Skip to content

[Megatron-FSDP] Add dtype customization to Megatron-FSDP.#3067

Merged
cspades merged 19 commits intoNVIDIA:mainfrom
cspades:cye/mfsdp-custom-dtype
Mar 3, 2026
Merged

[Megatron-FSDP] Add dtype customization to Megatron-FSDP.#3067
cspades merged 19 commits intoNVIDIA:mainfrom
cspades:cye/mfsdp-custom-dtype

Conversation

@cspades
Copy link
Copy Markdown
Member

@cspades cspades commented Jan 24, 2026

What does this PR do ?

  • Add customization options for main_params_dtype, main_grads_dtype, and grad_comm_dtype to Megatron-FSDP.
    • grad_accum_dtype (high-precision gradient reduce-scatter / all-reduce) will be handled on NVLink by NCCL UBR / SymMem for v2.27+. IB domain reduce-scatter requires NCCL 2.29U1+, while all-reduce is not currently supported for IB.
  • Various bug-fixes, refactors, documentation updates, commentary, etc.

All planned performance benchmarks completed with no bugs. Ready for expert & final review!

Details

Mixed-Precision Support (megatron_fsdp.MixedPrecisionPolicy)

  • main_params_dtype / --megatron-fsdp-main-params-dtype (🍀 NEW! 🍀) and main_grads_dtype / --megatron-fsdp-main-grads-dtype (🍀 NEW! 🍀) are simple generalizations of preserve_fp32_weights (⛔ DEPRECATED ⛔) and grad_reduce_in_fp32.
    • If not specified, the model weight buffer or model parameters become the main weights in that order of succession, and the main gradient buffer's data-type will be symmetrical to the model compute weight data-type.
  • grad_comm_dtype / --megatron-fsdp-grad-comm-dtype (🍀 NEW! 🍀) controls the data-type used for gradient communication (all-reduce & reduce-scatter).
    • If main_grads_dtype is not equivalent to grad_comm_dtype, a communication bucket with the communication data-type will be allocated. Otherwise, and if not specified, the main_grads_dtype will be the communication data-type.

Megatron-FSDP Gradient Lifecycle

To summarize the gradient pipeline of Megatron-FSDP for the uninitiated:

# Gradient Memory Lifecycle
# (...) = HFSDP pre-optimizer steps only!
param.grad -> DP-Shard Grad Alloc -> Reduce -> (Wait -> Free -> DP-Outer Grad Alloc -> Reduce) -> Wait -> Free
              ^ These are communication buffers to hold un-sharded or partially-sharded gradients.
  • Autograd produces the raw, un-reduced model gradient.
  • Megatron-FSDP's gradient buffer allocates a communication bucket to temporarily hold the un-reduced gradient.
    • 🍀 On main, this communication bucket matches the main gradient buffer data-type. So we cannot have low-precision communication buckets with high-precision main gradients.
    • Now, this is controllable with grad_comm_dtype, to support low-precision communication and high-precision reduction with NCCL (v2.27+).
  • The raw un-reduced gradient is either copied (sharded gradients) or accumulated (un-sharded gradients) into the allocated gradient bucket or main gradient buffer.
  • The gradient communication bucket is retrieved, and passed to the reduce-scatter or all-reduce collective. Accumulation is performed via type-promotion with respect to the main_grads_dtype, typically FP32.
    • For no_shard and optim, this is an local all-reduce or reduce-scatter that can only be called once per optimization cycle to avoid corrupt gradients.
      • This is why we do not immediately allocate a communication buffer for these two cases (and require an allocation per-unit before the collective for custom communication data-type), because the no_shard and optim sharding strategies definitively do not permit a second un-sharded memory allocation in order to maintain both communication and accumulation buffers for the gradient (one for BF16 communication, another for FP32 accumulation) until we finally perform the only DP-reduction right before the optimization step. Thus, we temporarily allocate / deallocate a BF16 communication buffer right before gradient reduction, while persistently allocating an FP32 main gradient bucket.
    • For optim_grads and optim_grads_params, this is a reduce-scatter into the allocated communication bucket, and shards of the result are accumulated into the main gradient buffer. Because we reduce every layer of every step, we only persistently hold onto a reduced and accumulated shard of the gradient.
  • All previous gradient collectives are synchronized if overlapped, gradient shards are attached to the parameter shards of the optimizer state, the model parameters are re-referenced to the distributed optimizer state parameters, and the distributed optimizer step is performed.
  • Finally, the main gradient buffer is zero'd out, which establishes a clean slate for subsequent reduction and accumulation, and the optimized main weights are installed into Megatron-FSDP's compute weight buffer.

