Welcome to my research log ๐Ÿ™‚๐Ÿ‘‹

I’m an AI researcher at Deepmind and I checkpoint my everyday learnings here. I strongly believe in Feynman’s way of learning by teaching others and in writing simple small-scale code snippets to understand stuff. This blog is highly inspired by the awesome lil’log which you should definitely check out :)

Shard data, shard model, shard sequence ... shard everything!!!

Introduction In order to scale up your model training / inference on multiple TPU/GPUs you need to know how to shard well and this post my attempt is to build intuition around how to do it well!! Basics of Neural Network Training Collectives Profiling code in JAX Data Parallel discuss DDP paper Model Parallel megatron sharding basics Optimizer State Sharding FSDP Sequence Sharding Efficient Checkpointing of Sharded Data Sharding Design Patterns

January 10, 2025