Fully Sharded Data Parallel (FSDP) is a memory-efficient distributed training paradigm that implements Stage 3 of the Zero Redundancy Optimizer (ZeRO) protocol. Unlike traditional Data Parallel training, which replicates the entire model on each GPU, FSDP shards the model parameters, gradients, and optimizer states across all devices. During the forward pass, each GPU materializes only the shard of parameters it needs for its current layer computation, significantly reducing the peak memory footprint and enabling the training of models much larger than a single GPU's memory capacity.
Glossary
Fully Sharded Data Parallel (FSDP)

What is Fully Sharded Data Parallel (FSDP)?
Fully Sharded Data Parallel (FSDP) is a PyTorch-native distributed training optimization that shards a model's parameters, gradients, and optimizer states across all GPUs in a data-parallel group to dramatically reduce per-device memory consumption.
FSDP's efficiency stems from its sharding strategy and communication collectives. It uses all-gather operations to reconstruct full parameters for computation and reduce-scatter operations to aggregate gradients. This approach is orthogonal to other memory-saving techniques like gradient checkpointing and mixed precision training, which can be combined with FSDP. As a core component of parameter-efficient fine-tuning strategies for massive models, FSDP allows engineers to adapt large pre-trained models on limited hardware by minimizing redundant memory usage across the data-parallel dimension.
Key Features and Capabilities
Fully Sharded Data Parallel (FSDP) is a PyTorch-native distributed training strategy that shards model parameters, gradients, and optimizer states across GPUs to maximize memory efficiency, enabling the training of models larger than the memory of any single device.
Zero Redundancy Optimizer (ZeRO-3) Implementation
FSDP is a direct implementation of the ZeRO-3 optimization stage from Microsoft's DeepSpeed framework, now integrated natively into PyTorch. It eliminates memory redundancy by partitioning the three primary components of the training state:
- Model Parameters: Sharded across all data-parallel processes.
- Gradients: Each GPU computes and stores only the gradients for its shard of parameters.
- Optimizer States: Partitioned corresponding to the parameter shards. This approach reduces the per-GPU memory footprint nearly linearly with the number of GPUs, enabling the training of models with trillions of parameters.
Memory-Efficient Activation Management
Beyond sharding parameters, FSDP incorporates sophisticated activation management. During the forward pass, each GPU materializes only the full parameters for the layers currently being computed; all other parameters remain sharded. The intermediate activations from these layers are stored for the backward pass. To save further memory, FSDP can be combined with gradient checkpointing (also known as activation checkpointing), which selectively discards and recomputes activations, trading compute for a significantly reduced memory footprint. This is crucial for training deep networks with large batch sizes.
Flexible Sharding Strategies
FSDP provides configurable policies to balance memory savings with communication overhead:
- FULL_SHARD: The default and most memory-efficient. Shards parameters, gradients, and optimizer states (ZeRO-3).
- SHARD_GRAD_OP: Shards only gradients and optimizer states (ZeRO-2), keeping parameters replicated for faster forward/backward passes with higher memory use.
- NO_SHARD: Essentially traditional DataParallel, with full replication. Used for benchmarking.
- HYBRID_SHARD: Shards within a node but replicates across nodes, optimizing for intra-node NVLink bandwidth and inter-node communication costs. This flexibility allows engineers to tune for their specific hardware topology.
Communication and Overlap Optimizations
FSDP is designed to hide communication latency. It uses non-blocking all-gather operations to collect the full parameters for a layer just before its computation ("just-in-time" materialization). Crucially, it can overlap this communication with the computation of the previous layer. Similarly, during the backward pass, the reduce-scatter operation for gradients is overlapped with computation. These optimizations, combined with configurable CPU offloading for optimizer states and parameters, help minimize the performance penalty of sharding and enable near-linear scaling efficiency.
Integration with Mixed Precision Training
FSDP seamlessly integrates with PyTorch's Automatic Mixed Precision (AMP) to further accelerate training and save memory. It supports:
- BF16/FP16 for Computation: Using lower precision (bfloat16 or float16) for forward and backward passes.
- Full Precision for Reduction: Maintaining master weights in full precision (FP32) for numerical stability during optimizer updates.
- Buffer Precision Control: Keeping non-trainable parameters (e.g., LayerNorm weights) in full precision. This combination allows FSDP to leverage the speed of tensor cores on modern GPUs while maintaining the convergence stability of full-precision optimization.
Comparison to Alternative Strategies
FSDP sits within a spectrum of distributed training paradigms:
- vs. Distributed Data Parallel (DDP): DDP replicates the entire model on each GPU, synchronizing gradients. It has lower communication overhead but cannot train models larger than a single GPU's memory.
- vs. Pipeline Parallelism: Pipeline parallelism splits the model by layers across devices, which can lead to GPU idle time (bubble). FSDP is orthogonal and can be combined with it for 3D parallelism.
- vs. Tensor Parallelism: Tensor parallelism splits individual layer operations (e.g., matrix multiplications) across devices, requiring very high bandwidth. FSDP is often preferred for its simpler implementation and good scaling on commodity clusters. FSDP's primary advantage is its transparency; it requires minimal code changes from a standard training loop, unlike more complex parallel strategies.
FSDP vs. Other Distributed Training Strategies
This table compares the core mechanisms and tradeoffs of Fully Sharded Data Parallel (FSDP) against other common distributed training paradigms, focusing on memory efficiency, communication overhead, and implementation complexity.
| Feature / Mechanism | Fully Sharded Data Parallel (FSDP) | Data Parallel (DP / DDP) | Model Parallel (MP) | Pipeline Parallel (PP) |
|---|---|---|---|---|
Core Partitioning Unit | Model states (params, gradients, optimizer states) | Data batches | Model layers or weights | Model layers (sequential stages) |
Memory Redundancy | None (Zero Redundancy) | Full model, gradients, and optimizer states on each GPU | Partial model per GPU, but full gradients/optimizer states for local params | One model stage per GPU, with full optimizer states for local params |
Peak GPU Memory per Device | ~1/N of full model states + activations | Full model states + 1/N of activations | ~1/N of model weights + full activations for local layers | ~1/N of model weights + activations for local stage + pipeline bubbles |
Inter-GPU Communication Volume | High (all-gather params, reduce-scatter grads) | Moderate (all-reduce gradients only) | Very High (forward/backward pass activations & gradients) | Moderate (forward activations, backward gradients between stages) |
Communication Synchronization | Synchronous (blocking all-gather/reduce-scatter) | Synchronous (blocking all-reduce) | Synchronous (blocking send/recv per layer) | Synchronous with pipeline bubbles (non-blocking between micro-batches) |
Implementation Complexity | High (PyTorch native, requires careful sharding config) | Low (PyTorch DDP is fully automated) | Very High (Manual model splitting, custom communication) | High (Requires pipeline scheduling, chunking, balance) |
Optimal Use Case | Training extremely large models that don't fit on a single GPU | Training models that fit on a single GPU with large batches | Training models with layers too large for one GPU (e.g., massive FFN) | Training models with a long, sequential layer stack (e.g., deep transformers) |
Activation Memory Optimization | Yes (via activation checkpointing per FSDP block) | No (activations stored for full local forward pass) | Yes (activations only for local layers) | Partial (activations stored per micro-batch in pipeline) |
Framework Integration and Usage
Fully Sharded Data Parallel (FSDP) is a PyTorch-native distributed training strategy that shards model parameters, gradients, and optimizer states across GPUs to maximize memory efficiency, enabling the training of models larger than the memory of a single device.
Core Sharding Strategy
FSDP implements ZeRO Stage 3 optimization by partitioning three key components across all data-parallel workers:
- Model Parameters: Each GPU stores only a shard of the full model's parameters.
- Gradients: Gradients are reduced only for the shard of parameters owned by each GPU.
- Optimizer States: Each GPU maintains optimizer states (e.g., momentum) only for its parameter shard. During the forward pass, parameters are gathered via all-gather communications as needed. After the backward pass, the now-redundant gathered parameters are discarded to free memory. This approach eliminates memory redundancy, allowing model size to scale almost linearly with the number of GPUs.
PyTorch API Integration
FSDP is integrated directly into PyTorch via torch.distributed.fsdp. Wrapping a model module is the primary API call:
pythonfrom torch.distributed.fsdp import FullyShardedDataParallel as FSDP model = FSDP(model)
Key configuration options include:
- Sharding Strategy:
FULL_SHARD(default),SHARD_GRAD_OP,NO_SHARD. - Auto Wrapping: Recursively wraps sub-modules to create nested FSDP instances, allowing finer-grained communication and memory control.
- Mixed Precision: Configurable via
MixedPrecisionpolicies for parameters, gradients, and buffer data types. - CPU Offload: Enables offloading parameters, gradients, and optimizer states to CPU RAM, trading communication for even greater memory savings.
Communication and Computation Overlap
A key performance optimization in FSDP is overlapping communication (all-gather) with computation. The strategy is progressive all-gather:
- Parameters for the next layer are gathered while the current layer's computation is ongoing.
- This hides a significant portion of the communication latency inherent in fetching remote parameter shards. Similarly, after the backward pass through a layer, the reduce-scatter operation for gradients can be overlapped with the backward computation of the preceding layer. Effective overlap requires careful tuning of the flattening of parameters and the use of non-blocking communication collectives to ensure the GPU compute stream is not idle.
Memory vs. Communication Trade-off
FSDP introduces a fundamental trade-off: reduced GPU memory footprint at the cost of increased inter-GPU communication. The primary communication costs are:
- All-gather in the forward pass to collect parameters.
- Reduce-scatter in the backward pass to aggregate gradients. The total communication volume is proportional to the model size. Therefore, FSDP is most beneficial when the model is too large to fit in memory, making the communication overhead acceptable. Performance is highly dependent on interconnect bandwidth (e.g., NVLink, InfiniBand). For smaller models that fit in memory, traditional Distributed Data Parallel (DDP) may be faster due to lower communication overhead.
Comparison with DDP and ZeRO
FSDP vs. Distributed Data Parallel (DDP):
- DDP: Replicates the full model on each GPU. Only gradients are communicated and averaged via all-reduce. Higher memory usage, lower communication.
- FSDP: Shards model states. Uses all-gather/reduce-scatter. Lower memory usage, higher communication.
FSDP vs. DeepSpeed's ZeRO-3:
Both implement the same core ZeRO-3 ideas. FSDP is PyTorch-native, offering tighter integration with PyTorch APIs, modules, and the torch.compile ecosystem. DeepSpeed ZeRO-3 is part of a broader suite of optimizations and may offer more granular configuration for extreme-scale training. The choice often depends on existing codebase and framework preference.
Use Case: Fine-Tuning Large Language Models
FSDP is a cornerstone technique for parameter-efficient fine-tuning (PEFT) of massive models where even a single copy of the model exceeds GPU memory. Common patterns include:
- Combining FSDP with LoRA or Adapter modules, where the base model is sharded and frozen, and the small trainable adapters are replicated.
- Using CPU offload to fine-tune multi-billion parameter models on a limited number of GPUs by leveraging system RAM.
- Applying selective activation checkpointing (gradient checkpointing) within FSDP-wrapped modules to trade extra recomputation for even greater memory savings, enabling larger batch sizes or sequence lengths. This makes FSDP essential for adapting state-of-the-art models to specific domains without access to massive GPU clusters.
Frequently Asked Questions
Fully Sharded Data Parallel (FSDP) is a PyTorch-native distributed training strategy that shards model parameters, gradients, and optimizer states across GPUs to maximize memory efficiency, enabling the training of models larger than the memory of a single device.
Fully Sharded Data Parallel (FSDP) is a PyTorch-native distributed training strategy that implements the ZeRO-3 optimization stage from Microsoft's DeepSpeed. It works by sharding a model's parameters, gradients, and optimizer states across all GPUs in a data-parallel group. During the forward pass, each GPU materializes only the shard of parameters it owns, fetching other parameters via all-gather communication as needed. After computing its local loss, gradients are reduced-scattered so each GPU updates its owned parameter shard. This eliminates memory redundancy, allowing the training of models significantly larger than the memory of a single GPU.
Key Mechanism:
- Parameter Sharding: Model parameters are partitioned across devices.
- On-Demand Materialization: Full parameters are assembled via
all_gatheronly for the layers currently being computed. - Gradient Reduction: Gradients are reduced via
reduce_scatterto the GPU owning the corresponding parameter shard. - Optimizer State Sharding: Each GPU stores and updates optimizer states (e.g., momentum) only for its parameter shard.
Enabling Efficiency, Speed & Accuracy
Intelligent Analysis, Decision & Execution
We build AI systems for teams that need search across company data, workflow automation across tools, or AI features inside products and internal software.
Talk to Us
Search across company data
Give teams answers from docs, tickets, runbooks, and product data with sources and permissions.
Useful when people spend too long searching or get different answers from different systems.

