From b27aab2c04aaf5fca81071bbddfa611eb1135ee9 Mon Sep 17 00:00:00 2001 From: yuzhongw-nvidia Date: Wed, 24 Dec 2025 13:39:52 +0800 Subject: [PATCH 1/4] support gated delta net MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Minor refines Co-authored-by: Li Tao fix CI get_transformer_block_with_experimental_attention_variant_spec Update test_mamba_moe_model.py Reopen qwen3next functional test in lightweight mode (#2493) Signed-off-by: oliver könig Co-authored-by: oliver könig --- gpt_builders.py | 11 +- megatron/core/jit.py | 27 +- ...rimental_attention_variant_module_specs.py | 423 +++++++++-- megatron/core/models/gpt/gpt_layer_specs.py | 25 - megatron/core/models/gpt/moe_module_specs.py | 4 +- megatron/core/ssm/gated_delta_net.py | 664 ++++++++++++++++++ megatron/core/transformer/spec_utils.py | 2 + .../core/transformer/transformer_block.py | 6 + .../core/transformer/transformer_config.py | 71 ++ megatron/training/arguments.py | 51 +- megatron/training/checkpointing.py | 8 +- megatron/training/global_vars.py | 4 + megatron/training/training.py | 112 ++- megatron/training/utils.py | 1 + pyproject.toml | 1 + .../shell_test_utils/run_ci_test.sh | 2 + .../model_config.yaml | 81 +++ tests/test_utils/recipes/gpt.yaml | 5 + .../unit_tests/models/test_mamba_moe_model.py | 7 + .../test_modelopt_module_spec.py | 1 + tests/unit_tests/ssm/test_gated_delta_net.py | 329 +++++++++ uv.lock | 30 + 22 files changed, 1764 insertions(+), 101 deletions(-) create mode 100644 megatron/core/ssm/gated_delta_net.py create mode 100644 tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp1_gdn/model_config.yaml create mode 100644 tests/unit_tests/ssm/test_gated_delta_net.py diff --git a/gpt_builders.py b/gpt_builders.py index 972ac738d50..dfe41f7b88e 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -9,6 +9,9 @@ get_gpt_mtp_block_spec, get_gpt_decoder_layer_specs, ) +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( get_gpt_heterogeneous_layer_spec, ) @@ -43,7 +46,13 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ else: use_te = args.transformer_impl == "transformer_engine" - if args.num_experts: + if args.experimental_attention_variant is not None: + transformer_layer_spec = ( + get_transformer_block_with_experimental_attention_variant_spec( + config=config, vp_stage=vp_stage + ) + ) + elif args.num_experts: assert not (config.transformer_impl == "inference_optimized") # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( diff --git a/megatron/core/jit.py b/megatron/core/jit.py index b1aa3e0b611..b67810f2e34 100644 --- a/megatron/core/jit.py +++ b/megatron/core/jit.py @@ -7,12 +7,27 @@ jit_fuser = torch.jit.script # nvFuser is deprecated in PyTorch JIT starting from 2.2 -try: - if is_torch_min_version("2.2.0a0"): - jit_fuser = torch.compile -except ImportError: - def noop_decorator(func): - return func +def noop_decorator(func): + '''No-op decorator''' + return func + +def enable_jit_fuser(): + '''Enable the JIT fuser''' + global jit_fuser + try: + if is_torch_min_version("2.2.0a0"): + jit_fuser = torch.compile + except ImportError: + + jit_fuser = noop_decorator + + +def disable_jit_fuser(): + '''Disable the JIT fuser''' + global jit_fuser jit_fuser = noop_decorator + + +enable_jit_fuser() diff --git a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py index b7f613e87d6..a7cc7cc0a55 100644 --- a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py +++ b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py @@ -1,10 +1,11 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. -from typing import Optional +from typing import List, Optional from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.models.backends import BackendSpecProvider -from megatron.core.transformer.enums import AttnMaskType +from megatron.core.ssm.gated_delta_net import GatedDeltaNet, GatedDeltaNetSubmodules +from megatron.core.transformer.enums import AttnMaskType, LayerType from megatron.core.transformer.experimental_attention_variant.dsa import ( DSAIndexer, DSAIndexerSubmodules, @@ -17,29 +18,77 @@ MLASelfAttentionSubmodules, ) from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + get_num_layers_to_build, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import ( + TransformerLayer, + TransformerLayerSubmodules, + get_transformer_layer_offset, +) + +try: + import transformer_engine as te # type: ignore[import-untyped] # pylint: disable=unused-import + + from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import nvidia_kitchen # type: ignore[import-not-found] # pylint: disable=unused-import + + from megatron.core.extensions.kitchen import KitchenSpecProvider + + HAVE_KITCHEN = True +except ImportError: + HAVE_KITCHEN = False + + +########## +# Experimental Attention Variant Module Specs +########## + + +def get_gated_delta_net_module_spec( + config: TransformerConfig, backend: BackendSpecProvider = None +) -> ModuleSpec: + """Build module spec for GatedDeltaNet attention.""" + + if backend is None: + backend = _get_backend_spec_provider(config=config) + + rms_norm = config.normalization == "RMSNorm" + attention = ModuleSpec( + module=GatedDeltaNet, + submodules=GatedDeltaNetSubmodules( + in_proj=backend.column_parallel_layer_norm_linear(), + out_norm=backend.layer_norm(rms_norm=rms_norm, for_qk=False), + out_proj=backend.row_parallel_linear(), + ), + metainfo={"fuse_input_layernorm": True}, + ) + return attention def get_dsa_module_spec_for_backend( - backend: BackendSpecProvider, - qk_layernorm: Optional[bool] = False, - qk_l2_norm: Optional[bool] = False, - multi_latent_attention: Optional[bool] = False, - num_experts: Optional[int] = None, - mlp: Optional[ModuleSpec] = None, + config: TransformerConfig, backend: BackendSpecProvider = None ) -> ModuleSpec: """Helper function to get module spec for Sparse Attention.""" - assert multi_latent_attention, "Currently only MLA supports sparse attention." - assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." + assert config.multi_latent_attention, "Currently only MLA supports sparse attention." + assert config.qk_l2_norm is False, "qk_l2_norm is not supported with MLA." linear_q_up_proj = ( backend.column_parallel_layer_norm_linear() - if qk_layernorm + if config.qk_layernorm else backend.column_parallel_linear() ) linear_kv_up_proj = ( backend.column_parallel_layer_norm_linear() - if qk_layernorm + if config.qk_layernorm else backend.column_parallel_linear() ) @@ -76,39 +125,327 @@ def get_dsa_module_spec_for_backend( ), ) - return ModuleSpec( - module=TransformerLayer, - submodules=TransformerLayerSubmodules( - input_layernorm=backend.layer_norm(), - self_attention=attention, - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - ), - ) + return attention -def get_experimental_attention_variant_module_spec_for_backend( - backend: BackendSpecProvider, - experimental_attention_variant: Optional[str] = None, - qk_layernorm: Optional[bool] = False, - qk_l2_norm: Optional[bool] = False, - multi_latent_attention: Optional[bool] = False, - num_experts: Optional[int] = None, - mlp: Optional[ModuleSpec] = None, +def get_experimental_attention_variant_module_spec( + config: TransformerConfig, backend: BackendSpecProvider = None ) -> ModuleSpec: - """Helper function to get module spec for Attention""" - if experimental_attention_variant == "dsa": - return get_dsa_module_spec_for_backend( - backend=backend, - qk_layernorm=qk_layernorm, - qk_l2_norm=qk_l2_norm, - multi_latent_attention=multi_latent_attention, - num_experts=num_experts, - mlp=mlp, + """Helper function to get module spec for experimental attention variant""" + + if backend is None: + backend = _get_backend_spec_provider(config=config) + + if config.experimental_attention_variant == "gated_delta_net": + return get_gated_delta_net_module_spec(config=config, backend=backend) + else: + raise ValueError( + f"Invalid experimental attention variant: {config.experimental_attention_variant}" + ) + + +########## +# Experimental GPT Decoder Block Spec +########## + + +def get_transformer_block_with_experimental_attention_variant_spec( + config: TransformerConfig, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None +) -> TransformerBlockSubmodules: + """Build transformer block spec with experimental attention variants (e.g., linear attention). + + This function constructs a heterogeneous transformer block that supports mixing different + attention mechanisms (experimental vs standard) and MLP types (MoE vs dense) across layers. + **Note that, this API is a experimental API in the short term, and might be deprecated in the + future. In the long run, we will move to a new design that better support hybrid models.** + + Key Design: + 1. Attention and MLP patterns: The attention pattern and MLP pattern are orthogonal + and determined independently. This allows flexible combinations (e.g., linear attention + with MoE, or standard attention with dense MLP). + - Attention pattern: derived from `config.linear_attention_freq` or + `config.experimental_attention_variant`. + - MLP pattern: derived from `config.moe_layer_freq`. + + 2. Per-Layer Spec Construction: Iterates through layers, constructing transformer + layer specs based on attention and MLP patterns. + + 3. Pipeline Slicing: Extracts layer specs for the current pipeline stage. + + Args: + config: Transformer configuration containing model hyperparameters and feature flags. + vp_stage: Virtual pipeline stage index for interleaved pipeline parallelism. + pp_rank: Pipeline model parallel rank. + + Returns: + TransformerBlockSubmodules containing per-layer specs and final layer norm. + + Note: + Currently only supports transformer_engine backend. Kitchen backend can be used as a + wrapper with TE fallback for unsupported operations. + """ + + backend = _get_backend_spec_provider(config=config) + + # Get attention patterns and specs + experimental_attention_pattern = [0] * config.num_layers + if is_linear_attention_variant(config.experimental_attention_variant): + experimental_attention_pattern = get_linear_attention_pattern(config=config) + elif config.experimental_attention_variant is not None: + experimental_attention_pattern = [1] * config.num_layers + + if 1 in experimental_attention_pattern: + experimental_attention_spec = get_experimental_attention_variant_module_spec( + config=config, backend=backend + ) + else: + experimental_attention_spec = None + + if 0 in experimental_attention_pattern: + standard_attention_spec = _get_self_attention_module_spec(config=config, backend=backend) + else: + standard_attention_spec = None + + # Get MLP patterns and specs + if config.num_moe_experts is not None: + moe_layer_pattern = get_moe_layer_pattern(config=config) + else: + moe_layer_pattern = [0] * config.num_layers + + if 1 in moe_layer_pattern: + moe_layer_spec = _get_moe_module_spec(config=config, backend=backend) + else: + moe_layer_spec = None + + if 0 in moe_layer_pattern: + dense_mlp_layer_spec = _get_dense_mlp_module_spec(config=config, backend=backend) + else: + dense_mlp_layer_spec = None + + # Get GPT decoder block layer specs + rms_norm = config.normalization == "RMSNorm" + layer_specs = [] + for layer_number in range(config.num_layers): + attention = ( + experimental_attention_spec + if experimental_attention_pattern[layer_number] == 1 + else standard_attention_spec + ) + mlp = moe_layer_spec if moe_layer_pattern[layer_number] == 1 else dense_mlp_layer_spec + input_layernorm = ( + IdentityOp + if attention.metainfo["fuse_input_layernorm"] + else backend.layer_norm(rms_norm=rms_norm, for_qk=False) + ) + pre_mlp_layernorm = ( + IdentityOp + if mlp.metainfo["fuse_pre_mlp_layernorm"] + else backend.layer_norm(rms_norm=rms_norm, for_qk=False) + ) + + layer_specs.append( + ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=input_layernorm, + self_attention=attention, + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=pre_mlp_layernorm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + ) + + # Slice the layer specs to only include the layers that are built in this pipeline stage. + if config.pipeline_model_parallel_layout is not None: + local_layer_ids = config.pipeline_model_parallel_layout.get_layer_id_list( + layer_type=LayerType.decoder, vp_stage=vp_stage, pp_rank=pp_rank + ) + else: + offset = get_transformer_layer_offset(config, vp_stage=vp_stage, pp_rank=pp_rank) + num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage, pp_rank=pp_rank) + local_layer_ids = range(offset, offset + num_layers_to_build) + + layer_specs = [layer_specs[layer_id] for layer_id in local_layer_ids] + + # Get GPT decoder block spec + gpt_decoder_block_spec = TransformerBlockSubmodules( + layer_specs=layer_specs, layer_norm=backend.layer_norm(rms_norm=rms_norm, for_qk=False) + ) + + return gpt_decoder_block_spec + + +########## +# Utilities +########## + + +def is_linear_attention_variant(experimental_attention_variant: Optional[str]) -> bool: + """Check if the experimental attention variant is a linear attention variant.""" + linear_attention_variants = ["gated_delta_net"] + return experimental_attention_variant in linear_attention_variants + + +def get_moe_layer_pattern(config: TransformerConfig) -> List[int]: + """Parse config.moe_layer_freq to get per-layer MoE pattern (1=MoE, 0=dense). + + - int N: one MoE layer every N layers (e.g., N=2 -> [1,0,1,0,...]) + - list: use directly as the pattern.""" + + if isinstance(config.moe_layer_freq, int): + # [1,0,0,...,0,1,0,0,...,0,...] + moe_layer_pattern = [ + 1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers) + ] + elif isinstance(config.moe_layer_freq, list): + moe_layer_pattern = config.moe_layer_freq + assert len(moe_layer_pattern) == config.num_layers, ( + f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, " + f"expected {config.num_layers}, " + f"current moe layer pattern: {config.moe_layer_freq}" ) else: raise ValueError( - f"Invalid experimental attention variant: {experimental_attention_variant}" + f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}" ) + return moe_layer_pattern + + +def get_linear_attention_pattern(config: TransformerConfig) -> List[int]: + """Parse config.linear_attention_freq to get per-layer attention pattern (1=LA, 0=SDPA). + + - int N: one SDPA layer every N layers (e.g., N=4 -> [1,1,1,0,1,1,1,0,...]) + - list: use directly as the pattern.""" + + if isinstance(config.linear_attention_freq, int): + linear_attention_pattern = [ + # [1,1,...,1,0,1,1,...,1,0,...] + 0 if ((i + 1) % config.linear_attention_freq == 0) else 1 + for i in range(config.num_layers) + ] + elif isinstance(config.linear_attention_freq, list): + linear_attention_pattern = config.linear_attention_freq + assert len(linear_attention_pattern) == config.num_layers, ( + f"Invalid length of linear_attention_pattern: {len(linear_attention_pattern)}, " + f"expected {config.num_layers}, " + f"current linear attention pattern: {config.linear_attention_freq}" + ) + elif config.linear_attention_freq is None: + if not is_linear_attention_variant(config.experimental_attention_variant): + linear_attention_pattern = [0] * config.num_layers + else: + # This should be caught by config validation, but raise here as a safety check + raise ValueError( + f"Linear attention type {config.experimental_attention_variant} is specified " + "but linear_attention_freq is None. " + "Please set linear_attention_freq to specify the LA/SDPA layer pattern." + ) + else: + raise ValueError( + f"Invalid linear_attention_freq: {type(config.linear_attention_freq)}," + f" {config.linear_attention_freq}" + ) + return linear_attention_pattern + + +def _get_backend_spec_provider(config: TransformerConfig) -> BackendSpecProvider: + """Get backend spec provider for experimental attention variant.""" + + assert config.transformer_impl == "transformer_engine", ( + "Experimental GPT decoder block spec only supports " + "transformer engine implementation for now." + ) + backend: BackendSpecProvider = ( + KitchenSpecProvider( + fallback=TESpecProvider(), + use_kitchen_attention=config.use_kitchen_attention, + kitchen_attention_backend=config.kitchen_attention_backend, + ) + if config.use_kitchen + else TESpecProvider() + ) + return backend + + +########## +# Spec functions for non-experimental self attention and MLP layer. +########## + + +def _get_self_attention_module_spec( + config: TransformerConfig, backend: BackendSpecProvider = None +) -> ModuleSpec: + """Get non-experimental self-attention module spec. + For hybrid models that mix experimental and non-experimental attention architectures. + + Warning: This function may be deprecated in the future.""" + + if backend is None: + backend = _get_backend_spec_provider(config=config) + + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec + + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + qk_l2_norm=config.qk_l2_norm, + use_kitchen=config.use_kitchen, + use_te_activation_func=config.use_te_activation_func, + use_kitchen_attention=config.use_kitchen_attention, + kitchen_attention_backend=config.kitchen_attention_backend, + ) + attn_spec = layer_spec.submodules.self_attention + if config.multi_latent_attention: + attn_spec.metainfo["fuse_input_layernorm"] = False + else: + attn_spec.metainfo["fuse_input_layernorm"] = backend.fuse_layernorm_and_linear() + + return attn_spec + + +def _get_dense_mlp_module_spec( + config: TransformerConfig, backend: BackendSpecProvider = None +) -> ModuleSpec: + """Get dense MLP module spec. + For hybrid models that mix dense MLP and experimental attention architectures. + + Warning: This function may be deprecated in the future.""" + + if backend is None: + backend = _get_backend_spec_provider(config=config) + + from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_for_backend + + mlp_spec = get_mlp_module_spec_for_backend(backend=backend, num_experts=None) + mlp_spec.metainfo["fuse_pre_mlp_layernorm"] = backend.fuse_layernorm_and_linear() + + return mlp_spec + + +def _get_moe_module_spec( + config: TransformerConfig, backend: BackendSpecProvider = None +) -> ModuleSpec: + """Get MoE module spec. + For hybrid models that mix MoE and experimental attention architectures. + + Warning: This function may be deprecated in the future.""" + + if backend is None: + backend = _get_backend_spec_provider(config=config) + + from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend + + moe_spec = get_moe_module_spec_for_backend( + backend=backend, + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + use_te_activation_func=config.use_te_activation_func, + ) + moe_spec.metainfo["fuse_pre_mlp_layernorm"] = False + return moe_spec diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 604b14c8ffe..974e33f88e8 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -9,9 +9,6 @@ InferenceSpecProvider, LocalSpecProvider, ) -from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( - get_experimental_attention_variant_module_spec_for_backend, -) from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType, LayerType @@ -179,7 +176,6 @@ def get_gpt_layer_with_transformer_engine_spec( moe_grouped_gemm: Optional[bool] = False, qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, - experimental_attention_variant: Optional[str] = None, fp8: Optional[str] = None, # pylint: disable=unused-argument moe_use_legacy_grouped_gemm: Optional[bool] = False, qk_l2_norm: Optional[bool] = False, @@ -197,8 +193,6 @@ def get_gpt_layer_with_transformer_engine_spec( moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. multi_latent_attention (bool, optional): To use MLA. Defaults to False. - experimental_attention_variant (str, optional): The type of experimental attention variant. - Defaults to None. fp8 (str, optional): Deprecated. For temporary Nemo compatibility. moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. Defaults to False. @@ -239,17 +233,6 @@ def get_gpt_layer_with_transformer_engine_spec( use_te_activation_func=use_te_activation_func, ) - if experimental_attention_variant is not None: - return get_experimental_attention_variant_module_spec_for_backend( - backend=backend, - experimental_attention_variant=experimental_attention_variant, - qk_layernorm=qk_layernorm, - qk_l2_norm=qk_l2_norm, - multi_latent_attention=multi_latent_attention, - num_experts=num_experts, - mlp=mlp, - ) - if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." linear_q_up_proj = ( @@ -328,7 +311,6 @@ def get_gpt_layer_local_spec( moe_grouped_gemm: Optional[bool] = False, qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, - experimental_attention_variant: Optional[str] = None, fp8: Optional[str] = None, # pylint: disable=unused-argument moe_use_legacy_grouped_gemm: Optional[bool] = False, normalization: Optional[str] = None, @@ -345,8 +327,6 @@ def get_gpt_layer_local_spec( moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. multi_latent_attention (bool, optional): To use MLA. Defaults to False. - experimental_attention_variant (str, optional): The type of experimental attention variant. - Defaults to None. fp8 (str, optional): Deprecated. For temporary Nemo compatibility. moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. Defaults to False. @@ -356,11 +336,6 @@ def get_gpt_layer_local_spec( ModuleSpec: Module specification with Megatron-Core modules """ - if experimental_attention_variant is not None: - raise NotImplementedError( - "Experimental attention variant is not supported with local spec yet." - ) - if use_kitchen: assert HAVE_KITCHEN backend = KitchenSpecProvider( diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 87e4091aece..62ee4537cfc 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -61,6 +61,8 @@ def get_moe_module_spec_for_backend( # MoE module spec moe_module_spec = ModuleSpec( - module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) + module=MoELayer, + submodules=MoESubmodules(experts=experts, shared_experts=shared_experts), + metainfo={"fuse_pre_mlp_layernorm": False}, ) return moe_module_spec diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py new file mode 100644 index 00000000000..70e749724dc --- /dev/null +++ b/megatron/core/ssm/gated_delta_net.py @@ -0,0 +1,664 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, Songlin Yang, Jan Kautz, Ali Hatamizadeh. + +# Some of this code was adopted from https://github.com/huggingface/transformers +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, replace +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from megatron.core.dist_checkpointing import ShardedTensor +from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory +from megatron.core.fp8_utils import get_fp8_align_size +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.jit import jit_fuser +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel import get_cuda_rng_tracker +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.utils import ( + ensure_metadata_has_dp_cp_group, + make_sharded_tensors_for_checkpoint, + sharded_state_dict_default, +) +from megatron.core.utils import deprecate_inference_params, nvtx_range_pop, nvtx_range_push + +# TODO: Implement GatedDeltaNetContextParallel +# from .gated_delta_net_context_parallel import GatedDeltaNetContextParallel + +try: + from fla.modules.l2norm import l2norm + from fla.ops.gated_delta_rule import chunk_gated_delta_rule + + HAVE_FLA = True +except ImportError: + chunk_gated_delta_rule = None + + HAVE_FLA = False + +try: + from causal_conv1d import causal_conv1d_fn +except ImportError: + causal_conv1d_fn = None + causal_conv1d_update = None + + +logger = logging.getLogger(__name__) + + +@dataclass +class GatedDeltaNetSubmodules: + """ + Contains the module specs for the input linear, output norm, and output linear layers. + """ + + in_proj: Union[ModuleSpec, type] = IdentityOp + out_norm: Union[ModuleSpec, type] = IdentityOp + out_proj: Union[ModuleSpec, type] = IdentityOp + + +class GatedDeltaNet(MegatronModule): + """Gated Delta Net (GDN) layer class + + GDN layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: GatedDeltaNetSubmodules, + layer_number: int = None, + bias: bool = False, + conv_bias: bool = False, + conv_init: Optional[float] = None, + use_qk_l2norm: bool = True, + A_init_range: Tuple[float, float] = (1, 16), + pg_collection: ProcessGroupCollection = None, + ): + """ + Args: + config: The config of the model. + submodules: Contains the module specs for the input and output linear layers. + layer_number: The layer number of this GDN layer. + bias: Whether to use bias in the linear layers. + conv_bias: Whether to use bias in the causal convolution. + conv_init: The initialization range for the causal convolution weights. + use_qk_l2norm: Whether to use L2 normalization in the kernel of the gated delta rule. + A_init_range: The initialization range for the attention weights. + pg_collection: The required process groups to use for tensor model parallel and context + parallel. + """ + + if not HAVE_FLA: + raise ImportError( + "FLA is not installed. Please install it with `pip install flash-linear-attention`." + ) + + super().__init__(config) + + # Attributes from arguments + self.layer_number = layer_number + self.bias = bias + self.conv_bias = conv_bias + self.conv_init = conv_init + assert A_init_range[0] >= 0 and A_init_range[1] >= A_init_range[0] + self.A_init_range = A_init_range + self.use_qk_l2norm = use_qk_l2norm + assert pg_collection is not None, "pg_collection must be provided for GatedDeltaNet" + self.pg_collection = pg_collection + self.tp_size = self.pg_collection.tp.size() + self.sp_size = self.tp_size if config.sequence_parallel else 1 + + # Attributes from config + self.config = config + self.hidden_size = config.hidden_size + self.act_fn = config.activation_func + self.activation = self.act_fn.__name__ + self.conv_kernel_dim = config.linear_conv_kernel_dim + self.key_head_dim = config.linear_key_head_dim + self.value_head_dim = config.linear_value_head_dim + self.num_key_heads = config.linear_num_key_heads + self.num_value_heads = config.linear_num_value_heads + self.qk_dim = self.key_head_dim * self.num_key_heads + self.v_dim = self.value_head_dim * self.num_value_heads + + # Input projection (hidden_states -> q, k, v, gate, beta, alpha) + # TODO: for now, output gate is forced for GDN. + # We may remove this restriction in the future. + self.in_proj_dim = self.qk_dim * 2 + self.v_dim * 2 + self.num_value_heads * 2 + if self.config.fp8: + fp8_align_size = get_fp8_align_size(self.config.fp8_recipe) + assert self.in_proj_dim % fp8_align_size == 0, ( + "For FP8, the innermost dimension of the GDN layer " + "input projection output tensor must be a multiple of 16." + ) + self.in_proj = build_module( + submodules.in_proj, + self.hidden_size, + self.in_proj_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="fc1", + tp_group=self.pg_collection.tp, + ) + + # Conv1d for QKV + self.conv_dim = self.qk_dim * 2 + self.v_dim + self.conv_dim_local_tp = self.conv_dim // self.tp_size + + # weight shape: [conv_dim, 1, d_conv] + # bias shape: [conv_dim] + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim_local_tp, + out_channels=self.conv_dim_local_tp, + bias=conv_bias, + kernel_size=self.conv_kernel_dim, + groups=self.conv_dim_local_tp, + padding=self.conv_kernel_dim - 1, + device=torch.cuda.current_device(), + dtype=config.params_dtype, + ) + setattr(self.conv1d.weight, "tensor_model_parallel", True) + if conv_bias: + setattr(self.conv1d.bias, "tensor_model_parallel", True) + + # Time step projection (discretization) + self.num_v_heads_local_tp = self.num_value_heads // self.tp_size + # dt_bias parameter + self.dt_bias = nn.Parameter( + torch.empty( + self.num_v_heads_local_tp, + dtype=config.params_dtype, + device=torch.cuda.current_device(), + ) + ) + setattr(self.dt_bias, "tensor_model_parallel", True) + # A_log parameter + self.A_log = nn.Parameter( + torch.empty( + self.num_v_heads_local_tp, + dtype=config.params_dtype, + device=torch.cuda.current_device(), + ) + ) + setattr(self.A_log, "tensor_model_parallel", True) + + # Output layernorm before projection + self.out_norm = build_module( + submodules.out_norm, + config=self.config, + hidden_size=self.value_head_dim, + eps=self.config.layernorm_epsilon, + ) + + self.out_proj = build_module( + submodules.out_proj, + self.v_dim, + self.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=bias, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name="fc2", + tp_group=self.pg_collection.tp, + ) + + # TODO: support CP + + self.reset_parameters() + + def reset_parameters(self): + """Reset the parameters.""" + if self.config.perform_initialization: + with get_cuda_rng_tracker().fork(): + # conv1d.weight + if self.conv_init is not None: + nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) + # dt_bias + torch.ones( + self.num_v_heads_local_tp, + out=self.dt_bias.data, + dtype=self.config.params_dtype, + device=torch.cuda.current_device(), + ) + # A_log + A = torch.empty( + self.num_v_heads_local_tp, + dtype=self.config.params_dtype, + device=torch.cuda.current_device(), + ).uniform_(*self.A_init_range) + self.A_log.data.copy_(torch.log(A)) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + key_value_states: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + attention_bias: Optional[Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + **kwargs, + ): + """ + Perform a forward pass through the GDN module. + + Args: + hidden_states (Tensor): Hidden states. + attention_mask (Tensor): Attention mask. + key_value_states (Optional[Tensor]): Key/value states (for cross attention). + inference_context (Optional[BaseInferenceContext]): Inference context that manages + KV cache. + attention_bias (Optional[Tensor]): Attention bias. + packed_seq_params (Optional[PackedSeqparams]): Parameters used for THD format. + sequence_len_offset (Optional[int]): Sequence length offset used for + inference CUDA graphs. + + Return: + (Tuple[Tensor, Tensor]) GDN output and bias. + + """ + # TODO: Deal with attention_mask + + inference_context = deprecate_inference_params(inference_context, inference_params) + + seq_len, batch, _ = hidden_states.shape + seq_len = seq_len * self.sp_size + + if inference_context is not None: + assert ( + inference_context.is_static_batching() + ), "GDN does not currently support dynamic inference batching." + assert not self.config.sequence_parallel + # TODO: support inference + raise NotImplementedError("GDN does not support inference for now.") + + if packed_seq_params is not None: + # TODO: support packed sequence + raise NotImplementedError("GDN does not support packed sequence for now.") + + # Input projection + nvtx_range_push(suffix="in_proj") + qkvzba, _ = self.in_proj(hidden_states) + nvtx_range_pop(suffix="in_proj") + + # Transpose: s b x --> b s x + # From sbhd to bshd format + qkvzba = qkvzba.transpose(0, 1) + + # Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha + qkv, gate, beta, alpha = torch.split( + qkvzba, + [ + (self.qk_dim * 2 + self.v_dim) // self.tp_size, + self.v_dim // self.tp_size, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + dim=-1, + ) + gate = gate.reshape(batch, seq_len, -1, self.value_head_dim) + beta = beta.reshape(batch, seq_len, -1) + alpha = alpha.reshape(batch, seq_len, -1) + + # Convolution on qkv + qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s + nvtx_range_push(suffix="conv1d") + if (causal_conv1d_fn is None) or self.config.deterministic_mode: + qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len]) + else: + assert self.activation in ["silu", "swish"] + qkv = causal_conv1d_fn( + x=qkv, + weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w + bias=self.conv1d.bias, + activation=self.activation, + ) + nvtx_range_pop(suffix="conv1d") + # Split qkv into query, key, and value + qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d + query, key, value = torch.split( + qkv, + [self.qk_dim // self.tp_size, self.qk_dim // self.tp_size, self.v_dim // self.tp_size], + dim=-1, + ) + query = query.reshape(batch, seq_len, -1, self.key_head_dim) + key = key.reshape(batch, seq_len, -1, self.key_head_dim) + value = value.reshape(batch, seq_len, -1, self.value_head_dim) + # Apply L2 norm to query and key + if self.use_qk_l2norm: + query = l2norm(query.contiguous()) + key = l2norm(key.contiguous()) + if self.num_value_heads // self.num_key_heads > 1: + query = query.repeat_interleave(self.num_value_heads // self.num_key_heads, dim=2) + key = key.repeat_interleave(self.num_value_heads // self.num_key_heads, dim=2) + + # Make contiguous + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + gate = gate.contiguous() + beta = beta.contiguous() + alpha = alpha.contiguous() + + # Calculate g and beta + nvtx_range_push(suffix="g_and_beta") + g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32 + beta = beta.sigmoid() + nvtx_range_pop(suffix="g_and_beta") + + nvtx_range_push(suffix="gated_delta_rule") + if self.config.deterministic_mode: + core_attn_out, last_recurrent_state = torch_chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + ) + else: + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, + ) + nvtx_range_pop(suffix="gated_delta_rule") + + # RMSNorm + nvtx_range_push(suffix="gated_norm") + norm_out = self._apply_gated_norm(core_attn_out, gate) + nvtx_range_pop(suffix="gated_norm") + + # Transpose: b s x --> s b x + # From bshd back to sbhd format + norm_out = norm_out.reshape(batch, seq_len, -1) + norm_out = norm_out.transpose(0, 1).contiguous() + + # Output projection + nvtx_range_push(suffix="out_proj") + out, out_bias = self.out_proj(norm_out) + nvtx_range_pop(suffix="out_proj") + + return out, out_bias + + @jit_fuser + def _apply_gated_norm(self, x, gate): + # Output Norm + x_dtype = x.dtype + x = x.reshape(-1, x.shape[-1]) + y = self.out_norm(x) + # Output gate + gate = gate.reshape(-1, gate.shape[-1]) + y = y * self.act_fn(gate.float()) + y = y.to(x_dtype) + return y + + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None, tp_group=None): + """Provide a sharded state dictionary for distributed checkpointing.""" + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + + sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, "", keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + "A_log": 0, + "dt_bias": 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + tp_group=(tp_group if tp_group is not None else self.pg_collection.tp), + dp_cp_group=metadata['dp_cp_group'], + ) + # Submodules + tp_group = tp_group if tp_group is not None else self.pg_collection.tp + for name, module in self.named_children(): + if name == "conv1d": + # Add TP sharding for Conv1d + module_sd = module.state_dict(prefix="", keep_vars=True) + tp_sharding_map = {f"weight": 0} + if self.conv_bias: + tp_sharding_map[f"bias"] = 0 + module_sharded_sd = make_sharded_tensors_for_checkpoint( + module_sd, + f"{prefix}{name}.", + tp_sharding_map, + sharded_offsets, + tp_group=tp_group, + dp_cp_group=metadata['dp_cp_group'], + ) + else: + module_sharded_sd = sharded_state_dict_default( + module, f"{prefix}{name}.", sharded_offsets, metadata, tp_group=tp_group + ) + + sharded_state_dict.update(module_sharded_sd) + + # At this point the TP sharding is correctly defined for each tensor, but some of the + # tensors must be additionally split into separate parts + in_proj_dim_local_tp = self.in_proj_dim // self.tp_size + assert sharded_state_dict[f"{prefix}in_proj.weight"].data.size(0) == in_proj_dim_local_tp, ( + in_proj_dim_local_tp, + sharded_state_dict[f"{prefix}in_proj.weight"], + ) + + sharded_state_dict[f"{prefix}in_proj.weight"] = _split_tensor_factory( + sharded_state_dict[f"{prefix}in_proj.weight"], + [ + self.qk_dim // self.tp_size, + self.qk_dim // self.tp_size, + self.v_dim // self.tp_size, + self.v_dim // self.tp_size, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + ["query", "key", "value", "z", "beta", "alpha"], + 0, + ) + + conv_layer_name_list = ["conv1d.weight"] + assert ( + sharded_state_dict[f"{prefix}conv1d.weight"].data.size(0) == self.conv_dim_local_tp + ), (self.conv_dim_local_tp, sharded_state_dict[f"{prefix}conv1d.weight"]) + if self.conv_bias: + conv_layer_name_list.append("conv1d.bias") + assert ( + sharded_state_dict[f"{prefix}conv1d.bias"].data.size(0) == self.conv_dim_local_tp + ), (self.conv_dim_local_tp, sharded_state_dict[f"{prefix}conv1d.bias"]) + for conv_layer_name in conv_layer_name_list: + sharded_state_dict[f"{prefix}{conv_layer_name}"] = _split_tensor_factory( + sharded_state_dict[f"{prefix}{conv_layer_name}"], + [ + self.qk_dim // self.tp_size, + self.qk_dim // self.tp_size, + self.v_dim // self.tp_size, + ], + ["query", "key", "value"], + 0, + ) + + return sharded_state_dict + + +def _split_tensor_factory( + orig_sh_ten: ShardedTensor, split_sections: List[int], split_names: List[str], split_dim: int +) -> ShardedTensorFactory: + """Builds a factory that splits a given ShardedTensor into several independent chunks.""" + assert isinstance(orig_sh_ten, ShardedTensor), type(orig_sh_ten) + orig_sh_ten_no_data = orig_sh_ten.without_data() # remove `data` reference + + if sum(split_sections) != orig_sh_ten_no_data.local_shape[split_dim]: + raise ValueError( + f"Split sections must cover the whole dimension size, " + f"got {split_sections=} vs dimensions size " + f"{orig_sh_ten_no_data.local_shape[split_dim]}" + ) + + assert not isinstance( + split_sections, int + ), "Splitting into predefined section sizes is supported (`split_sections` must be a list)" + assert len(split_sections) == len(split_names), (len(split_sections), len(split_names)) + + @torch.no_grad() + def sh_ten_build_fn( + key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice] + ): + factory_sh_ten = replace( + orig_sh_ten_no_data, + key=key, + data=t, + dtype=t.dtype, + replica_id=replica_id, + flattened_range=flattened_range, + ) + + chunk_sh_tens = [] + split_start = 0 + for split_size, split_name in zip(split_sections, split_names): + split_chunks = factory_sh_ten.narrow(split_dim, split_start, split_size) + for sh_ten in split_chunks: + sh_ten.key = f"{sh_ten.key}.{split_name}" + chunk_sh_tens.extend(split_chunks) + split_start += split_size + + assert split_start == orig_sh_ten_no_data.local_shape[split_dim], ( + split_start, + orig_sh_ten_no_data.local_shape[split_dim], + ) + assert sum(sh_ten.data.numel() for sh_ten in chunk_sh_tens) == t.numel(), ( + chunk_sh_tens, + t.shape, + ) + return chunk_sh_tens + + @torch.no_grad() + def sh_ten_merge_fn(sub_state_dict): + return torch.cat(sub_state_dict) + + return ShardedTensorFactory( + orig_sh_ten.key, orig_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn, orig_sh_ten.replica_id + ) + + +def torch_chunk_gated_delta_rule( + query, + key, + value, + g, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, +): + # pylint: disable=line-too-long + ''' + Torch-native implementation of chunked gated delta rule for deterministic mode. + Need this because FLA is not deterministic. + + Reference: https://github.com/huggingface/transformers/blob/144c8ce2809a2e21914017652700e1ecb450501e/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L470-L547 + ''' + + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1, eps=1e-6) + key = l2norm(key, dim=-1, eps=1e-6) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_sequence_length = sequence_length + pad_size + scale = 1 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0 + ) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + last_recurrent_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_attn_out = torch.zeros_like(value) + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1 + ) + + # for each chunk + for i in range(0, total_sequence_length // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + ) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape( + core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1] + ) + core_attn_out = core_attn_out[:, :, :sequence_length] + core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state diff --git a/megatron/core/transformer/spec_utils.py b/megatron/core/transformer/spec_utils.py index 160cad10ed7..09058084181 100644 --- a/megatron/core/transformer/spec_utils.py +++ b/megatron/core/transformer/spec_utils.py @@ -28,6 +28,7 @@ class ModuleSpec: module: Union[Tuple, type] params: dict = field(default_factory=lambda: {}) submodules: object = None + metainfo: dict = field(default_factory=lambda: {}) def __call__(self, *args: Any, **kwargs: Any) -> Any: """Builds an instance of the module from the spec. @@ -54,6 +55,7 @@ def import_module(module_path: Tuple[str]): return vars(module)[name] +# pylint: disable=missing-function-docstring def get_module(spec_or_module: Union[ModuleSpec, type], **additional_kwargs): """Returns or imports the provided module.""" # If a module clas is already provided return it as is diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index b16d88a83cd..ea4464b4784 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -821,6 +821,12 @@ def sharded_state_dict( elif isinstance(self.config.moe_layer_freq, list): non_homogeneous_layers = True + if isinstance(self.config.linear_attention_freq, int): + if self.config.linear_attention_freq > 1: + non_homogeneous_layers = True + elif isinstance(self.config.linear_attention_freq, list): + non_homogeneous_layers = True + if self.config.heterogeneous_block_specs: non_homogeneous_layers = True diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 4ec3e18530b..cabad4e15d7 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -192,6 +192,9 @@ class TransformerConfig(ModelParallelConfig): qk_layernorm: bool = False """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + qk_l2_norm: bool = False + """Whether to apply llama 4-style qk L2 norm.""" + qk_clip: bool = False """Whether to clip the query and key weights. Needed for Muon MLA Model training.""" @@ -231,6 +234,9 @@ class TransformerConfig(ModelParallelConfig): experimental_attention_variant: Optional[str] = None """Type of attention variant to use. Currently support gated_delta_net and dsa.""" + #################### + # DSA + #################### dsa_indexer_n_heads: Optional[int] = None """Number of DSA indexer heads.""" @@ -247,6 +253,31 @@ class TransformerConfig(ModelParallelConfig): """Whether to use sparse DSA indexer loss. If True, the indexer loss will be computed using the top-k indices.""" + #################### + # linear attention + #################### + linear_attention_freq: Optional[Union[int, List[int]]] = None + """Frequency between LA (linear attention) layers + and SDPA (scaled dot-product attention) layers. + Accepts either: + - An integer N: Represents a (N-1):N ratio, meaning (N-1) LA layers for every 1 SDPA layer + - A list that defines a custom pattern, e.g.: [1,1,1,0,1,1,1,0,1,1,1,0]""" + + linear_conv_kernel_dim: Optional[int] = None + """Conv kernel dimension for the gated delta net.""" + + linear_key_head_dim: Optional[int] = None + """Query and key head dimension for the gated delta net.""" + + linear_value_head_dim: Optional[int] = None + """Value and gate head dimension for the gated delta net.""" + + linear_num_key_heads: Optional[int] = None + """Number of query and key heads for the gated delta net.""" + + linear_num_value_heads: Optional[int] = None + """Number of value and gate heads for the gated delta net.""" + #################### # initialization #################### @@ -874,6 +905,46 @@ def __post_init__(self): f"tensor_model_parallel_size ({self.tensor_model_parallel_size})." ) + if self.experimental_attention_variant == "gated_delta_net": + assert ( + self.linear_attention_freq is not None + ), f"linear_attention_freq must be set for linear gated_delta_net." + + # Check required parameters + assert ( + self.linear_conv_kernel_dim is not None + ), "linear_conv_kernel_dim must be set for gated delta net." + assert ( + self.linear_key_head_dim is not None + ), "linear_key_head_dim must be set for gated delta net." + assert ( + self.linear_value_head_dim is not None + ), "linear_value_head_dim must be set for gated delta net." + assert ( + self.linear_num_key_heads is not None + ), "linear_num_key_heads must be set for gated delta net." + assert ( + self.linear_num_value_heads is not None + ), "linear_num_value_heads must be set for gated delta net." + assert self.linear_num_value_heads % self.linear_num_key_heads == 0, ( + f"linear_num_value_heads ({self.linear_num_value_heads}) must be a multiple of " + f"linear_num_key_heads ({self.linear_num_key_heads})." + ) + + # Check tensor parallelism compatibility + assert ( + self.linear_num_key_heads % self.tensor_model_parallel_size == 0 + ), "linear_num_key_heads must be a multiple of tensor_model_parallel_size." + assert ( + self.linear_num_value_heads % self.tensor_model_parallel_size == 0 + ), "linear_num_value_heads must be a multiple of tensor_model_parallel_size." + + # Do not support yet, but coming soon. + assert self.context_parallel_size == 1, ( + f"Gated delta net does not support context parallel for now," + f" but got {self.context_parallel_size=}." + ) + if self.fp8: # cannot support first last layer bf16 with delayed scaling if self.first_last_layers_bf16 and self.fp8_recipe == Fp8Recipe.delayed: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 979fbd4d5e1..345219e6070 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -321,7 +321,7 @@ def moe_freq_type(x): This allows defining arbitrary patterns of expert and dense layers. The pattern length must match the total number of transformer layers. Examples: - "([0]+[1]*23)": 1 dense layer followed by 23 experts layers + "([0]+[1]*23)": 1 dense layer followed by 23 expert layers "([1]*3+[0]*2)*2": Three expert layers followed by two dense layers, repeated twice. """ if isinstance(x, int): @@ -334,6 +334,31 @@ def moe_freq_type(x): # it's a single int but in str return int(x) +def la_freq_type(x): + """Frequency between LA (linear attention) layers and SDPA (scaled dot-product attention) layers. + + Accepts either: + - An integer N: Represents a (N-1):N ratio, meaning (N-1) LA layers for every 1 SDPA layer + - A string "N": Same as above, but provided as a string + - A string containing a Python list expression that defines a custom pattern, e.g.: + "([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] + where 1 indicates an LA layer and 0 indicates a SDPA layer. + This allows defining arbitrary patterns of LA and SDPA layers. + The pattern length must match the total number of transformer layers. + Examples: + "([0]+[1]*23)": 1 SDPA layer followed by 23 LA layers + "([1]*3+[0]*2)*2": Three LA layers followed by two SDPA layers, repeated twice. + """ + if x is None or isinstance(x, int): + return x + assert isinstance(x, str) + if '[' in x: + # it's a custom pattern + return _eval_pattern(x) + else: + # it's a single int but in str + return int(x) + def tuple_type(x): """ Convert a string to a tuple of integers. @@ -2437,7 +2462,8 @@ def _add_training_args(parser): 'processed in different batch configurations. This is more strict than deterministic-mode ' 'which only ensures bitwise identical results when the same inputs are processed in the same batch configuration. ' 'This will significantly affect speed of training and inference as the kernels are not full optimized.') - + group.add_argument('--disable-jit-fuser', action='store_true', + help='Disable the JIT fuser.') return parser @@ -3386,7 +3412,26 @@ def _add_experimental_attention_variant_args(parser): help='Coefficient for the indexer KL divergence loss. Set to 0 to disable indexer loss.') group.add_argument('--dsa-indexer-use-sparse-loss', action='store_true', help='Use sparse indexer loss. If set, the indexer loss will be computed using the top-k indices.') - + # Linear attention + group.add_argument('--linear-attention-freq', type=la_freq_type, default=None, + help='Frequency between LA (linear attention) layers and' + ' SDPA (scaled dot-product attention) layers. Accepts either: ' + '- An integer N: Represents a (N-1):N ratio, meaning (N-1) LA layers for every 1 SDPA layer ' + '- A string containing a Python list expression that defines a custom pattern, e.g.: ' + '"([1]*3+[0]*1)*3" evaluates to [1,1,1,0,1,1,1,0,1,1,1,0] ' + 'where 1 indicates an LA layer and 0 indicates a SDPA layer. ' + 'Examples: "([0]+[1]*23)": 1 SDPA layer followed by 23 LA layers, ' + '"([1]*3+[0]*2)*2": Three LA layers followed by two SDPA layers, repeated twice.') + group.add_argument('--linear-conv-kernel-dim', default=4, type=int, + help='Conv kernel dimension for the gated delta net.') + group.add_argument('--linear-key-head-dim', default=128, type=int, + help='Query and key head dimension for the gated delta net.') + group.add_argument('--linear-value-head-dim', default=128, type=int, + help='Value and gate head dimension for the gated delta net.') + group.add_argument('--linear-num-key-heads', default=16, type=int, + help='Number of query and key heads for the gated delta net.') + group.add_argument('--linear-num-value-heads', default=32, type=int, + help='Number of value and gate heads for the gated delta net.') return parser def _add_heterogeneous_args(parser): diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 0cc7edcfeb3..66bfe185a3b 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1475,13 +1475,13 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', ckpt_args = state_dict.get("args") if not hasattr(ckpt_args, "tensor_model_parallel_size"): - print_rank_0("WARNING: TP size not found in checkpoint args, using 0 as default.") + print_rank_0("WARNING: TP size not found in checkpoint args, using 1 as default.") if not hasattr(ckpt_args, "pipeline_model_parallel_size"): - print_rank_0("WARNING: PP size not found in checkpoint args, using 0 as default.") + print_rank_0("WARNING: PP size not found in checkpoint args, using 1 as default.") ckpt_tp_pp = ( - getattr(ckpt_args, "tensor_model_parallel_size", 0), - getattr(ckpt_args, "pipeline_model_parallel_size", 0), + getattr(ckpt_args, "tensor_model_parallel_size", 1), + getattr(ckpt_args, "pipeline_model_parallel_size", 1), ) run_tp_pp = ( args.tensor_model_parallel_size, diff --git a/megatron/training/global_vars.py b/megatron/training/global_vars.py index 62dbf701c1f..a718877b40c 100644 --- a/megatron/training/global_vars.py +++ b/megatron/training/global_vars.py @@ -9,6 +9,7 @@ from megatron.core import Timers from megatron.core.config import set_experimental_flag from megatron.core.energy_monitor import EnergyMonitor +from megatron.core.jit import disable_jit_fuser from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator, unset_num_microbatches_calculator from megatron.training.dist_signal_handler import DistributedSignalHandler from megatron.training.tokenizer import build_tokenizer @@ -112,6 +113,9 @@ def set_global_variables(args, build_tokenizer=True): if args.exit_signal_handler: _set_signal_handler(args.exit_signal) + if args.disable_jit_fuser: + disable_jit_fuser() + def unset_global_variables(): """Unset global vars. diff --git a/megatron/training/training.py b/megatron/training/training.py index 9551c89bf60..0acd7a54114 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -51,6 +51,9 @@ from megatron.core import mpu, tensor_parallel +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + is_linear_attention_variant, +) from megatron.core.utils import ( check_param_hashes_across_dp_replicas, get_attr_wrapped_model, @@ -282,9 +285,6 @@ def hybrid_flops(batch_size, seq_len, hidden_size, def transformer_flops(): """Calculate FLOPs for a standard Transformer model.""" # TODO(helenn/dnarayanan): Refactor this to reuse the helper methods. - # Attention projection size. - query_projection_size = args.kv_channels * args.num_attention_heads - query_projection_to_hidden_size_ratio = query_projection_size / args.hidden_size # Group Query Attention. if not args.group_query_attention: args.num_query_groups = args.num_attention_heads @@ -375,10 +375,9 @@ def transformer_flops(): + args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim) + 1 ) - self_attn_term = ( + standard_self_attn_term = ( 3 * 2 # fwd(1) + bwd(2) *FMA - * num_layers * ( ## q lora + rope + q norm q_term @@ -395,29 +394,106 @@ def transformer_flops(): ## core attn + args.seq_length * (args.num_attention_heads * (args.qk_head_dim + args.qk_pos_emb_head_dim)) - / 2 + / 2 # causal mask (only half of the mask is non-zero) + args.seq_length * args.num_attention_heads * args.v_head_dim / 2 ) ) else: ## MHA or GQA - self_attn_term = ( - expansion_factor - * num_layers - * args.hidden_size - * args.hidden_size + query_projection_size = args.kv_channels * args.num_attention_heads + key_projection_size = args.kv_channels * args.num_query_groups + value_projection_size = args.kv_channels * args.num_query_groups + standard_self_attn_term = ( + 3 + * 2 # fwd(1) + bwd(2) *FMA * ( - ( - 1 - + (args.num_query_groups / args.num_attention_heads) - # # Only half of the attention matrix is non-zero and needs to be multiplied with V. - + (args.seq_length / args.hidden_size / 2) - ) - * query_projection_to_hidden_size_ratio + ## qkv proj + args.hidden_size + * (query_projection_size + key_projection_size + value_projection_size) + ## core attention + + query_projection_size + * args.seq_length + / 2 # causal mask (only half of the mask is non-zero) + * 2 # QK^T and (QK^T)V + ## out proj + + query_projection_size + * args.hidden_size ) ) + if is_linear_attention_variant(args.experimental_attention_variant): + # Calculate number of dense and MoE Transformer MLPs. + if isinstance(args.linear_attention_freq, int): + linear_attention_pattern = [ + # [1,1,...,1,0,1,1,...,1,0,...] + 0 if ((i + 1) % args.linear_attention_freq == 0) + else 1 for i in range(num_layers) + ] + elif isinstance(args.linear_attention_freq, list): + linear_attention_pattern = args.linear_attention_freq + assert len(linear_attention_pattern) == num_layers, ( + f"Invalid length of linear_attention_pattern: {len(linear_attention_pattern)}, " + f"expected {num_layers}, " + f"current linear attention pattern: {args.linear_attention_freq}" + ) + elif args.linear_attention_freq is None: + # This should be caught by config validation, but raise here as a safety check + raise ValueError( + f"Linear attention type {args.experimental_attention_variant} is specified " + "but linear_attention_freq is None. " + "Please set linear_attention_freq to specify the LA/SDPA layer pattern." + ) + else: + raise ValueError( + f"Invalid linear_attention_freq: {type(args.linear_attention_freq)}," + f" {args.linear_attention_freq}" + ) + num_linear_attention_layers = sum(linear_attention_pattern) + num_standard_attention_layers = num_layers - num_linear_attention_layers + + if args.experimental_attention_variant == "gated_delta_net": + # Calculate the FLOPs for the gated delta net attention. + qk_head_dim = args.linear_key_head_dim + v_head_dim = args.linear_value_head_dim + num_qk_heads = args.linear_num_key_heads + num_v_heads = args.linear_num_value_heads + qk_dim = qk_head_dim * num_qk_heads + v_dim = v_head_dim * num_v_heads + linear_self_attn_term = ( + 3 + * 2 # fwd(1) + bwd(2) *FMA + * ( + ## in proj + args.hidden_size + * (2 * qk_dim + 2 * v_dim + 2 * num_v_heads) + ## conv1d + + args.linear_conv_kernel_dim + * (2 * qk_dim + v_dim) + ## gated delta rule + + num_v_heads + * (v_head_dim ** 2) + * 4 # KK^T, VK^T, S(a(I-bKK^T)), and SQ + ## out proj + + args.hidden_size + * v_dim + ) + ) + else: + raise ValueError( + "Invalid experimental_attention_variant: " + f"{args.experimental_attention_variant}" + ) + else: + num_linear_attention_layers = 0 + linear_self_attn_term = 0 + num_standard_attention_layers = num_layers + + self_attn_term = ( + linear_self_attn_term * num_linear_attention_layers + + standard_self_attn_term * num_standard_attention_layers + ) + total_floating_point_operations = ( batch_size * args.seq_length diff --git a/megatron/training/utils.py b/megatron/training/utils.py index ad070c42fb0..4bf613b7bd2 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -282,6 +282,7 @@ def report_memory(name): string += ' | max allocated: {}'.format(torch.cuda.max_memory_allocated() / mega_bytes) string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes) string += ' | max reserved: {}'.format(torch.cuda.max_memory_reserved() / mega_bytes) + string += ' | device usage: {}'.format(torch.cuda.device_memory_used() / mega_bytes) if mpu.get_data_parallel_rank() == 0: print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True) diff --git a/pyproject.toml b/pyproject.toml index 776813cf10a..a251bbb23ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dev = [ "opentelemetry-api~=1.33.1", "mamba-ssm~=2.2", "causal-conv1d~=1.5", + "flash-linear-attention~=0.3.2", "nv-grouped-gemm~=1.1", "megatron-energon[av_decode]~=6.0", "av", diff --git a/tests/functional_tests/shell_test_utils/run_ci_test.sh b/tests/functional_tests/shell_test_utils/run_ci_test.sh index 73fa78b97ca..20267536a0f 100644 --- a/tests/functional_tests/shell_test_utils/run_ci_test.sh +++ b/tests/functional_tests/shell_test_utils/run_ci_test.sh @@ -51,6 +51,8 @@ set -exo pipefail # Extract settings from params file TEST_TYPE=$(cat $TRAINING_PARAMS_PATH | /usr/local/bin/yq '.TEST_TYPE') +ENABLE_LIGHTWEIGHT_MODE=$(cat $TRAINING_PARAMS_PATH | + /usr/local/bin/yq '.ENV_VARS.ENABLE_LIGHTWEIGHT_MODE // "false"') MODE=$(cat $TRAINING_PARAMS_PATH | /usr/local/bin/yq '.MODE // "pretraining"') diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp1_gdn/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp1_gdn/model_config.yaml new file mode 100644 index 00000000000..ae6bb92a125 --- /dev/null +++ b/tests/functional_tests/test_cases/gpt/gpt3_mcore_te_tp2_pp1_gdn/model_config.yaml @@ -0,0 +1,81 @@ +ENV_VARS: + CUDA_DEVICE_MAX_CONNECTIONS: 1 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NCCL_ALGO: Ring + CUBLAS_WORKSPACE_CONFIG: :4096:8 + ENABLE_LIGHTWEIGHT_MODE: true +MODEL_ARGS: + # Add network size args + --untie-embeddings-and-output-weights: true + --num-layers: 6 + --hidden-size: 512 + --num-attention-heads: 8 + --group-query-attention: true + --num-query-groups: 2 + --swiglu: true + --position-embedding-type: rope + --rotary-percent: 0.5 + --no-rope-fusion: true #TODO: We can remove this once upgrading to the DEV container + --apply-layernorm-1p: true + --attention-output-gate: true + --experimental-attention-variant: gated_delta_net + --linear-attention-freq: 3 + --linear-conv-kernel-dim: 4 + --linear-key-head-dim: 64 + --linear-value-head-dim: 64 + --linear-num-key-heads: 4 + --linear-num-value-heads: 8 + # Add MoE args + --num-experts: 32 + --moe-ffn-hidden-size: 64 + --moe-shared-expert-intermediate-size: 64 + --moe-shared-expert-gate: true + --moe-router-load-balancing-type: aux_loss + --moe-router-topk: 8 + --disable-bias-linear: true + --moe-router-dtype: fp32 + # Add logging args + --log-params-norm: true + --log-num-zeros-in-grad: true + --log-validation-ppl-to-tensorboard: true + --log-timers-to-tensorboard: true + --tensorboard-dir: ${TENSORBOARD_PATH} + --micro-batch-size: 4 + --global-batch-size: 32 + --seq-length: 1024 + --max-position-embeddings: 1024 + --train-iters: 50 + --timing-log-level: 0 + --lr-decay-iters: 320000 + --save: ${CHECKPOINT_SAVE_PATH} + --load: ${CHECKPOINT_LOAD_PATH} + --data-path: ${DATA_PATH}/text/the_pile/shard00/my-gpt3_00_text_document + --vocab-file: ${DATA_PATH}/text/the_pile/shard00/bpe/vocab.json + --merge-file: ${DATA_PATH}/text/the_pile/shard00/bpe/merges.txt + --split: 949,50,1 + --distributed-backend: nccl + --lr: 0.00015 + --lr-decay-style: cosine + --min-lr: 1.0e-5 + --weight-decay: 1e-2 + --clip-grad: 1.0 + --lr-warmup-fraction: .01 + --log-interval: 1 + --save-interval: 25 + --eval-interval: 1000 + --eval-iters: 10 + --transformer-impl: transformer_engine + --tensor-model-parallel-size: 2 + --pipeline-model-parallel-size: 1 + --sequence-parallel: true + --untie-embeddings-and-output-weights: true + --deterministic-mode: true + --no-gradient-accumulation-fusion: true + --attention-softmax-in-fp32: true + --use-mcore-models: true + --ckpt-format: torch_dist + --data-cache-path: ${DATA_CACHE_PATH} + --bf16: true + --attention-backend: unfused + --log-memory-to-tensorboard: true +TEST_TYPE: ckpt-resume diff --git a/tests/test_utils/recipes/gpt.yaml b/tests/test_utils/recipes/gpt.yaml index d4f1828450e..90eddc55c27 100644 --- a/tests/test_utils/recipes/gpt.yaml +++ b/tests/test_utils/recipes/gpt.yaml @@ -342,6 +342,11 @@ products: platforms: [dgx_h100] - environment: [lts] scope: [nightly] + - test_case: [gpt3_mcore_te_tp2_pp1_gdn] + products: + - environment: [dev] + scope: [mr, mr-github, mr-github-slim] + platforms: [dgx_h100] - test_case: [gpt3_mcore_te_tp2_pp2_mla] products: - environment: [dev] diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index d3a58710907..500fc5783e2 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -136,6 +136,12 @@ "kv_channels": 128, "layernorm_epsilon": 1e-05, "layernorm_zero_centered_gamma": False, + "linear_attention_freq": None, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "linear_value_head_dim": 128, "log_max_attention_logit": False, "mamba_head_dim": 64, "mamba_num_groups": 8, @@ -217,6 +223,7 @@ "qk_clip": False, "qk_clip_alpha": 0.5, "qk_clip_threshold": 100, + "qk_l2_norm": False, "qk_layernorm": False, "quant_recipe": None, "recompute_granularity": None, diff --git a/tests/unit_tests/post_training/test_modelopt_module_spec.py b/tests/unit_tests/post_training/test_modelopt_module_spec.py index ec80fcb1a72..dac96785bc0 100644 --- a/tests/unit_tests/post_training/test_modelopt_module_spec.py +++ b/tests/unit_tests/post_training/test_modelopt_module_spec.py @@ -173,6 +173,7 @@ def setup_method(self, method): moe_ffn_hidden_size=128, moe_shared_expert_intermediate_size=128, qk_layernorm=True, + qk_l2_norm=True, use_cpu_initialization=True, ) default_spec = get_gpt_decoder_block_spec( diff --git a/tests/unit_tests/ssm/test_gated_delta_net.py b/tests/unit_tests/ssm/test_gated_delta_net.py new file mode 100644 index 00000000000..1ccc70a2327 --- /dev/null +++ b/tests/unit_tests/ssm/test_gated_delta_net.py @@ -0,0 +1,329 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from unittest import mock + +import pytest +import torch +import torch.nn.functional as F + +from megatron.core import parallel_state +from megatron.core.models.common.embeddings.rope_utils import ( + get_pos_emb_on_this_cp_rank as get_tensor_on_this_cp_rank, +) +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_experimental_attention_variant_module_spec, + get_transformer_block_with_experimental_attention_variant_spec, +) +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.ssm.gated_delta_net import GatedDeltaNet +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig +from megatron.training.arguments import parse_args +from megatron.training.checkpointing import load_checkpoint, save_checkpoint +from megatron.training.global_vars import set_args +from megatron.training.training import get_model +from megatron.training.utils import unwrap_model +from tests.unit_tests.dist_checkpointing import ( + TempNamedDir, + init_basic_mock_args, + init_checkpointing_mock_args, +) +from tests.unit_tests.test_utilities import Utils + +try: + import fla + + HAVE_FLA = True +except ImportError: + HAVE_FLA = False + + +@pytest.mark.parametrize( + ("tp_size", "sp", "cp_size"), + [ + (1, False, 1), + (2, False, 1), + (2, True, 1), + # GDN does not support CP for now. Leave it for future work. + ], +) +@pytest.mark.skipif(not HAVE_FLA, reason="FLA is not installed.") +@pytest.mark.internal +class TestGatedDeltaNet: + + @pytest.fixture(scope='function', autouse=True) + def setup_method(self, tp_size, sp, cp_size): + # Initialize parallel and random seed + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + ) + model_parallel_cuda_manual_seed(123) + self.tp_size = tp_size + self.cp_size = cp_size + self.sp_size = tp_size if sp else 1 + + # Get TP and CP process groups from device mesh + tp_group = parallel_state.get_tensor_model_parallel_group() + cp_group = parallel_state.get_context_parallel_group() + pg_collection = ProcessGroupCollection(tp=tp_group, cp=cp_group) + + # Initialize model + self.transformer_config = TransformerConfig( + hidden_size=256, + linear_conv_kernel_dim=2, + linear_key_head_dim=64, + linear_value_head_dim=64, + linear_num_key_heads=4, + linear_num_value_heads=8, + num_layers=1, + normalization="RMSNorm", + use_cpu_initialization=True, + layernorm_zero_centered_gamma=True, + num_attention_heads=8, + activation_func=F.silu, + bf16=True, + tensor_model_parallel_size=tp_size, + sequence_parallel=sp, + context_parallel_size=cp_size, + experimental_attention_variant="gated_delta_net", + linear_attention_freq=[1], + transformer_impl="transformer_engine", + ) + gdn_submodules = get_experimental_attention_variant_module_spec( + config=self.transformer_config + ).submodules + + self.gdn = GatedDeltaNet( + self.transformer_config, + submodules=gdn_submodules, + layer_number=1, + bias=False, + conv_bias=False, + conv_init=1.0, + use_qk_l2norm=True, + A_init_range=(1, 16), + pg_collection=pg_collection, + ) + self.gdn = self.gdn.cuda().bfloat16() + + def teardown_method(self): + Utils.destroy_model_parallel() + + def test_gpu_forward(self): + gdn = self.gdn + + micro_batch_size = 2 + seq_length = 64 + hidden_states = torch.ones( + (seq_length // self.sp_size // self.cp_size, micro_batch_size, gdn.config.hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16, + ) + attention_mask = None + + output, bias = gdn(hidden_states, attention_mask) + + assert output.dim() == 3, f"Output too many dimensions ({output.shape=})" + assert output.shape[0] == seq_length // self.sp_size // self.cp_size, ( + f"Output shape {output.shape[0]=} mismatch with " + f" {seq_length=} // {self.sp_size=} // {self.cp_size=}." + ) + assert ( + output.shape[1] == micro_batch_size + ), f"Output shape {output.shape[1]=} mismatch with {micro_batch_size=}" + assert ( + output.shape[2] == gdn.config.hidden_size + ), f"Output shape {output.shape[2]=} mismatch with {gdn.config.hidden_size=}" + assert ( + output.dtype == hidden_states.dtype + ), f"Output dtype {output.dtype=} mismatch with {hidden_states.dtype=}" + + +@pytest.mark.parametrize( + ("tp", "sp", "cp"), + [ + (4, False, 1), # TP w/o SP + (4, True, 1), # TP w/ SP + # CP does not support GDN for now. Add it once it is supported. + ], +) +@pytest.mark.skipif(not HAVE_FLA, reason="FLA is not installed.") +def test_parallel_gated_delta_net_correctness(tmp_path_dist_ckpt, tp, sp, cp): + # Constants + seed = 123 + sequence_length = 256 + micro_batch_size = 4 + hidden_size = 128 + + # Model initialization function + def initialize_gpt_model( + config, pre_process=True, post_process=True, vp_stage=None, pg_collection=None + ): + layer_spec = get_transformer_block_with_experimental_attention_variant_spec( + config=config, vp_stage=None, pp_rank=None + ) + gpt_model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=128, + max_sequence_length=sequence_length, + pre_process=pre_process, + post_process=post_process, + vp_stage=vp_stage, + pg_collection=pg_collection, + ) + return gpt_model + + # Initialize baseline parallel state + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1 + ) + + # Initialize input hidden states + torch.manual_seed(seed) + model_parallel_cuda_manual_seed(seed) + input_hidden_states = ( + torch.rand((sequence_length, micro_batch_size, hidden_size)) + .cuda() + .bfloat16() + .requires_grad_(True) + ) + + # Initialize transformer config + transformer_config = TransformerConfig( + hidden_size=128, + linear_conv_kernel_dim=2, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_num_key_heads=4, + linear_num_value_heads=8, + num_layers=1, + normalization="RMSNorm", + use_cpu_initialization=True, + layernorm_zero_centered_gamma=True, + num_attention_heads=8, + activation_func=F.silu, + bf16=True, + experimental_attention_variant="gated_delta_net", + linear_attention_freq=[1], + transformer_impl="transformer_engine", + ) + + with TempNamedDir(tmp_path_dist_ckpt / 'test_parallel_gdn', sync=True) as ckpt_dir: + # Set argument + mock_args = parse_args(ignore_unknown_args=True) + set_args(mock_args) + + # Initialize baseline model + init_basic_mock_args(mock_args, 1, 1, bf16=True) + mock_args.context_parallel_size = 1 + mock_args.sequence_parallel = 1 + gpt_model = unwrap_model(get_model(initialize_gpt_model, config=transformer_config)) + + # Initialize args and save checkpoint + init_checkpointing_mock_args(mock_args, ckpt_dir, False) + mock_args.no_save_optim = True + mock_args.no_save_rng = True + mock_args.no_load_optim = True + mock_args.no_load_rng = True + save_checkpoint(10, gpt_model, None, None, 0) + + # Calculate baseline output + attention = gpt_model[0].decoder.layers[0].self_attention + output_hidden_states_baseline, bias_hidden_states_baseline = attention( + input_hidden_states, attention_mask=None + ) + output_hidden_states_baseline.sum().backward() + + # Save baseline output + input_grad_baseline = input_hidden_states.grad.detach() + output_hidden_states_baseline = output_hidden_states_baseline.detach() + + # Initialize parallel model + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=tp, pipeline_model_parallel_size=1, context_parallel_size=cp + ) + torch.manual_seed(seed) + model_parallel_cuda_manual_seed(seed) + transformer_config.context_parallel_size = cp + transformer_config.tensor_model_parallel_size = tp + transformer_config.sequence_parallel = sp + init_basic_mock_args(mock_args, tp, 1, bf16=True) + mock_args.context_parallel_size = cp + mock_args.sequence_parallel = sp + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + pg_collection.embd = parallel_state.get_embedding_group() + gpt_model = unwrap_model( + get_model(initialize_gpt_model, config=transformer_config, pg_collection=pg_collection) + ) + with mock.patch('megatron.training.checkpointing.check_checkpoint_args'): + with mock.patch('megatron.training.checkpointing.update_num_microbatches'): + load_checkpoint(gpt_model, None, None) + + # Function to get tensor on this tp and cp rank + cp_group = parallel_state.get_context_parallel_group() + tp_rank = parallel_state.get_tensor_model_parallel_rank() + + def get_tensor_on_this_rank(tensor): + if cp > 1: + tensor = get_tensor_on_this_cp_rank(tensor, 0, cp_group) + if tp > 1 and sp: + sp_seg = sequence_length // tp // cp + tensor = tensor[tp_rank * sp_seg : (tp_rank + 1) * sp_seg] + return tensor + + # Calculate parallel model output + input_hidden_states = get_tensor_on_this_rank(input_hidden_states) + input_hidden_states = input_hidden_states.detach().requires_grad_(True) + parallel_attention = gpt_model[0].decoder.layers[0].self_attention + output_hidden_states_parallel, bias_hidden_states_parallel = parallel_attention( + input_hidden_states, attention_mask=None + ) + output_hidden_states_parallel.sum().backward() + input_grad_parallel = input_hidden_states.grad.detach() + + # Check if the output is the same + if cp: + atol, rtol = 5e-3, 5e-3 + else: + atol, rtol = 5e-4, 5e-4 + output_hidden_states_baseline = get_tensor_on_this_rank(output_hidden_states_baseline) + input_grad_baseline = get_tensor_on_this_rank(input_grad_baseline) + + assert torch.all( + ~torch.isnan(output_hidden_states_baseline) + ), "output_hidden_states_baseline contains nan" + assert torch.all( + ~torch.isinf(output_hidden_states_baseline) + ), "output_hidden_states_baseline contains inf" + assert torch.all(~torch.isnan(input_grad_baseline)), "input_grad_baseline contains nan" + assert torch.all(~torch.isinf(input_grad_baseline)), "input_grad_baseline contains inf" + assert torch.all( + ~torch.isnan(output_hidden_states_parallel) + ), "output_hidden_states_parallel contains nan" + assert torch.all( + ~torch.isinf(output_hidden_states_parallel) + ), "output_hidden_states_parallel contains inf" + assert torch.all(~torch.isnan(input_grad_parallel)), "input_grad_parallel contains nan" + assert torch.all(~torch.isinf(input_grad_parallel)), "input_grad_parallel contains inf" + + torch.testing.assert_close( + output_hidden_states_baseline, + output_hidden_states_parallel, + atol=atol, + rtol=rtol, + msg=lambda msg: f"Mismatch in output_hidden_states: {msg}", + ) + torch.testing.assert_close( + input_grad_baseline, + input_grad_parallel, + atol=atol, + rtol=rtol, + msg=lambda msg: f"Mismatch in input_grad: {msg}", + ) + + Utils.destroy_model_parallel() diff --git a/uv.lock b/uv.lock index e2db433cc28..ba5225a4c75 100644 --- a/uv.lock +++ b/uv.lock @@ -1331,6 +1331,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/79/1b8fa1bb3568781e84c9200f951c735f3f157429f44be0495da55894d620/filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25", size = 19970, upload-time = "2022-11-02T17:34:01.425Z" }, ] +[[package]] +name = "fla-core" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "torch", marker = "sys_platform == 'never'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/c6/10a1149b07e6bab45b2cb2d07f6b827716c2baf5f3404161753f25c6389b/fla_core-0.3.2.tar.gz", hash = "sha256:d38db16bc4e1c6fa8c04df442f246da1e6926a209426bc6ef703d41bfbc37c92", size = 296725, upload-time = "2025-09-10T07:43:40.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/74947b33c07682280e65adbdf17c4ee94b30232df2f728bafecf13d1d820/fla_core-0.3.2-py3-none-any.whl", hash = "sha256:e751d5a41e33eee721a6fb6588bd857f6f36e0d14719a23b1ebdbd617d307209", size = 413594, upload-time = "2025-09-10T07:43:37.786Z" }, +] + [[package]] name = "flake8" version = "7.1.0" @@ -1345,6 +1358,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/43/d5147aadaa52558e94e024811f2f9543b4bd7203b3a9659eeb5dff9c61b3/flake8-7.1.0-py2.py3-none-any.whl", hash = "sha256:2e416edcc62471a64cea09353f4e7bdba32aeb079b6e360554c659a122b1bc6a", size = 57569, upload-time = "2024-06-15T21:37:05.342Z" }, ] +[[package]] +name = "flash-linear-attention" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "datasets" }, + { name = "fla-core" }, + { name = "pytest" }, + { name = "transformers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/f6/e62c1e562a288557eba7f06f168a7615813d1a227327b8beb8ba426da2c5/flash_linear_attention-0.3.2.tar.gz", hash = "sha256:9147747316c2951fed4ebeb4fa87977c05d807dc70c93b46250b68a6eb1183e2", size = 150880, upload-time = "2025-09-10T07:43:41.37Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/d0/35ce9eac5f52c72005095aaa12a393d2656ed7ffedf925b2381a6b76d10c/flash_linear_attention-0.3.2-py3-none-any.whl", hash = "sha256:604e73361437ba786420ab195e2caa3fd19280503761e703fa353c5ce5c65376", size = 274592, upload-time = "2025-09-10T07:43:39.107Z" }, +] + [[package]] name = "flash-mla" version = "1.0.0+9edee0c" @@ -2208,6 +2236,7 @@ dev = [ { name = "einops" }, { name = "emerging-optimizers" }, { name = "fastapi" }, + { name = "flash-linear-attention" }, { name = "flashinfer-python" }, { name = "mamba-ssm" }, { name = "megatron-energon", extra = ["av-decode"], marker = "extra == 'extra-13-megatron-core-dev'" }, @@ -2318,6 +2347,7 @@ requires-dist = [ { name = "emerging-optimizers", marker = "extra == 'lts'", git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git?rev=v0.1.0" }, { name = "fastapi", marker = "extra == 'dev'", specifier = "~=0.50" }, { name = "fastapi", marker = "extra == 'lts'", specifier = "~=0.50" }, + { name = "flash-linear-attention", marker = "extra == 'dev'", specifier = "~=0.3.2" }, { name = "flashinfer-python", marker = "extra == 'dev'", specifier = "~=0.5.0" }, { name = "flashinfer-python", marker = "extra == 'lts'", specifier = "~=0.5.0" }, { name = "flask-restful", marker = "extra == 'mlm'" }, From 418be3ce003cfeb1d8b5da186de26430de836074 Mon Sep 17 00:00:00 2001 From: Yuzhong Wang Date: Thu, 15 Jan 2026 19:37:37 -0800 Subject: [PATCH 2/4] refine the flops computation --- megatron/training/training.py | 59 ++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/megatron/training/training.py b/megatron/training/training.py index 0acd7a54114..c25757a2db9 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -334,18 +334,15 @@ def transformer_flops(): if args.moe_shared_expert_intermediate_size is None else args.moe_shared_expert_intermediate_size ) - # SwiGLU. - gated_linear_multiplier = 3 / 2 if args.swiglu else 1 - # The 12x term below comes from the following factors; for more details, see - # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473. # - 3x: Each GEMM in the model needs to be performed 3 times (forward pass, # backward wgrad [weight gradient], backward dgrad [data gradient]). - # - 2x: GEMMs of a particular size are stacked twice in the standard Transformer model - # architectures implemented in this codebase (e.g., h->ffn_h GEMM and ffn_h->h GEMM - # in MLP layer). + forward_backward_expansion_factor = 3 # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations. - expansion_factor = 3 * 2 * 2 + fma_expansion_factor = 2 + # - 3x (SwiGLU enabled): h->2*ffn_h GEMM and ffn_h->h GEMM are stacked. + # - 2x (SwiGLU disabled): h->ffn_h GEMM and ffn_h->h GEMM are stacked. + ffn_expansion_factor = 3 if args.swiglu else 2 if args.multi_latent_attention: assert not args.group_query_attention @@ -376,8 +373,8 @@ def transformer_flops(): + 1 ) standard_self_attn_term = ( - 3 - * 2 # fwd(1) + bwd(2) *FMA + forward_backward_expansion_factor + * fma_expansion_factor * ( ## q lora + rope + q norm q_term @@ -404,13 +401,19 @@ def transformer_flops(): query_projection_size = args.kv_channels * args.num_attention_heads key_projection_size = args.kv_channels * args.num_query_groups value_projection_size = args.kv_channels * args.num_query_groups + gate_projection_size = query_projection_size if args.attention_output_gate else 0 standard_self_attn_term = ( - 3 - * 2 # fwd(1) + bwd(2) *FMA + forward_backward_expansion_factor + * fma_expansion_factor * ( ## qkv proj args.hidden_size - * (query_projection_size + key_projection_size + value_projection_size) + * ( + query_projection_size + + key_projection_size + + value_projection_size + + gate_projection_size + ) ## core attention + query_projection_size * args.seq_length @@ -461,8 +464,8 @@ def transformer_flops(): qk_dim = qk_head_dim * num_qk_heads v_dim = v_head_dim * num_v_heads linear_self_attn_term = ( - 3 - * 2 # fwd(1) + bwd(2) *FMA + forward_backward_expansion_factor + * fma_expansion_factor * ( ## in proj args.hidden_size @@ -499,25 +502,25 @@ def transformer_flops(): * args.seq_length * ( # MLP - expansion_factor - * num_layers + forward_backward_expansion_factor + * fma_expansion_factor * args.hidden_size * ( # dense layer (deepseek v2, v3 style) - (args.ffn_hidden_size * gated_linear_multiplier) - * (num_dense_layers / num_layers) + (args.ffn_hidden_size * ffn_expansion_factor) + * num_dense_layers # routed experts - + (moe_ffn_hidden_size * num_experts_routed_to * gated_linear_multiplier) - * (num_moe_layers / num_layers) + + (moe_ffn_hidden_size * num_experts_routed_to * ffn_expansion_factor) + * num_moe_layers # Shared Experts. - + (shared_expert_ffn_hidden_size * gated_linear_multiplier) - * (num_moe_layers / num_layers) + + (shared_expert_ffn_hidden_size * ffn_expansion_factor) + * num_moe_layers ) # Self Attention + self_attn_term # MTP norms and proj - + 3 - * 2 + + forward_backward_expansion_factor + * fma_expansion_factor * mtp_num_layers * ( # MTP eh norm + final nrom @@ -526,7 +529,11 @@ def transformer_flops(): + 2 * args.hidden_size * args.hidden_size ) # Logit. - + 3 * 2 * args.hidden_size * args.padded_vocab_size * (mtp_num_layers + 1) + + forward_backward_expansion_factor + * fma_expansion_factor + * args.hidden_size + * args.padded_vocab_size + * (mtp_num_layers + 1) # MTP + final logit ) ) return total_floating_point_operations From b218786bee6edf15d0a458c22c73a1c3ad853957 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Fri, 16 Jan 2026 11:59:34 +0000 Subject: [PATCH 3/4] ci(fix): CI_COMMIT_BRANCH on forks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: oliver könig --- .gitlab/scripts/build.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab/scripts/build.sh b/.gitlab/scripts/build.sh index 3c42b9137d1..0f34b838384 100644 --- a/.gitlab/scripts/build.sh +++ b/.gitlab/scripts/build.sh @@ -20,6 +20,8 @@ docker buildx create --name container --driver=docker-container --use tls-enviro ADDITIONAL_PARAMS=() +CI_COMMIT_BRANCH="${CI_COMMIT_BRANCH:-$CI_MERGE_REQUEST_SOURCE_BRANCH_NAME}" + if [[ "$CI_COMMIT_BRANCH" == "ci-rebuild-mcore-nemo-image" || "$CI_COMMIT_BRANCH" == "main" || "$CI_COMMIT_BRANCH" == "dev" ]]; then ADDITIONAL_PARAMS+=("--pull") fi From f58c3106c5bf7ccc483d02bd03f3ed82c32c2dde Mon Sep 17 00:00:00 2001 From: Yuzhong Wang Date: Fri, 16 Jan 2026 07:38:41 -0800 Subject: [PATCH 4/4] remove device usage --- megatron/training/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 4bf613b7bd2..ad070c42fb0 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -282,7 +282,6 @@ def report_memory(name): string += ' | max allocated: {}'.format(torch.cuda.max_memory_allocated() / mega_bytes) string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes) string += ' | max reserved: {}'.format(torch.cuda.max_memory_reserved() / mega_bytes) - string += ' | device usage: {}'.format(torch.cuda.device_memory_used() / mega_bytes) if mpu.get_data_parallel_rank() == 0: print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)