8  Training at Scale: Stability and Distributed Parallelism

A frontier model does not fit on one accelerator, and the run that produces it spans thousands of them for weeks. This chapter owns the systems that make that run fast, numerically stable, and recoverable: the four orthogonal axes of parallelism and how they compose, the precision formats that buy throughput without breaking the math, the collectives that move gradients and the overlap that hides them, and the fault tolerance that lets a run survive its own hardware. By the end a reader can explain how a device mesh is laid out for a given model shape and interconnect, why each parallelism axis lives where it does, and why MFU is the one number that says whether the systems work is done.

8.1 Problem

Two facts force everything in this chapter. A frontier model’s parameters, gradients, and optimizer state do not fit in one device’s memory, and the arithmetic to train it does not finish on one device in any acceptable time. So the work is split across thousands of accelerators, and the moment it is split, three constraints start to fight.

The first is memory: weights, the Adam moments whose state dominates (see Chapter 3), activations, and gradients must each find a home, and peak activation memory grows with sequence length. The second is communication: every split introduces traffic to keep the shards consistent, and that traffic competes with compute for the same time budget. The third is failure: at thousands of nodes for weeks, a dead GPU or NIC is the steady state, not the exception, and a single one kills the whole collective. The job of the training-systems layer is to spend memory, communication, and recovery against each other so that achieved throughput stays high and a failure costs minutes, not days.

8.2 Design

The core idea is that no single way of splitting the work suffices, so the design offers four orthogonal axes, each cutting a different resource, and composes them. Around the split sit two more concerns the split creates: which numerical precision the arithmetic runs in, and how the resulting communication is hidden under compute.

8.2.1 Four axes of parallelism

Each axis answers a different scaling problem.

Data parallelism (DP) replicates the model and shards the batch across replicas; each replica computes gradients on its slice and an all-reduce averages them. Plain DP does not save memory, since every replica holds a full copy. ZeRO and its PyTorch sibling FSDP fix that by additionally sharding optimizer state, gradients, and parameters across the DP group, so per-device memory drops without changing the math: a layer’s parameters are all-gathered just before they are needed and freed right after. This is the sharding Chapter 3 points to when it notes that the optimizer state is the dominant memory term.

Tensor parallelism (TP) shards individual matmuls inside a layer, splitting attention heads and the feed-forward columns and rows across devices. It is communication-heavy: each TP layer needs an all-reduce (or a reduce-scatter and all-gather pair) on the critical path of every forward and backward step. That cost is only affordable over the fastest interconnect, so TP stays inside a node.

Pipeline parallelism (PP) splits the layer stack into stages placed on different nodes and streams micro-batches through them, so stage two works on micro-batch one while stage one works on micro-batch two. It is bandwidth-cheap, moving only the activations at stage boundaries, but it pays a pipeline bubble: the fill and drain at the ends of each step where some stages sit idle.

Sequence and context parallelism (SP/CP) shard along the sequence dimension. SP, as introduced by Korthikanti et al., splits the parts of a layer that TP leaves replicated (the norms and dropouts) along the sequence to cut activation memory; CP, of which ring attention is the canonical form, shards the attention computation itself so a context too long for one device can be processed in blocks that pass key and value tiles around a ring.

For each axis the design question is the same: what does it shard, what does it cost in communication, and which tier of the interconnect can it afford to use.

8.2.2 Expert parallelism

Mixture-of-experts layers add a fifth axis. The experts are sharded across devices, and each token is dispatched to its chosen experts by an all-to-all collective, then the results are combined back by a second all-to-all. The systems layer owns only this execution: the two all-to-all exchanges per MoE layer, the load-imbalance stragglers when a hot expert draws more than its share of tokens, the capacity padding that bounds per-expert work, and the overlap of dispatch and combine with expert compute. The routing decision and the load-balancing loss that create this traffic pattern belong to Chapter 7.

8.2.3 Precision

