Gradient checkpointing is a memory optimization technique that enables the training of larger neural networks by selectively discarding intermediate activations during the forward pass and recomputing them on-demand during the backward pass. This strategic trade-off reduces peak memory consumption from O(n) to O(√n) with respect to the number of layers, at the cost of approximately one additional forward pass per checkpointed segment. It is a critical tool for large language model (LLM) training and parameter-efficient fine-tuning where memory is the primary constraint.
Glossary
Gradient Checkpointing

What is Gradient Checkpointing?
Gradient checkpointing is a memory-for-compute trade-off technique used to train deep neural networks that would otherwise exceed GPU memory limits.
The technique works by dividing the model's computational graph into segments. Only the inputs and outputs of these segments are stored in memory; internal activations are discarded. During backpropagation, the forward pass is re-executed for each segment to regenerate the necessary activations for gradient calculation. This approach is formally implemented in frameworks like PyTorch via torch.utils.checkpoint. It is fundamentally different from mixed precision training or model parallelism, as it directly trades computational FLOPs for reduced memory footprint.
Key Characteristics of Gradient Checkpointing
Gradient checkpointing is a memory-for-compute trade-off technique that enables the training of neural networks larger than available GPU memory by strategically discarding and recomputing intermediate activations.
Core Trade-Off: Memory vs. Compute
Gradient checkpointing fundamentally trades increased computational time for reduced memory consumption. During the standard backpropagation algorithm, all intermediate activations from the forward pass must be stored to compute gradients, leading to O(n) memory complexity with respect to network depth. Checkpointing reduces this to O(√n) by storing only a subset of activations (checkpoints) and recomputing the non-stored segments during the backward pass. This allows models that are 2-10x larger to be trained on the same hardware, at the cost of approximately a 20-30% increase in total training time due to the extra forward passes.
Selective Activation Storage
The technique does not store every layer's output. Instead, it strategically selects specific layers as checkpoints. Common strategies include:
- Uniform Checkpointing: Saving activations at evenly spaced intervals (e.g., every k layers).
- Manual Checkpointing: Manually defining critical layers (e.g., after expensive operations or at residual connections).
- Dynamic Programming: Using an algorithm to determine the optimal checkpoint schedule to minimize total recomputation cost. The non-checkpointed activations are discarded after their initial use in the forward pass and must be regenerated when their gradients are needed. This selective storage is the primary source of memory savings.
Recomputation During Backward Pass
When the backward pass reaches a segment between two checkpoints, it executes a local forward pass to regenerate the discarded activations for that segment only. This process is recursive: to recompute the activation for layer i, you may need to recompute from the previous checkpoint. The recomputation is performed in the same precision as the original forward pass and uses the same model weights, ensuring numerical equivalence. The overhead is not a full second forward pass but a series of smaller, targeted recomputations, making it more efficient than naive re-calculation.
Implementation in Deep Learning Frameworks
Gradient checkpointing is natively supported in major frameworks, abstracting the complexity from the user.
- PyTorch: The
torch.utils.checkpoint.checkpointfunction wraps a network segment. It runs the segment without saving intermediate activations in the forward pass and recomputes them during the backward pass. - TensorFlow: The
tf.recompute_graddecorator ortf.contrib.layers.recompute_gradfunction provides similar functionality. - JAX: The
jax.checkpoint(formerlyjax.remat) transformation is a core primitive for automatic rematerialization. These implementations handle the autograd tape or computation graph modifications required to correctly record operations for the recomputed forward pass.
Critical Use Case: Large Model Training
This technique is essential for overcoming GPU memory bottlenecks when training large language models (LLMs), vision transformers, and scientific models. It is a foundational enabling technology for:
- Research Scalability: Allowing academic researchers with limited hardware to experiment with architectures that would otherwise require industrial-scale clusters.
- Efficient Fine-Tuning: Enabling parameter-efficient fine-tuning (PEFT) methods like LoRA or Adapter Layers on larger base models by reducing the memory footprint of the frozen backbone's activations.
- Long Sequence Processing: Making it feasible to train models on very long input sequences (e.g., in genomics or document analysis) where activation memory is prohibitive.
Interaction with Other Optimizations
Gradient checkpointing is rarely used in isolation and interacts with other memory-saving techniques:
- Mixed Precision Training: Checkpointing stores activations in FP16/BF16, providing further memory reduction. Recomputation also uses lower precision.
- Model Parallelism & FSDP: Often combined with Fully Sharded Data Parallel (FSDP) or tensor parallelism. Checkpointing reduces the per-GPU memory for activations, while parallelism shards the model parameters.
- Activation Offloading: An even more aggressive technique where all activations are offloaded to CPU RAM and fetched back during backward pass. Checkpointing reduces the volume of data that needs to be offloaded and retrieved.
- Compiler Optimizations: Frameworks like PyTorch's
torch.compilecan optimize the recomputation graph, potentially reducing the overhead.
Gradient Checkpointing vs. Other Memory Optimization Techniques
A technical comparison of memory reduction strategies used during the training of large neural networks, highlighting trade-offs between memory, compute, and implementation complexity.
| Feature / Metric | Gradient Checkpointing | Mixed Precision Training | ZeRO / FSDP (Distributed) | Model Compression (PTQ/QAT) |
|---|---|---|---|---|
Primary Mechanism | Selectively discards and recomputes intermediate activations | Uses lower-precision (e.g., FP16/BF16) for most operations | Partitions model states (params, gradients, optimizer) across GPUs | Reduces numerical precision of weights/activations (e.g., to INT8) |
Memory Reduction Target | Activations during forward pass | Activations and model weights | Optimizer states, gradients, and parameters | Model weights and sometimes activations |
Compute Overhead | High (30-40% increase due to recomputation) | Low (accelerated via Tensor Cores) | Moderate (communication overhead between GPUs) | None for PTQ; Moderate for QAT (re-training) |
Typical Memory Savings |
| ~50% for activation/weight memory | Enables scaling to 1T+ parameters | 4x reduction (FP32 -> INT8) for weights |
Best For | Training very large models on limited GPU memory | Speeding up training and reducing memory footprint generally | Distributed training of extremely large models | Deploying models for inference on resource-constrained hardware |
Requires Model Changes | No (implementation-level technique) | No (framework-level precision policy) | Yes (requires distributed training setup) | Yes (quantization ops or simulated quantization) |
Applicable Phase | Training | Training | Training | Primarily Inference (PTQ), Training (QAT) |
Key Trade-off | Compute time for memory | Numerical stability for speed/memory | Inter-GPU communication for memory scalability | Model accuracy for size/speed |
Implementation and Framework Support
Gradient checkpointing is implemented as a memory-for-compute trade-off within the training loop. Major deep learning frameworks provide native APIs and automatic mechanisms to apply this optimization.
PyTorch Implementation
PyTorch provides the torch.utils.checkpoint.checkpoint function, which wraps a segment of the forward pass. During execution, it:
- Saves only the inputs and the function object for the wrapped segment.
- Discards all intermediate activations produced within the segment.
- During the backward pass, it re-runs the forward function with the saved inputs to recompute the necessary activations.
Key API: output = torch.utils.checkpoint.checkpoint(function, *inputs). For transformer models, a common strategy is to checkpoint every N layers, where N is tuned based on available memory.
TensorFlow Implementation
TensorFlow 2.x implements gradient checkpointing via tf.recompute_grad as a function decorator or the tf.nn.compute_gradients API. The tf.GradientTape context manager can be configured to use checkpointing automatically.
Core Mechanism: The decorator creates a custom gradient function that recomputes the forward pass during the backward pass. For Keras models, checkpointing can be integrated by wrapping layer calls or using custom training loops. The tf.config.experimental.enable_async_checkpoint can further optimize I/O for very large models.
JAX and Flax Support
In JAX, the jax.checkpoint (formerly jax.remat) transformation is the fundamental primitive. It is a higher-order function that returns a rematerialized version of its input function.
Key Characteristics:
- Flexible Policies: JAX supports
policyarguments to control which operations are checkpointed (e.g.,checkpoint_policies.everything_saveable). - Integration with
jax.grad: Checkpointing is composable with automatic differentiation. In Flax, it can be applied via therematargument inflax.linen.Moduledecorators or within transformer block definitions.
Automatic Selection and Policies
Advanced implementations use checkpointing policies to decide which layers or operations to recompute, optimizing the trade-off.
Common Policies:
- Uniform: Checkpoint every
k-th layer. - Optimal (Chen et al.): A dynamic programming algorithm that selects checkpoints to minimize recomputation under a fixed memory budget. This is the theoretical optimum.
- Sublinear: Used in extremely long sequences (e.g., in reinforcement learning), where the policy ensures memory usage grows sublinearly with sequence length. Frameworks like FairScale and DeepSpeed often implement these optimal policies.
Interaction with Other Optimizations
Gradient checkpointing is often combined with other memory-saving techniques:
- Mixed Precision Training: Checkpointing typically stores inputs in full precision (FP32) for numerical stability during recomputation, even if the forward pass uses FP16/BF16.
- Model Parallelism: Checkpointing is applied within each device's sub-graph to further reduce peak memory per GPU.
- Activation Offloading: In systems like DeepSpeed ZeRO-Offload, checkpointed activations can be moved to CPU memory, compounding the memory savings. Critical Note: The recomputation cost is additive, so the total wall-clock training time increases proportionally to the number of checkpointed segments.
Frequently Asked Questions
Gradient checkpointing is a critical memory optimization technique in deep learning, enabling the training of models that would otherwise exceed GPU memory limits by trading compute for memory. These FAQs address its core mechanisms, trade-offs, and implementation.
Gradient checkpointing is a memory optimization technique for training deep neural networks that trades compute time for GPU memory by selectively discarding and recomputing intermediate activations during the backward pass.
It works by dividing the model's forward pass into segments. During the initial forward pass, only the activations at the boundaries of these segments (the 'checkpoints') are stored in memory. All other intermediate activations are discarded. During the backward pass, when gradients for a segment are needed, the forward pass for that segment is recomputed from the nearest stored checkpoint. This recomputation generates the required activations locally, gradients are calculated, and then the temporary activations are discarded again. This process significantly reduces the peak memory consumption from O(n) (where n is the number of layers) to O(√n), at the cost of approximately one additional forward pass per training iteration.
Enabling Efficiency, Speed & Accuracy
Intelligent Analysis, Decision & Execution
We build AI systems for teams that need search across company data, workflow automation across tools, or AI features inside products and internal software.
Talk to Us
Search across company data
Give teams answers from docs, tickets, runbooks, and product data with sources and permissions.
Useful when people spend too long searching or get different answers from different systems.

