Inferensys

Glossary

Gradient Checkpointing

Gradient checkpointing is a memory optimization technique for neural network training that trades increased computation for reduced memory usage by selectively discarding intermediate activations during the forward pass and recomputing them as needed during the backward pass.
Engineer optimizing context window usage on laptop, token usage charts visible, technical work session.
MEMORY OPTIMIZATION

What is Gradient Checkpointing?

Gradient checkpointing is a memory-for-compute trade-off technique used to train deep neural networks that would otherwise exceed GPU memory limits.

Gradient checkpointing is a memory optimization technique that enables the training of larger neural networks by selectively discarding intermediate activations during the forward pass and recomputing them on-demand during the backward pass. This strategic trade-off reduces peak memory consumption from O(n) to O(√n) with respect to the number of layers, at the cost of approximately one additional forward pass per checkpointed segment. It is a critical tool for large language model (LLM) training and parameter-efficient fine-tuning where memory is the primary constraint.

The technique works by dividing the model's computational graph into segments. Only the inputs and outputs of these segments are stored in memory; internal activations are discarded. During backpropagation, the forward pass is re-executed for each segment to regenerate the necessary activations for gradient calculation. This approach is formally implemented in frameworks like PyTorch via torch.utils.checkpoint. It is fundamentally different from mixed precision training or model parallelism, as it directly trades computational FLOPs for reduced memory footprint.

MEMORY OPTIMIZATION TECHNIQUE

Key Characteristics of Gradient Checkpointing

Gradient checkpointing is a memory-for-compute trade-off technique that enables the training of neural networks larger than available GPU memory by strategically discarding and recomputing intermediate activations.

01

Core Trade-Off: Memory vs. Compute

Gradient checkpointing fundamentally trades increased computational time for reduced memory consumption. During the standard backpropagation algorithm, all intermediate activations from the forward pass must be stored to compute gradients, leading to O(n) memory complexity with respect to network depth. Checkpointing reduces this to O(√n) by storing only a subset of activations (checkpoints) and recomputing the non-stored segments during the backward pass. This allows models that are 2-10x larger to be trained on the same hardware, at the cost of approximately a 20-30% increase in total training time due to the extra forward passes.

02

Selective Activation Storage

The technique does not store every layer's output. Instead, it strategically selects specific layers as checkpoints. Common strategies include:

  • Uniform Checkpointing: Saving activations at evenly spaced intervals (e.g., every k layers).
  • Manual Checkpointing: Manually defining critical layers (e.g., after expensive operations or at residual connections).
  • Dynamic Programming: Using an algorithm to determine the optimal checkpoint schedule to minimize total recomputation cost. The non-checkpointed activations are discarded after their initial use in the forward pass and must be regenerated when their gradients are needed. This selective storage is the primary source of memory savings.
03

Recomputation During Backward Pass

When the backward pass reaches a segment between two checkpoints, it executes a local forward pass to regenerate the discarded activations for that segment only. This process is recursive: to recompute the activation for layer i, you may need to recompute from the previous checkpoint. The recomputation is performed in the same precision as the original forward pass and uses the same model weights, ensuring numerical equivalence. The overhead is not a full second forward pass but a series of smaller, targeted recomputations, making it more efficient than naive re-calculation.

04

Implementation in Deep Learning Frameworks

Gradient checkpointing is natively supported in major frameworks, abstracting the complexity from the user.

  • PyTorch: The torch.utils.checkpoint.checkpoint function wraps a network segment. It runs the segment without saving intermediate activations in the forward pass and recomputes them during the backward pass.
  • TensorFlow: The tf.recompute_grad decorator or tf.contrib.layers.recompute_grad function provides similar functionality.
  • JAX: The jax.checkpoint (formerly jax.remat) transformation is a core primitive for automatic rematerialization. These implementations handle the autograd tape or computation graph modifications required to correctly record operations for the recomputed forward pass.
05

Critical Use Case: Large Model Training

This technique is essential for overcoming GPU memory bottlenecks when training large language models (LLMs), vision transformers, and scientific models. It is a foundational enabling technology for:

  • Research Scalability: Allowing academic researchers with limited hardware to experiment with architectures that would otherwise require industrial-scale clusters.
  • Efficient Fine-Tuning: Enabling parameter-efficient fine-tuning (PEFT) methods like LoRA or Adapter Layers on larger base models by reducing the memory footprint of the frozen backbone's activations.
  • Long Sequence Processing: Making it feasible to train models on very long input sequences (e.g., in genomics or document analysis) where activation memory is prohibitive.
06

Interaction with Other Optimizations

Gradient checkpointing is rarely used in isolation and interacts with other memory-saving techniques:

  • Mixed Precision Training: Checkpointing stores activations in FP16/BF16, providing further memory reduction. Recomputation also uses lower precision.
  • Model Parallelism & FSDP: Often combined with Fully Sharded Data Parallel (FSDP) or tensor parallelism. Checkpointing reduces the per-GPU memory for activations, while parallelism shards the model parameters.
  • Activation Offloading: An even more aggressive technique where all activations are offloaded to CPU RAM and fetched back during backward pass. Checkpointing reduces the volume of data that needs to be offloaded and retrieved.
  • Compiler Optimizations: Frameworks like PyTorch's torch.compile can optimize the recomputation graph, potentially reducing the overhead.
MEMORY OPTIMIZATION COMPARISON

Gradient Checkpointing vs. Other Memory Optimization Techniques

A technical comparison of memory reduction strategies used during the training of large neural networks, highlighting trade-offs between memory, compute, and implementation complexity.

Feature / MetricGradient CheckpointingMixed Precision TrainingZeRO / FSDP (Distributed)Model Compression (PTQ/QAT)

Primary Mechanism

Selectively discards and recomputes intermediate activations