The arithmetic does not need to run in fp32. Mixed precision keeps the forward and backward matmuls and the activations in bf16, while holding an fp32 master copy of the weights and fp32 optimizer state so that small updates do not vanish when added to a large weight. bf16’s wide exponent is what makes this safe: it retired the loss scaling that fp16 needed to keep gradients off the floor of its narrow range.

fp8 (the E4M3 and E5M2 formats) pushes the matmuls further on Hopper-class and newer hardware, roughly doubling matmul throughput and halving the memory the operands occupy. The price is range: fp8 has so few exponent bits that the operands must be rescaled per tensor, or more finely, to keep values inside the representable band. It is applied selectively, the matmuls in fp8 while the numerically sensitive operations, the normalizations, residual adds, and the attention softmax, stay in higher precision. Pushed too far, fp8 does not crash; it quietly costs quality, which is what makes it hard to use well.

8.2.4 Collectives and overlap

The traffic every axis generates runs through a small set of collectives: all-reduce, reduce-scatter, all-gather, and all-to-all. NCCL (RCCL on AMD) implements them and maps each onto the topology, choosing a ring or a tree algorithm and routing over NVLink within a node or the network between nodes.

The collectives are pure overhead unless they run while the device is doing something else, so overlap is the central throughput trick. FSDP prefetches the next layer’s parameters with an all-gather while the current layer computes. The gradient reduce-scatter overlaps the backward pass. PP communication overlaps stage compute. EP all-to-all overlaps expert matmuls. Whatever cannot be hidden is exposed communication, and it shows up directly as lost throughput.

8.2.5 MFU, the scoreboard

Model FLOPs utilization measures whether the overlap worked. It is achieved FLOPs over peak hardware FLOPs, counting only the FLOPs the model mathematically requires. It is distinct from hardware FLOPs utilization (HFU), which also counts the redundant FLOPs of activation recomputation, so HFU exceeds MFU whenever recomputation is on. MFU is the single number that says the systems work is done, and its budget is a sum of losses: exposed communication, the pipeline bubble, recomputation, and unfused kernels each take a slice. A low MFU is diagnosed from a per-step timeline, not from the loss curve.

8.3 Evolution

The toolkit arrived one axis at a time, each in response to a wall the previous generation hit.

Tensor parallelism came first. Megatron-LM (Shoeybi et al., 2019) showed how to shard the transformer’s matmuls across GPUs with only two all-reduces per layer, making a multi-billion-parameter model trainable when it no longer fit on one device. Pipeline parallelism arrived in parallel: GPipe (Huang et al., 2019) split the layer stack into stages and used micro-batching to keep them busy, and PipeDream (Narayanan et al., 2019) generalized the schedule to cut the bubble.

Data parallelism then learned to save memory. ZeRO (Rajbhandari et al., 2019) observed that the optimizer state, gradients, and parameters in standard DP are redundantly replicated, and sharded them across the DP group in three progressive stages; ZeRO-Infinity (2021) extended the idea to offload onto CPU and NVMe, and PyTorch FSDP (Zhao et al., 2023) made sharded DP a native, ergonomic primitive. Activation memory got its own axis next: Korthikanti et al. (2022) introduced sequence parallelism alongside selective recomputation, recomputing only the cheap-to-redo, memory-heavy operations rather than the whole layer.

By 2021 the axes were combined deliberately. Narayanan et al.’s PTD-P work mapped data, tensor, and pipeline parallelism onto a single GPU cluster as a composed 3D layout, establishing the ordering that frontier runs still follow. Long context then drove context-parallel attention: ring attention (Liu et al., 2023) shards the sequence across devices and streams key and value blocks around a ring so the context can grow nearly without bound. On the precision axis, mixed-precision training (Micikevicius et al., 2017) established the bf16-plus-fp32-master recipe, and the fp8 formats paper (Micikevicius et al., 2022) defined E4M3 and E5M2 for the Hopper generation that made fp8 matmuls native. Expert parallelism’s all-to-all dispatch traces to GShard (Lepikhin et al., 2020).

8.4 Trade-offs

