Introduction

In order to scale up your model training / inference on multiple TPU/GPUs you need to know how to shard well and this post my attempt is to build intuition around how to do it well!!

Basics of Neural Network Training

Collectives

Profiling code in JAX

Data Parallel

  • discuss DDP paper

Model Parallel

  • megatron sharding basics

Optimizer State Sharding

FSDP

Sequence Sharding

Efficient Checkpointing of Sharded Data

Sharding Design Patterns