Skip to content

[main] feat(moe): Support gated delta net for Qwen3-Next (1/4)#1989

Merged
Phlip79 merged 4 commits intoNVIDIA:mainfrom
yuzhongw-nvidia:qwen3next
Jan 16, 2026
Merged

[main] feat(moe): Support gated delta net for Qwen3-Next (1/4)#1989
Phlip79 merged 4 commits intoNVIDIA:mainfrom
yuzhongw-nvidia:qwen3next

Conversation

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor

@yuzhongw-nvidia yuzhongw-nvidia commented Oct 28, 2025

What does this PR do ?

MR to dev.

Design doc

Qwen3-Next functionality PRs.

Changes in this PR:

  1. Gated Delta Net (GDN);
  2. num_floating_point_operations

New supported arguments for Qwen3-Next

--attention-output-gate: true  # Qwen3-Next applies output gate to standard attention layers
--no-weight-decay-cond-type: qwen3_next  # Qwen3-Next applies weight decay to qk layernorm as a special case
--moe-shared-expert-gate: true  # Qwen3-Next have a shared expert gate
# GDN args
--linear-attention-type: gated_delta_net
--linear-attention-freq: 4  # 1 1 1 0 1 1 1 0 ..., or we can use ([1,1,1,0]*12) to represent
--linear-conv-kernel-dim: 4
--linear-key-head-dim: 128
--linear-value-head-dim: 128
--linear-num-key-heads: 16
--linear-num-value-heads: 32

New for Qwen3-Next, but already supported in MCore yet

--rotary-percent: 0.25  # Qwen3-Next applies RoPE only to the first 25% of position dimensions
--rotary-base: 10000000  # Qwen3-Next uses a larger rotary base
--apply-layernorm-1p: true  # Zero-Centered RMSNorm for Qwen3-Next
--max-position-embeddings: 262144  # Qwen3-Next-80B-A3B supports up to 256k sequence length
--moe-shared-expert-intermediate-size: 512  # Qwen3-Next have a shared expert
# MTP: Qwen3-Next has one MTP layer
--mtp-num-layers: 1
--mtp-loss-scaling-factor: 0.1

LM loss curve with the training dataset of Qwen3 are as below (GBS=256, seq_len=4096, TP1 is in green, TP2 is in blue).

wandb url

image image image image

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@yuzhongw-nvidia yuzhongw-nvidia requested review from a team as code owners October 28, 2025 02:59
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Oct 28, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@jaredcasper
Copy link
Copy Markdown
Contributor

What tests have been done to validate the implementation of the gated delta net is correct?

@yanring yanring added the Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. label Nov 5, 2025
@yuzhongw-nvidia yuzhongw-nvidia requested review from a team as code owners November 5, 2025 07:04
@yuzhongw-nvidia yuzhongw-nvidia force-pushed the qwen3next branch 2 times, most recently from 99d022e to 1a63170 Compare November 5, 2025 08:03
@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

yuzhongw-nvidia commented Nov 5, 2025

Hi @jaredcasper , I have rebased main and make the following changes to this MR,

  1. Resolve [main] feat(moe): Support gated delta net for Qwen3-Next (1/4) #1989 (comment)
  2. Try to make Attention.get_query_key_value_tensors clear. Does it look better now?

Thanks. cc @yanring

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

What tests have been done to validate the implementation of the gated delta net is correct?

Hi @jaredcasper , we have do several things to validate the implementation.

  1. We have a UT to verify the implementation of GDN and parallel GDN.
  2. We have used the comparing tool in Megatron-Bridge to verify the alignment with HF implementation.
  3. We have run a E2E training test and see the loss curve makes sense. (wandb)

In the next few weeks, we will further add a functional test to guarantee the long-term correctness of Qwen3-Next.

@jaredcasper
Copy link
Copy Markdown
Contributor

jaredcasper commented Nov 10, 2025

In the next few weeks, we will further add a functional test to guarantee the long-term correctness of Qwen3-Next.

In the meantime do you think it makes sense to split this up and get the other features in?

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

yuzhongw-nvidia commented Nov 11, 2025

In the next few weeks, we will further add a functional test to guarantee the long-term correctness of Qwen3-Next.

In the meantime do you think it makes sense to split this up and get the other features in?

Do you mean covering multiple features in a single functional test? It will be excellent, as long as we do not have too many features in one test case making it hard to maintain.

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

Hi @jaredcasper , I remember you mentioned that you need a thorough review to the gpt_layer_spec.py. Do you have any comments or feedback about it? Does it look good, or do you think we need to refine it?

@jaredcasper
Copy link
Copy Markdown
Contributor

jaredcasper commented Nov 13, 2025

Do you mean covering multiple features in a single functional test? It will be excellent, as long as we do not have too many features in one test case making it hard to maintain.

No, I mean there are 7 changes listed in the PR, do they all need to be in one PR? I'm saying while some of the features are getting tests added for them does it make sense to open PRs for some of the others that have been tested.

@jaredcasper
Copy link
Copy Markdown
Contributor

Hi @jaredcasper , I remember you mentioned that you need a thorough review to the gpt_layer_spec.py. Do you have any comments or feedback about it? Does it look good, or do you think we need to refine it?

I thought you were going to work on it a bit after our conversation, but if you are done refining I've asked some from my team to do a thorough review.

@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Jan 12, 2026

/ok to test c574f43

Copy link
Copy Markdown
Contributor

@JRD971000 JRD971000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@kunlunl kunlunl mentioned this pull request Jan 14, 2026
@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 271bb87

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 7c2a449

Copy link
Copy Markdown
Contributor

@yanring yanring left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the refinement!

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 8fe3705

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 476bc57

key_projection_size = args.kv_channels * args.num_query_groups
value_projection_size = args.kv_channels * args.num_query_groups
standard_self_attn_term = (
3
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add explanations here since you aren't using expansion factor? Or ever better use expansion_factor to make this more readable. The transformation of this equation you've done here isn't immediately obvious, can you make it closer to the previous formulation?

Copy link
Copy Markdown
Contributor Author

@yuzhongw-nvidia yuzhongw-nvidia Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure it is correct because:

  • I have checked carefully and make sure it is equivalent to the previous code.
  • I have checked in an E2E test and find it outputs the same result as before.

I also add some comments to make it easy to understand.

I do not agree with the point "use expansion_factor to make this more readable". I know most of the people are confused by it because there is a 2x in the expansion_factor:

        # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model
        #       architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM
        #       in MLP layer).

which is only immediately obvious for MLP. Could you please tell me what does this 2x factor mean in each term? To tell you the truth, many colleagues and I finally understand it after a careful derivation, reflecting how truly unreadable it previously was. In contrast, the calculation of MLA doesn't use expansion_factor, and as you can see, it is much cleaner than MHA and GQA. Furthermore, the previous version has a chaotic calculation sequence and lacked comments, rendering it completely unreadable. That's why I argue that we need to rewrite it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different approaches, I thought the previous was much more readable. :)

In any case, there is a large comment explaining expansion_factor, and the creation of the variable itself, but expansion_factor is not used any more (unless I'm missing its use?). Please clean this up if you are going to remove expansion_factor. The comments explaining expansion factor can be moved down here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion and discussion. I refine it again and break the expansion_factor into forward_backward_expansion_factor, fma_expansion_factor, and ffn_expansion_factor and refine their comments.

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 82ba0d0

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 3529e98

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 3e1ca94

yuzhongw-nvidia and others added 2 commits January 16, 2026 02:38
Minor refines

Co-authored-by: Li Tao <lit@nvidia.com>

fix CI

get_transformer_block_with_experimental_attention_variant_spec

Update test_mamba_moe_model.py

Reopen qwen3next functional test in lightweight mode (NVIDIA#2493)

Signed-off-by: oliver könig <okoenig@nvidia.com>
Co-authored-by: oliver könig <okoenig@nvidia.com>
@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 418be3c

Signed-off-by: oliver könig <okoenig@nvidia.com>
@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test fd96ce8

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test df953ac

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 13d2cc9

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 06e3ea9

@yuzhongw-nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test f58c310

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: high dev2main: mbridge dev to main: this PR is needed in main for mbridge Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. module: moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[gated attention] Support Gated Attention from NeurIPS 25 best paper Support Qwen3 80b Next (hybrid attn) Gated deltaNet support?