Each axis spends a different resource, and the layout that wins is the one that hides the most communication under compute for a specific model shape on a specific interconnect.

  • Parallelism layout, and its non-portability. TP cuts per-device memory and latency but burns intra-node bandwidth, so it caps at the NVLink domain. PP is bandwidth-cheap but pays a pipeline bubble, mitigated by more micro-batches and interleaved schedules. DP and ZeRO scale the batch, but the optimizer-state shard and the all-gather traffic grow with the DP degree. The right layout is not portable: change the model size, the sequence length, or the cluster, and the optimum moves.
  • Memory versus compute. Activation and gradient checkpointing trade extra forward FLOPs in the backward pass for lower peak activation memory. Selective recomputation redoes only the cheap, memory-heavy operations. More recomputation lets you fit a bigger model or a longer sequence, at the cost of MFU, which is exactly why HFU exceeds MFU when recompute is on.
  • Precision versus stability. bf16 is the safe default, and fp32 master weights are non-negotiable for the optimizer. fp8 buys throughput and memory but narrows range, so it is applied per-tensor with scaling and kept away from the sensitive operations. Pushed too far it produces silent quality loss, not a crash.
  • Checkpoint cadence. Frequent checkpoints shrink the work lost per failure but tax storage bandwidth and can stall the run. Asynchronous and sharded checkpointing is what makes a short cadence affordable, and the optimum depends on cluster mean time between failures and write cost.
ImportantWhat’s contested

How far fp8 can be pushed in frontier pre-training is unsettled. bf16 is the agreed safe default, and fp8 plainly buys throughput and memory on Hopper-class hardware. The disagreement is over the boundary: which operations can move to fp8, how fine the scaling must be, and how many trillions of tokens a run can sustain in fp8 before the narrowed range costs final quality. The failure mode is what makes this hard to settle empirically: fp8 pushed too far does not crash, it quietly degrades the model, so the cost is invisible until a late evaluation. Treat the fp8 boundary as a per-operation, per-recipe choice that must be validated, not a setting you can copy.

TipConstraint arrow

The interconnect dictates the parallelism layout. Tensor parallelism needs an all-reduce on the critical path of every layer, which is only affordable over the fastest links, so the size of the NVLink domain sets the maximum TP group; beyond it, TP traffic crosses the network and stalls. Pipeline and data parallelism, which move less and can hide it, are the axes that span nodes, and expert parallelism gets its own all-to-all group. The bandwidth hierarchy of Chapter 30 is therefore upstream of every mesh decision in this chapter: a lower layer’s wires decide which split an upper layer may use where.

8.5 Implementation

8.5.1 Composing the axes

A frontier run maps the device mesh as a product of axes, conventionally (DP x TP x PP), adding EP for MoE and layering SP on top of TP. The ordering rule of thumb follows directly from the constraint arrow: put TP inside a node under NVLink, run PP and DP across nodes where their cheaper communication can hide, and give EP its own all-to-all group. The layout is co-designed with the model shape from Chapter 6 and Chapter 7, and it shifts with model size, sequence length, and interconnect.

flowchart TB
  subgraph across["Across nodes"]
    DP["DP / ZeRO-FSDP: shard batch and optimizer state"]
    PP["PP: layer-stack stages, stream micro-batches"]
    EP["EP: experts sharded, all-to-all dispatch/combine"]
  end
  subgraph node["Inside a node under NVLink"]
    TP["TP: shard matmuls, all-reduce per layer"]
    SP["SP: shard sequence dim, cut activation memory"]
  end
  DP --> TP
  PP --> TP
  TP --> SP

8.5.2 Frameworks

The frameworks differ mostly in which layout and which hardware each makes ergonomic. Megatron-LM and Megatron-Core carry TP, PP, SP, and EP on the NVIDIA stack. DeepSpeed packages ZeRO and offload. PyTorch FSDP is native sharded DP. JAX and XLA express sharding as GSPMD-style annotations (Xu et al., 2021) on TPU, where the JAX pjit interface drives the partitioner. In practice a frontier stack combines pieces, for example Megatron-Core for model parallelism plus a custom data and checkpoint plane.

