Introduction
In order to scale up your model training / inference on multiple TPU/GPUs you need to know how to shard well and this post my attempt is to build intuition around how to do it well!!
Basics of Neural Network Training
Collectives
Profiling code in JAX
Data Parallel
- discuss DDP paper
Model Parallel
- megatron sharding basics