Pedram Agand
← Writing
ML / AI

Large Model Checkpointing: What Actually Fails and How to Fix It

Checkpointing large models is not a solved problem. The failure modes — silent corruption, resume divergence, storage bottlenecks — are predictable and…

2026-04-06·6 min read·training, pytorch, infrastructure, checkpointing, distributed
Use with AI
Large Model Checkpointing: What Actually Fails and How to Fix It

Checkpointing is infrastructure work — it's not the part of a training run that makes papers. It's also the part that, when it fails, costs you days of compute and forces you to explain to your organization why the training job needs to restart from scratch. Understanding what actually fails in large model checkpointing, and why, is more valuable than following a recipe.

What You're Actually Serializing

A naive checkpoint saves the model weights. A correct checkpoint saves everything needed to resume a training run from the exact same state:

  • Model weights — the obvious part
  • Optimizer state — for Adam, this includes the first and second moment estimates for each parameter. For a large model, optimizer state is often 2–3× the size of the weights themselves
  • Scheduler state — learning rate schedule position
  • RNG states — the state of every random number generator on every device. Without this, your dropout masks, data shuffling, and other stochastic operations will differ from what they would have been, causing divergence
  • Gradient scaler state — if you're using mixed precision training with automatic loss scaling
  • Data loader state — your position in the dataset, to avoid reprocessing the same examples after resume (especially important for streaming datasets)
  • Step count and epoch — to correctly position the scheduler and determine when to checkpoint next

Missing any of these produces a checkpoint that loads but does not resume correctly. The training will continue, loss will be roughly right, and you may not notice until you compare the final model against what you'd have gotten without the interruption.

The Distributed Training Case

Single-GPU checkpointing is manageable. Distributed training multiplies the complexity.

In tensor-parallel or pipeline-parallel configurations, model weights are sharded across devices. Each device holds a shard of the model, and the optimizer state for that shard lives on the same device. There is no single process with the full checkpoint.

The options:

Full checkpoint on rank 0. Each rank sends its shard to rank 0, which assembles and writes the full checkpoint. Simple to restore (just load the file), expensive to create for large models (network bandwidth, memory pressure on rank 0, blocking the entire training job while it completes).

Sharded checkpoints. Each rank writes its own shard independently. Checkpoint creation is fast and parallel. Restoration requires knowing the original parallelism configuration — restoring with a different number of GPUs requires a resharding step.

PyTorch Distributed Checkpoint (DCP). PyTorch's built-in solution since 2.0. Writes sharded files with metadata that describes the global layout. Restoration can handle different parallelism configurations by reading the metadata and redistributing shards. This is the correct answer for most large model training as of 2024.

The most common practical failure: teams use torch.save(model.state_dict()) on rank 0 without handling the distributed optimizer state correctly. Weights restore fine. Optimizer state is lost. Training resumes but the loss trajectory diverges from where it would have been. The divergence is often subtle enough to miss on the training curve but significant enough to affect the final model.

Silent Corruption

Silent corruption is the failure mode that costs the most, because you don't know about it until you try to resume.

How it happens: checkpoint writes are interrupted (preemption, hardware failure, storage error). The file exists but is incomplete. On resume, PyTorch reads what's there and either raises a confusing error about tensor shape mismatch or, worse, loads a partially valid state that produces subtly wrong training behavior.

Mitigations:

Atomic writes. Write to a temporary file, then rename to the final path. On POSIX filesystems, rename is atomic — you either have the old complete checkpoint or the new complete checkpoint, never a partial file.

Verification reads. After writing a checkpoint, immediately read it back and verify that the loaded tensors match the originals. Adds latency to checkpoint creation but catches corruption before you depend on the file.

Checkpoint rotation. Keep the last N checkpoints, not just the most recent one. If the latest is corrupted, you can resume from N-1. One checkpoint is a single point of failure.

Off-node storage with integrity checks. Write to distributed storage (S3, GCS, Azure Blob) with checksums. These storage systems provide server-side integrity verification on read.

Storage Bottlenecks

For a 70B parameter model in BF16, weights alone are ~140GB. Optimizer state adds another 280GB in FP32. A full checkpoint is ~420GB. Checkpointing every 1000 steps with a 4-hour training run means dozens of checkpoints — petabyte-scale storage is not hypothetical.

Practical mitigations:

Checkpoint frequency proportional to checkpoint cost. Early in training, frequent checkpoints catch instabilities before they're expensive. Late in training, less frequent checkpoints are acceptable. Don't checkpoint at fixed intervals without considering the cost.

Only save full optimizer state at key milestones. For recovery purposes, weights + scheduler + RNG are sufficient to approximately resume. A milestone checkpoint (beginning of each epoch, end of each major phase) gets the full optimizer state. Recovery checkpoints in between save the lighter version.

Asynchronous checkpointing. The training process copies tensors to CPU memory, then a separate process handles the serialization and write. The training job continues while the write happens in the background. PyTorch's AsyncCheckpointWrapper implements this pattern.

Checkpoint pruning. If storage is constrained, keep checkpoints at geometrically increasing intervals (step 1000, 2000, 4000, 8000...) rather than linearly. The most recent checkpoints are kept; older checkpoints are pruned with a policy that preserves the ability to recover from any point in the last few checkpoints.

Resume Verification

After resuming from a checkpoint, verify that training is on track before committing to the resumed run:

  1. Load the checkpoint and verify the step count and learning rate match expectations
  2. Run a small number of steps (100–500) and compare the loss trajectory against the pre-interruption curve
  3. If loss diverges within those steps, your checkpoint has an issue — either RNG state, optimizer state, or data loader state is wrong

This verification step is cheap relative to the cost of discovering the problem 10,000 steps into a resumed run.

The Takeaway

Checkpointing failures have a characteristic feel: the training resumes without error, the loss curve looks approximately right, and you only discover the problem when the model behaves differently than expected in evaluation. By then, you've burned compute and the checkpoint is long overwritten.

The preventive investment — atomic writes, complete state serialization, rotation, and resume verification — is measurably cheaper than the cost of a single silent failure in a long training run. Treat your checkpointing infrastructure with the same rigor you'd apply to your training loop, because a training job is only as reproducible as its last valid checkpoint.

Want this implemented in your workflow?

I work with SaaS companies, real-estate, finance, and regulated-industry teams on AI adoption. Book a 20-minute strategy call — no pitch, just a focused conversation about your situation.

I publish one post like this per month. Join AI Command Room and I'll send it directly to you.