8.5.3 Fault tolerance

At thousands of nodes for weeks, recovery is a first-class subsystem. Checkpointing cadence is set by the write cost against the work lost per failure; asynchronous and sharded checkpointing, as in CheckFreq (Mohan et al., 2021), keep the write off the critical path so a short cadence stays affordable. Elastic restart detects a dead node, fences it, and resumes from the last checkpoint, ideally onto spare capacity without a full restart. Stragglers and silent data corruption are harder than clean crashes, because the run keeps going while it is slow or subtly wrong.

One fault-tolerance concern reaches into the data plane: a restart must resume the exact same data order, which means the sampler and shuffle state are part of the checkpoint. Without that, a resume re-feeds or skips data and breaks the mixture contract from Chapter 4, surfacing as an unexplained loss discontinuity at every restart. The storage fabric that streams the corpus and the cluster orchestration that schedules the run live in Chapter 31; the resumability contract is the piece that stays here.

8.5.4 Failure modes

The symptoms cluster. A low MFU is the catch-all, traced by a per-step profile to exposed communication, an oversized bubble, too much recomputation, or unfused kernels. A communication bottleneck means the network is the limiter: TP spilling past the NVLink domain, a DP all-gather that cannot hide under the backward pass, or MoE all-to-all stalling on a hot expert. An out-of-memory error means the wrong layout for the budget: activation memory at long sequence length, optimizer state under-sharded, or a fragmentation cliff that appears only at a certain micro-batch count. A node failure with no fast detection, fencing, and recent checkpoint rolls the run back to its last save, and the gap is pure waste.

Recipe-level stability, the z-loss and QK-norm and loss-spike recovery that keep the optimizer from diverging, belongs to Chapter 3. The stability this chapter owns is numerical, set by the precision choice, and run-level, set by fault tolerance.

8.6 Further reading

  • Shoeybi et al., “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism,” 2019. arXiv:1909.08053
  • Rajbhandari et al., “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models,” 2019. arXiv:1910.02054
  • Narayanan et al., “Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM,” 2021 (SC’21; PTD-P 3D parallelism). arXiv:2104.04473
  • Korthikanti et al., “Reducing Activation Recomputation in Large Transformer Models,” 2022 (sequence parallelism + selective recomputation). arXiv:2205.05198
  • Huang et al., “GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism,” 2019 (NeurIPS). arXiv:1811.06965
  • Narayanan et al., “PipeDream: Generalized Pipeline Parallelism for DNN Training,” 2019 (SOSP). doi:10.1145/3341301.3359646
  • Zhao et al., “PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel,” 2023 (VLDB). arXiv:2304.11277
  • Rajbhandari et al., “ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning,” 2021. arXiv:2104.07857
  • Micikevicius et al., “Mixed Precision Training,” 2017. arXiv:1710.03740
  • Micikevicius et al., “FP8 Formats for Deep Learning,” 2022. arXiv:2209.05433
  • Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,” 2022 (kernel-level IO-awareness; also in 03). arXiv:2205.14135
  • Liu et al., “Ring Attention with Blockwise Transformers for Near-Infinite Context,” 2023 (context-parallel attention). arXiv:2310.01889
  • Lepikhin et al., “GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding,” 2020 (expert sharding + all-to-all; also in 04). arXiv:2006.16668
  • NCCL (NVIDIA Collective Communications Library): optimized inter-GPU collective primitives (all-reduce, all-gather, reduce-scatter, all-to-all) over NVLink/PCIe/InfiniBand. Engineering library, not a single canonical paper. NVIDIA/nccl
  • Mohan et al., “CheckFreq: Frequent, Fine-Grained DNN Checkpointing,” 2021 (USENIX FAST’21; asynchronous, low-overhead checkpointing). usenix.org/conference/fast21/presentation/mohan
  • Xu et al., “GSPMD: General and Scalable Parallelization for ML Computation Graphs,” 2021 (XLA/TPU sharding annotations underlying JAX pjit). arXiv:2105.04663