🚨 Bug Fixes 🚨

  • optim had corrupted gradients, where the main gradient would be reduce-scattered into a temporary shard, but the reduced shard would be accumulated back into the source main gradient shard (without zero'ing the buffer), leading to duplicate gradients.
    • Fixed by adding copy and += cases to the DP-Shard gradient reduction.
    • For example, with DP-Shard=2, and only 1 accumulation / optimization step for simplicity with (...) representing the reduced gradient and gN representing the pre-reduce accumulated gradient:
      • Rank 0 Expected: (g1 + g2)
      • Rank 0 Actual: g1 + (g1 + g2)
      • Without the torch.empty_like temporary shard, the bug would have doubled the gradient when using optim, i.e. (g1 + g2) += (g1 + g2)!
      • Causes main gradient disparity on all DP ranks.
    • With custom data-type buckets, the generalized logic also works correctly, where for optim we copy the reduced gradient shard into the main gradient buffer if a communication buffer was allocated, otherwise the reduce-scatter directly updates the shard of the main gradient buffer. (Same for no_shard as well, but using all-reduce and copying the reduced un-sharded gradient.)
  • Also discovered a broken gradient DP-scaling while working on this PR: [Megatron-FSDP] Fix incorrect gradient scaling target. #3023

Minor Edits

  • Refactored free_bucket_storage() to remove the criteria that only deallocates buckets for sharded buffers and factor out the param.main_grad reset to reset_param_main_grad().
    • fetch_bucket() will only allocate temporary buckets if the data-type is different, or if the buffer is sharded. So there is a loophole where a custom data-type allocation will not be deallocated if the buffer is sharded.
    • Modules that are not FSDP units should not have their buckets deallocated, but this is controlled by the post-forward and post-backward un-shard hooks that call AllGatherPipeline.release_bucket().
    • reset_param_main_grad() only needs to be called when the FSDP gradient buffer on DP-Shard has completed its collectives and installed the reduced gradient in local data.
      • param.main_grad will first point to the unreduced gradient bucket, and then point to the DP-Shard reduced main gradient buffer data (or a custom data-type variant of the aforementioned values).
  • Implemented check_for_nan_in_grad for Megatron-LM (called in start_grad_sync) and report_nan_in_param_grad for fully_shard, which both default to False in MegatronFSDP. report_nan_in_param_grad in particular is an expensive operation that can degrade performance by around 5%, but can be extremely useful for quickly debugging the source of NaNs, whether they come from Megatron-FSDP or user models.
  • Updated and removed all un-used config options in Megatron-FSDP's version of DDPConfig.
    • @BoxiangW and I discussed this during our meetings, but this is just the first baby step to perhaps completely annihilating the second variant of the DDPConfig so it doesn't confuse users, and flatten all necessary arguments into Megatron-FSDP directly.
  • Updated documentation on DDP, where all-reduce should not be called repetitively during gradient accumulation, and added warning messages for the user to zero the gradient buffer every step when this kind of behavior is detected.
    • We have a variety of options to handle this such as no_sync in Megatron-LM and an even simpler sync() / MegatronFSDP.set_model_auto_sync() for Megatron-external use (the opposite of no_sync that basically calls all the necessary functions to make Megatron-FSDP low-code in a vanilla training loop).

Tests

  • Added unit tests that cover all relevant sharding and mixed-precision strategies, as well as changing the gradient communication data-type mid-flight (which is not allowed for NCCL UBR).

All performance tests below use the following configuration (unless otherwise specified):

  • Llama 8B
    • The larger the model, the more performance improvements we can observe using low-precision communication, because the volume of communications will increase.
  • FP32 Main Parameters and Main Gradients
  • FP32 Gradient Accumulation and Reduction
  • FP8 Delayed Scaling + Parameter AG
  • optim_grads_params
    • For HFSDP, --outer-dp-sharding-strategy optim and --num-distributed-optimizer-instances 2.
  • Full Activation Recompute
  • GBS128 / MBS1
  • NCCL User Buffer Registration / FSDP Double Buffer
    • --use-nccl-ub --fsdp-manual-registration --fsdp-double-buffer for NCCL UB perf experiments.

Performance & Accuracy Parity with FP32 Gradient Communication + Accumulation (Reduce-Scatter)

  • With both communication and accumulation set to FP32 (still the default in Megatron-LM), the performance and accuracy are identical with main branch.
# Main Branch (FP32 Gradient Reduce-Scatter)
[2026-01-29 09:54:18.103917] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 10925.4 | throughput per GPU (TFLOP/s/GPU): 617.5 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.154529E+00 | loss scale: 1.0 | grad norm: 5.414 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Mixed-Precision (FP32 Gradient Reduce-Scatter)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 18516.73 | max allocated: 60039.38 | reserved: 21228.00 | max reserved: 66490.00
[2026-02-27 18:13:37.705927] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 10893.1 | throughput per GPU (TFLOP/s/GPU): 619.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.154007E+00 | loss scale: 1.0 | grad norm: 5.409 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

Mixed-Precision BF16 Gradient Communication + FP32 Gradient Reduction / Accumulation

  • Setting --megatron-fsdp-grad-comm-dtype bf16 enables BF16 communication and FP32 reduction / accumulation if NCCL 2.27+ is used with NCCL UBR for pure FSDP.
    • Compared to FP32 communications, we have a ~5% speedup for model compute as small as Llama 8B.
    • Loss is equivalent to the FP32 communications case, which implies FP32 reduction over NCCL.
# Mixed Precision (BF16 Comm / FP32 Reduce-Accum via NCCL UBR SymMem)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 19765.60 | max allocated: 60646.15 | reserved: 22424.00 | max reserved: 66844.00
[2026-02-27 18:35:04.045681] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 10424.1 | throughput per GPU (TFLOP/s/GPU): 647.2 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.154608E+00 | loss scale: 1.0 | grad norm: 5.418 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • The performance improvement is even more apparent with a large model and smaller compute, such as with Llama 70B @ 1K SeqLen, where we get loss parity and a 22.5% speedup in communications compared to main branch!
# Llama 70B Main Branch (FP32 Gradient Reduce-Scatter + NCCL UB)
[2026-02-16 21:53:58.402418] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 3300.5 | throughput per GPU (TFLOP/s/GPU): 261.3 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 2.233609E+00 | loss scale: 1.0 | grad norm: 26.389 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Llama 70B Mixed Precision (BF16 Comm / FP32 Reduce-Accum via NCCL UBR SymMem)
[2026-02-27 23:33:50.230377] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2559.1 | throughput per GPU (TFLOP/s/GPU): 337.0 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 2.234128E+00 | loss scale: 1.0 | grad norm: 26.477 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

HFSDP Performance & Accuracy Tests (BF16 Gradient Communication + FP32 Gradient Reduction / Accumulation)

  • --num-distributed-optimizer-instances 2 and --outer-dp-sharding-strategy optim has parity on loss after 100 steps, and is just shy of 4x faster (3.62x) per global batch from 4 Nodes on Llama 8B, compared with FSDP.
    • Slight accuracy disparity only when using HFSDP with BF16 gradient communication.
# HFSDP 4-Node (DP-Outer=2, DP-Shard=16) + BF16 Comm + FP32 Reduce / Accum + NCCL UBR
[2026-02-27 21:28:39.433296] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2845.5 | throughput per GPU (TFLOP/s/GPU): 592.8 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.209103E+00 | loss scale: 1.0 | grad norm: 3.962 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

Extra Tests

  • With the optim gradient fix, and GBS 128 / MBS 1, we have improved loss (5.48 vs. 5.56) and reduced gradient norm (19.143 vs. 22.110) as we are no longer duplicating the gradient on the local rank, i.e. grad_i + sum(grad_i) instead of the expected sum(grad_i).
    • This difference becomes more difficult to see the higher DP you have, since the DP-reduced sum dominates the local gradient in magnitude.
# Main (Optimizer Sharding / Llama 8B @ SeqLen 1K / FP32 Main Params, Grads, Grad Comm)
[2026-02-17 08:41:03.817490] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2793.4 | throughput per GPU (TFLOP/s/GPU): 268.8 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.560552E+00 | loss scale: 1.0 | grad norm: 22.110 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Optim Bug-Fix (Optimizer Sharding / Llama 8B @ SeqLen 1K / FP32 Main Params, Grads, Grad Comm)
[2026-02-17 08:52:45.746740] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 2776.0 | throughput per GPU (TFLOP/s/GPU): 270.5 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.486691E+00 | loss scale: 1.0 | grad norm: 19.143 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • Symmetric kernels are used for AG, not but always RS.
    • In particular, I have observed minor loss disparity when using HFSDP with grad_comm_dtype=torch.bfloat16 due to the lack of symmetric RS kernels that is difficult to reproduce on H100.
eos0552:2730518:2731948 [3] NCCL INFO ReduceScatter: 54534144 Bytes -> Algo RING proto SIMPLE channel{Lo..Hi}={0..23}
eos0552:2730516:2731922 [1] NCCL INFO ReduceScatter: 131465216 Bytes -> Algo RING proto SIMPLE channel{Lo..Hi}={0..23}
eos0552:2730518:2731948 [3] NCCL INFO ReduceScatter: 131465216 Bytes -> Algo RING proto SIMPLE channel{Lo..Hi}={0..23}
eos0552:2730519:2731952 [4] NCCL INFO ReduceScatter: 4096 Bytes -> Algo RING proto LL channel{Lo..Hi}={0..0}
eos0552:2730519:2731952 [4] NCCL INFO ReduceScatter: 109068288 Bytes -> Algo RING proto SIMPLE channel{Lo..Hi}={0..23}
eos0552:2730519:2731952 [4] NCCL INFO ReduceScatter: 262930432 Bytes -> Algo RING proto SIMPLE channel{Lo..Hi}={0..23}
eos0552:2730515:2730515 [0] NCCL INFO AllReduce: 4 Bytes -> Algo RING proto LL channel{Lo..Hi}={0..0}
eos0552:2730515:2730515 [0] NCCL INFO AllReduce: 4 Bytes -> Algo RING proto LL channel{Lo..Hi}={0..0}
eos0552:2730519:2730519 [4] NCCL INFO AllReduce: 4 Bytes -> Algo RING proto LL channel{Lo..Hi}={0..0}
eos0552:2730522:2730522 [7] NCCL INFO AllGather [Symmetric]: 131465216 Bytes -> Kernel AllGather_ST nchannels 9 nthreads 512
eos0552:2730521:2730521 [6] NCCL INFO AllGather [Symmetric]: 131465216 Bytes -> Kernel AllGather_ST nchannels 9 nthreads 512
eos0552:2730522:2730522 [7] NCCL INFO AllGather [Symmetric]: 131473408 Bytes -> Kernel AllGather_ST nchannels 9 nthreads 512
eos0552:2730521:2730521 [6] NCCL INFO AllGather [Symmetric]: 131473408 Bytes -> Kernel AllGather_ST nchannels 9 nthreads 512
  • FP32 activations (no --bf16 argument) works without any issues with NCCL UBR.
# FP32 Activations (BF16 Grad Comms + NCCL UBR / SymMem)
[2026-02-12 18:09:30.972237] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 19961.1 | throughput per GPU (TFLOP/s/GPU): 158.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.620105E+00 | loss scale: 1.0 | grad norm: 6.975 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • Test HFSDP vs. FSDP loss parity without using a torch.empty_like output buffer for HFSDP, i.e. DP-Outer reduce-scatter is in-place as in DP-Shard on 1 Node / 8 GPUs.
# Main Branch FSDP (Llama 8B @ SeqLen 1K / FP32 Main Params, Grads, Grad Comm)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23425.69 | max allocated: 33573.09 | reserved: 26276.00 | max reserved: 39934.00
[2026-02-27 20:19:44.121167] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 3789.9 | throughput per GPU (TFLOP/s/GPU): 198.2 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.476863E+00 | loss scale: 1.0 | grad norm: 9.411 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Mixed-Precision HFSDP (Llama 8B @ SeqLen 1K / FP32 Main Params, Grads, Grad Comm)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23425.69 | max allocated: 33573.09 | reserved: 26276.00 | max reserved: 39176.00
[2026-02-27 19:45:20.030360] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 3779.1 | throughput per GPU (TFLOP/s/GPU): 198.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.475965E+00 | loss scale: 1.0 | grad norm: 9.424 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Mixed-Precision HFSDP (Llama 8B @ SeqLen 1K / FP32 Main Params, Grads, Grad Comm / NCCL UBR)
[2026-02-27 21:12:32.207370] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 3685.9 | throughput per GPU (TFLOP/s/GPU): 203.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.475298E+00 | loss scale: 1.0 | grad norm: 9.426 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

# HFSDP BF16 Comm + FP32 Reduce / Accum + NCCL UBR
[2026-02-27 21:01:44.028797] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 3482.6 | throughput per GPU (TFLOP/s/GPU): 215.6 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 5.534179E+00 | loss scale: 1.0 | grad norm: 9.544 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • Checking for NaN for all weight gradients with fully_shard(report_nan_in_param_grad=True) costs a slight performance regression of +5% global step time. Should only be turned on for debugging!
[2026-01-30 09:51:37.769937] iteration        6/15258789 | consumed samples:          768 | elapsed time per iteration (ms): 11512.2 | throughput per GPU (TFLOP/s/GPU): 586.0 | learning rate: 2.949118E-08 | global batch size:   128 | lm loss: 1.213158E+01 | loss scale: 1.0 | grad norm: 28.643 | num zeros: 0 | number of skipped iterations:   0 | number of nan iterations:   0 |

Future Work

  • NCCL symmetric memory is finicky, as not all collectives are guaranteed to use symmetric kernels. @youngeunkwon0405 @cspades @shjwudp to investigate how to increase the probability of symmetric kernels in Megatron-FSDP.
  • param_comm_dtype doesn't have that much use right now outside of the already supported TransformerEngine FP8 AG, so we will defer this to the future when we have plans for quantized AG for non-FP8 parameters, which in itself requires some research into the effect of extra quantization operations on sharded parameters vs. un-sharded parameters in model training.
  • @cspades Add a MegatronFSDP.__init__(debug=False) argument for improved unit tests.
  • @cspades Better MixedPrecisionPolicy modification support, currently only easy to use for training steps, or if the user adds hooks or code to modify the gradient communication data-type before the post-backward reduction.

Appendix

Type-Promotion Examples

  • TL;DR Type-promotion is equivalent to casting everything to the higher precision before operation, and can affect the numerics through down-casts even when the output precision is lower than the type-promoted precision.
"""
Input DType: torch.float32
Output DType: torch.bfloat16

REDUCTION TESTS

Reduce with torch.float32, and cast to torch.bfloat16:
 tensor([ 11.8750, 252.0000, -31.0000], dtype=torch.bfloat16)
Reduce via type-promotion into torch.bfloat16:
 tensor([ 11.8750, 252.0000, -31.0000], dtype=torch.bfloat16)
Cast to torch.bfloat16, and then reduce:
 tensor([ 11.8750, 252.0000, -31.1250], dtype=torch.bfloat16)
Use torch.sum(torch.bfloat16) to reduce:
 tensor([ 11.8750, 252.0000, -31.1250], dtype=torch.bfloat16)

ACCUMULATION TESTS

torch.bfloat16.add_(torch.float32):
 tensor([   3.6094, -180.0000,   -0.7031], dtype=torch.bfloat16)
torch.sum(dtype=torch.float32).to(torch.bloat16):
 tensor([   3.6094, -180.0000,   -0.7031], dtype=torch.bfloat16)
torch.sum(dtype=torch.bfloat16):
 tensor([   3.6094, -180.0000,   -0.7188], dtype=torch.bfloat16)

-------------------

Input DType: torch.bfloat16
Output DType: torch.float32

REDUCTION TESTS

Reduce with torch.bfloat16, and cast to torch.float32:
 tensor([ 11.8750, 252.0000, -31.1250])
Reduce via type-promotion into torch.float32:
 tensor([ 11.8750, 252.2500, -31.0781])
Cast to torch.float32, and then reduce:
 tensor([ 11.8750, 252.2500, -31.0781])
Use torch.sum(torch.float32) to reduce:
 tensor([ 11.8750, 252.2500, -31.0781])

ACCUMULATION TESTS

torch.float32.add_(torch.bfloat16):
 tensor([   3.6099, -179.7130,   -0.7188])
torch.sum(dtype=torch.bfloat16).to(torch.float32):
 tensor([   3.6094, -180.0000,   -0.7188])
torch.sum(dtype=torch.float32):
 tensor([   3.6099, -179.7130,   -0.7188])
"""

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@cspades cspades requested a review from deepakn94 January 24, 2026 02:05
@cspades cspades self-assigned this Jan 24, 2026
@cspades cspades requested review from a team as code owners January 24, 2026 02:05
@ko3n1g ko3n1g requested a review from a team January 24, 2026 02:06
@ko3n1g ko3n1g added this to the Core 0.16 milestone Jan 24, 2026
@cspades cspades requested a review from shjwudp January 24, 2026 02:09
@cspades cspades added Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. module: megatron-fsdp labels Jan 24, 2026
@cspades cspades force-pushed the cye/mfsdp-custom-dtype branch from 8714977 to c72e1d3 Compare January 24, 2026 02:50
@cspades cspades force-pushed the cye/mfsdp-custom-dtype branch from c72e1d3 to 107e81a Compare January 24, 2026 02:54
Copy link
Copy Markdown
Contributor

@shjwudp shjwudp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for helping clarify the necessity of implementing reduce-scatter based on A2A and the trade-offs with NCCL’s native reduce-scatter.

I think there’s still room to simplify this MR. If we keep the main goal in focus, perhaps the changes related to the gradient reduce pipeline and the bucket fetch/free operations could be reverted. A well-scoped PR focusing on a few clear objectives will make it easier to review, trace, and maintain.

Copy link
Copy Markdown
Member

@youngeunkwon0405 youngeunkwon0405 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. The user-side interface you suggested seems great!

I checked the param_and_gard_buffer part. At first glance, it seems fine to me. I wish we could have actual test results for the following cases with nccl-ub + manual-registration.

  • FSDP-only
    • AG/RS should be symmetric-kernel
  • HSDP (within the single rack of GB200 or GB300)
    • AG/RS/AR should be symmetric-kernel
  • HSDP (multi-rack of GB200 or GB300. FSDP within 64 GPUs and outer-dp for inter-rack)
    • AG/RS should be symmetric-kernel
  • HSDP + TP
    • AG/RS should be symmetric-kernel

There are two ways to check if the symmetric kernel is called or not

  1. See directly from nsys-rep
  2. Set NCCL_DEBUG=INFO NCCL_DEBUG_SUBSYS=TUNING and search for [Symmetric] in the log. In case of manual-registration, you are expected to see the [Symmetric] kernels after the first iteration (since we are registering the buffer after the first iteration).

For me, the second way was more convenient to check, but it's up to you.

@cspades
Copy link
Copy Markdown
Member Author

cspades commented Feb 26, 2026

Aligned with @shjwudp on the final design: we'll use separate (dummy / no local data) DataParallelBuffer(data=None) to support gradient communication allocations and separate FP32 from BF16 allocation, and modify the API of allocate_bucket_storage:

class DataParallelBuffer:
    def allocate_bucket_storage(
        self,
        # alloc_shard: Standardized way of getting shards or full buckets.
        alloc_shard: bool,
        # dtype: Used to dynamically configure `MixedPrecisionPolicy`. Defaults to initial `grad_comm_dtype`.
        dtype: torch.dtype = None,
        # init_values: To initialize the allocated bucket values.
        init_values: Tensor = None
    )
        bucket = self.temp_allocator(
            size=self.shard_bucket_index.size if alloc_shard else self.bucket_index.size,
            dtype if dtype is not None else self.dtype,
            self.alloc_context,
        )
        if init_values is not None:
            bucket.data.detach().copy_(init_values)
        return bucket

Should have no functional difference to the current PR state, just refactor underneath for improved maintenance and memory robustness.

The idea is independence of de-allocation in the DP-Shard gradient buffer, and allocation in the DP-Outer gradient communications.

  • Use dp_shard_grad_buf.fetch_bucket for FP32 (DDP / ZeRO-1) or BF16 (ZeRO-2/3) communication bucket.
    • If no_shard or optim, we must allocate a 2nd un-sharded gradient bucket for DP-Shard communications, so this must use dp_shard_grad_buf.allocate_bucket_storage.
  • After dp_shard_grad_rs_stream.wait() and before HSDP/HFSDP, call dp_shard_grad_buf.free_bucket_storage (and dp_shard_grad_buf.reset_main_grad, if ZeRO-2 or ZeRO-3).
  • For HFSDP/HSDP reduction, call dp_outer_grad_buf.allocate_bucket_storage for the HSDP/HFSDP communication buffer.
    • Note that dp_shard_grad_buf.free_bucket_storage and dp_outer_grad_buf.allocate_bucket_storage do NOT affect/depend on each other at all! This is the key to make the logic a bit more robust for future development, each allocator is only invoked once over one step of the entire training stream.
  • dp_outer_grad_rs_stream.wait() and dp_shard_grad_buf.free_bucket_storage used in HSDP/HFSDP.

Warning: FixedPoolAllocator will use more memory for the H(F)SDP communication bucket, which is no longer managed by the DP-Shard buffer allocator!

…tron_fsdp to remove unnecessary attributes.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…heck.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…m_dtype by deactivating SymMem for gradients.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…t freed, both used to setup NCCL UB communication buckets.

Signed-off-by: Cory Ye <cye@nvidia.com>
…sharded buffers.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…precision.

Signed-off-by: Cory Ye <cye@nvidia.com>
"""
mp_policy_reset = MixedPrecisionPolicy(
# Preserve the original main parameter + gradient data-type.
main_params_dtype=self.mp_policy.main_params_dtype,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note: Dynamically changing main_params_dtype and main_grads_dtype could be achieved by rebuilding the DataParallelBuffer, but this is not supported for now.

Copy link
Copy Markdown
Member Author

@cspades cspades Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... though it would also require calling fsdp_manual_registration again for NCCL UBR as well, it's complex indeed.

Signed-off-by: Cory Ye <cye@nvidia.com>
Copy link
Copy Markdown
Contributor

@shjwudp shjwudp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22600987668

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working Final Review PR is in the "final review" stage module: megatron-fsdp

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants