Shard data, shard model, shard sequence ... shard everything!!!
Introduction In order to scale up your model training / inference on multiple TPU/GPUs you need to know how to shard well and this post my attempt is to build intuition around how to do it well!! Basics of Neural Network Training Collectives Profiling code in JAX Data Parallel discuss DDP paper Model Parallel megatron sharding basics Optimizer State Sharding FSDP Sequence Sharding Efficient Checkpointing of Sharded Data Sharding Design Patterns