Automate internal workflows
Use AI to route work, draft outputs, trigger actions, and keep approvals and logs in place.
Useful when repetitive work moves across multiple tools and teams.

Add AI to products and internal tools
Build assistants, guided actions, or decision support into the software your team or customers already use.
Useful when AI needs to be part of the product, not a separate tool.
Related Terms
FSDP is a key technique within a broader ecosystem of methods designed to train massive models efficiently. These related concepts address the core challenges of distributed computation, memory management, and parameter efficiency.
Distributed Data Parallel (DDP)
Distributed Data Parallel is PyTorch's standard synchronous data-parallel training method. Each GPU holds a full replica of the model, optimizer, and gradients. In each iteration:
- The same mini-batch is processed independently.
- Gradients are averaged across all processes via an
all-reduceoperation. - Identical weight updates are applied. FSDP builds directly upon DDP's communication primitives but crucially shards the model parameters, whereas DDP replicates them. This makes DDP memory-inefficient for very large models, as GPU memory must hold the entire model.
Mixed Precision Training
Mixed precision training uses lower numerical precision (e.g., torch.float16 or torch.bfloat16) for most operations to speed up computation and reduce memory usage, while maintaining higher precision (torch.float32) for critical operations like weight updates to preserve numerical stability. FSDP is fully compatible with PyTorch's AMP (Automatic Mixed Precision). The memory savings from mixed precision (smaller activations and gradients) are multiplicative with the savings from FSDP's sharding, enabling the training of even larger models or the use of larger batch sizes.
Model Parallelism
Model parallelism is a broader class of techniques for partitioning a model across multiple devices. Unlike data parallelism (where each device has a full model copy), model parallelism splits the model's layers or components. Key variants include:
- Tensor Parallelism: Splits individual layer operations (e.g., matrix multiplications) across devices.
- Pipeline Parallelism: Places different groups of model layers on different devices. FSDP is a form of data parallelism enhanced with model-state sharding. It can be combined with pipeline or tensor parallelism in a 3D parallelism strategy for extreme-scale training.
Parameter-Efficient Fine-Tuning (PEFT)
Parameter-Efficient Fine-Tuning methods, such as LoRA or Adapter Layers, adapt a pre-trained model to a new task by updating only a tiny fraction of its parameters. While FSDP addresses the memory cost of training, PEFT addresses the storage and compute cost of adaptation. They operate at different lifecycle stages: FSDP is for large-scale pre-training or full fine-tuning, where all parameters are updated but sharded. PEFT is for downstream task adaptation, where most parameters are frozen. However, PEFT methods can themselves be trained using FSDP when the base model is enormous.

About the author
Prasad Kumkar
CEO & MD, Inference Systems
Prasad Kumkar is the CEO & MD of Inference Systems and writes about AI systems architecture, LLM infrastructure, model serving, evaluation, and production deployment. Over 5+ years, he has worked across computer vision models, L5 autonomous vehicle systems, and LLM research, with a focus on taking complex AI ideas into real-world engineering systems.
His work and writing cover AI systems, large language models, AI agents, multimodal systems, autonomous systems, inference optimization, RAG, evaluation, and production AI engineering.
Partnered with leading AI, data, and software stack.
How We Work
Custom AI workflows for your Business
One-fit-all AI don't work for modern businesses. At Inferensys, we aim to understand your business & custom requirements; which we use to define most efficient agentic workflows, the data, and the tools for your business.
01
Review the use case
We understand the task, the users, and where AI can actually help.
Read more02
Pick the right approach
We define what needs search, automation, or product integration.
Read more03
Build the first useful version
We implement the part that proves the value first.
Read more04
Improve from there
We add the checks and visibility needed to keep it useful.
Read moreThe first call is a practical review of your use case and the right next step.
Talk to Us