Uses lower-precision (e.g., FP16/BF16) for most operations

Partitions model states (params, gradients, optimizer) across GPUs

Reduces numerical precision of weights/activations (e.g., to INT8)

Memory Reduction Target

Activations during forward pass

Activations and model weights

Optimizer states, gradients, and parameters

Model weights and sometimes activations

Compute Overhead

High (30-40% increase due to recomputation)

Low (accelerated via Tensor Cores)

Moderate (communication overhead between GPUs)

None for PTQ; Moderate for QAT (re-training)

Typical Memory Savings

50% for activation memory

~50% for activation/weight memory

Enables scaling to 1T+ parameters

4x reduction (FP32 -> INT8) for weights

Best For

Training very large models on limited GPU memory

Speeding up training and reducing memory footprint generally

Distributed training of extremely large models

Deploying models for inference on resource-constrained hardware

Requires Model Changes

No (implementation-level technique)

No (framework-level precision policy)

Yes (requires distributed training setup)

Yes (quantization ops or simulated quantization)

Applicable Phase

Training

Training

Training

Primarily Inference (PTQ), Training (QAT)

Key Trade-off

Compute time for memory

Numerical stability for speed/memory

Inter-GPU communication for memory scalability

Model accuracy for size/speed

GRADIENT CHECKPOINTING

Implementation and Framework Support

Gradient checkpointing is implemented as a memory-for-compute trade-off within the training loop. Major deep learning frameworks provide native APIs and automatic mechanisms to apply this optimization.

01

PyTorch Implementation

PyTorch provides the torch.utils.checkpoint.checkpoint function, which wraps a segment of the forward pass. During execution, it:

  • Saves only the inputs and the function object for the wrapped segment.
  • Discards all intermediate activations produced within the segment.
  • During the backward pass, it re-runs the forward function with the saved inputs to recompute the necessary activations.

Key API: output = torch.utils.checkpoint.checkpoint(function, *inputs). For transformer models, a common strategy is to checkpoint every N layers, where N is tuned based on available memory.

02

TensorFlow Implementation

TensorFlow 2.x implements gradient checkpointing via tf.recompute_grad as a function decorator or the tf.nn.compute_gradients API. The tf.GradientTape context manager can be configured to use checkpointing automatically.

Core Mechanism: The decorator creates a custom gradient function that recomputes the forward pass during the backward pass. For Keras models, checkpointing can be integrated by wrapping layer calls or using custom training loops. The tf.config.experimental.enable_async_checkpoint can further optimize I/O for very large models.

03

JAX and Flax Support

In JAX, the jax.checkpoint (formerly jax.remat) transformation is the fundamental primitive. It is a higher-order function that returns a rematerialized version of its input function.

Key Characteristics:

  • Flexible Policies: JAX supports policy arguments to control which operations are checkpointed (e.g., checkpoint_policies.everything_saveable).
  • Integration with jax.grad: Checkpointing is composable with automatic differentiation. In Flax, it can be applied via the remat argument in flax.linen.Module decorators or within transformer block definitions.
05

Automatic Selection and Policies

Advanced implementations use checkpointing policies to decide which layers or operations to recompute, optimizing the trade-off.

Common Policies:

  • Uniform: Checkpoint every k-th layer.
  • Optimal (Chen et al.): A dynamic programming algorithm that selects checkpoints to minimize recomputation under a fixed memory budget. This is the theoretical optimum.
  • Sublinear: Used in extremely long sequences (e.g., in reinforcement learning), where the policy ensures memory usage grows sublinearly with sequence length. Frameworks like FairScale and DeepSpeed often implement these optimal policies.
06

Interaction with Other Optimizations

Gradient checkpointing is often combined with other memory-saving techniques:

  • Mixed Precision Training: Checkpointing typically stores inputs in full precision (FP32) for numerical stability during recomputation, even if the forward pass uses FP16/BF16.
  • Model Parallelism: Checkpointing is applied within each device's sub-graph to further reduce peak memory per GPU.
  • Activation Offloading: In systems like DeepSpeed ZeRO-Offload, checkpointed activations can be moved to CPU memory, compounding the memory savings. Critical Note: The recomputation cost is additive, so the total wall-clock training time increases proportionally to the number of checkpointed segments.
GRADIENT CHECKPOINTING

Frequently Asked Questions

Gradient checkpointing is a critical memory optimization technique in deep learning, enabling the training of models that would otherwise exceed GPU memory limits by trading compute for memory. These FAQs address its core mechanisms, trade-offs, and implementation.

Gradient checkpointing is a memory optimization technique for training deep neural networks that trades compute time for GPU memory by selectively discarding and recomputing intermediate activations during the backward pass.

It works by dividing the model's forward pass into segments. During the initial forward pass, only the activations at the boundaries of these segments (the 'checkpoints') are stored in memory. All other intermediate activations are discarded. During the backward pass, when gradients for a segment are needed, the forward pass for that segment is recomputed from the nearest stored checkpoint. This recomputation generates the required activations locally, gradients are calculated, and then the temporary activations are discarded again. This process significantly reduces the peak memory consumption from O(n) (where n is the number of layers) to O(√n), at the cost of approximately one additional forward pass per training iteration.

Prasad Kumkar

About the author

Prasad Kumkar

CEO & MD, Inference Systems

Prasad Kumkar is the CEO & MD of Inference Systems and writes about AI systems architecture, LLM infrastructure, model serving, evaluation, and production deployment. Over 5+ years, he has worked across computer vision models, L5 autonomous vehicle systems, and LLM research, with a focus on taking complex AI ideas into real-world engineering systems.

His work and writing cover AI systems, large language models, AI agents, multimodal systems, autonomous systems, inference optimization, RAG, evaluation, and production AI engineering.