Automate internal workflows
Use AI to route work, draft outputs, trigger actions, and keep approvals and logs in place.
Useful when repetitive work moves across multiple tools and teams.

Add AI to products and internal tools
Build assistants, guided actions, or decision support into the software your team or customers already use.
Useful when AI needs to be part of the product, not a separate tool.
Related Terms
Gradient checkpointing is a foundational memory optimization technique that enables the training of larger models. It is often used in conjunction with other methods for efficient model adaptation and deployment.
Activation Recomputation
Activation recomputation is the core mechanism behind gradient checkpointing. It is the strategic process of discarding intermediate layer outputs (activations) during the forward pass of neural network training and recalculating them during the backward pass when they are needed for gradient computation.
- Trade-off: Explicitly trades extra compute cycles for reduced GPU memory consumption.
- Granularity: Can be applied at the level of individual transformer blocks, groups of layers, or custom segments.
- Implementation: Modern deep learning frameworks (PyTorch's
torch.utils.checkpoint, TensorFlow'stf.recompute_grad) provide APIs to automate this process.
Memory-Bound vs. Compute-Bound
These terms describe the primary limiting factor in a training or inference workload. Understanding this distinction is key to applying techniques like gradient checkpointing effectively.
- Memory-Bound: The process is limited by the available GPU RAM (VRAM). This is the typical scenario for training large models where activations, parameters, and optimizer states exceed memory capacity. Gradient checkpointing is a solution for memory-bound training.
- Compute-Bound: The process is limited by the computational throughput (FLOPs) of the GPU. The workload can fully utilize the GPU's compute units, and memory is not the bottleneck.
- Engineering Decision: Checkpointing introduces compute overhead. It is most beneficial when moving from a memory-bound to a compute-bound regime, allowing training to proceed where it was previously impossible.
Selective Activation Checkpointing
Selective activation checkpointing is an advanced optimization of the basic gradient checkpointing strategy. Instead of checkpointing every layer, it uses heuristics or profiling to identify and only checkpoint the most memory-intensive layers.
- Objective: Minimize the compute overhead of recomputation while still achieving significant memory savings.
- Targets: Often focuses on layers with large output tensors, such as the feed-forward network (FFN) blocks in transformers, which are typically more expensive than attention layers.
- Benefit: Can provide a superior trade-off curve compared to uniform checkpointing, leading to faster training times for the same memory budget.

About the author
Prasad Kumkar
CEO & MD, Inference Systems
Prasad Kumkar is the CEO & MD of Inference Systems and writes about AI systems architecture, LLM infrastructure, model serving, evaluation, and production deployment. Over 5+ years, he has worked across computer vision models, L5 autonomous vehicle systems, and LLM research, with a focus on taking complex AI ideas into real-world engineering systems.
His work and writing cover AI systems, large language models, AI agents, multimodal systems, autonomous systems, inference optimization, RAG, evaluation, and production AI engineering.
Partnered with leading AI, data, and software stack.
How We Work
Custom AI workflows for your Business
One-fit-all AI don't work for modern businesses. At Inferensys, we aim to understand your business & custom requirements; which we use to define most efficient agentic workflows, the data, and the tools for your business.
01
Review the use case
We understand the task, the users, and where AI can actually help.
Read more02
Pick the right approach
We define what needs search, automation, or product integration.
Read more03
Build the first useful version
We implement the part that proves the value first.
Read more04
Improve from there
We add the checks and visibility needed to keep it useful.
Read moreThe first call is a practical review of your use case and the right next step.
Talk to Us