Skip to content

[Pytorch] Fix backward_dw cuda graph order#2376

Merged
ksivaman merged 7 commits intoNVIDIA:mainfrom
Wohox:pingtian/fix_wgrad_cuda_graph_order
Nov 25, 2025
Merged

[Pytorch] Fix backward_dw cuda graph order#2376
ksivaman merged 7 commits intoNVIDIA:mainfrom
Wohox:pingtian/fix_wgrad_cuda_graph_order

Conversation

@Wohox
Copy link
Copy Markdown
Contributor

@Wohox Wohox commented Nov 13, 2025

Description

Based on #1948
Fixes the cuda graph order of backward_dw graphs when enabling delay_wgrad_compute, the user may delay the wgrad compute to the end of overlapped forward layers, therefore the capture order should also be moved accordingly.

Overlap Pattern:

f_1 --------------> f_2 -> ... ->f_5 | (end of fwd layers)
b_2 -> b_2_wgrad -> b_1 -------------|-> b_1_wgrad

Design
Allow float values in _order, for model chunk i, its corresponding wgrad order is -i-0.5.

Example

ORDER [1, 1, 1, 1, 2, -2, 2, -2, 2, -2, 2, -2, 1, -1, 1, -1, 1, -1, 1, -1, 2, -2, 2, -2, 2, -2, 2, -2, -1, -1, -1, -1]
ORDER after overlap_moe_expert_parallel_comm [1, 1, 1, 1, 2, 3, 2, -3, 3, -2, 2, -3, 3, -2, 2, -3, 3, -2, 1, -3, -2, 1, -1, 1, -1, 1, -1, 2, -1, 3, 2, -3, 3, -2, 2, -3, 3, -2, 2, -3, 3, -2, -3, -2, -1, -1, -1, -1]
Fixed ORDER after overlap_moe_expert_parallel_comm [1, 1, 1, 1, 2, 3, 2, -3, -3.5, 3, -2, -2.5, 2, -3, -3.5, 3, -2, -2.5, 2, -3, -3.5, 3, -2, -2.5, 1, -3, -3.5, -2, -2.5, 1, -1, -1.5, 1, -1, -1.5, 1, -1, -1.5, 2, -1(dgrad), 3, -1.5(delayed wgrad), 2, -3, -3.5, 3, -2, -2.5, 2, -3, -3.5, 3, -2, -2.5, 2, -3, -3.5, 3, -2, -2.5, -3, -3.5, -2, -2.5, -1, -1.5, -1, -1.5, -1, -1.5, -1, -1.5]

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Support float value in cuda graph _order input, representing the backward_dw of ceil(order) chunk.
  • During graph capture, when wgrad graph is identified, no input/ouput tensor space is allocated.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Wohox
Copy link
Copy Markdown
Contributor Author

Wohox commented Nov 13, 2025

@timmoon10 Could you help review this or assign to someone? Thanks! Cc @buptzyb @lhb8125

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Nov 13, 2025

Greptile Overview

Greptile Summary

Fixes CUDA graph capture order when delay_wgrad_compute is enabled by allowing float values in the _order list (e.g., -1.5 represents wgrad for chunk 1). This ensures wgrad graphs are captured at the correct position when weight gradient computation is delayed to the end of overlapped forward layers.

Key changes:

  • Float values in _order now represent delayed wgrad computation (value c_id - 0.5 for chunk ceil(c_id))
  • Added validation to ensure number of wgrad entries matches dgrad entries per chunk
  • Modified graph capture logic to handle wgrad entries separately from dgrad
  • When delay_wgrad_compute=True, enforces that each model chunk has exactly 1 layer
  • Wgrad capture now correctly adjusts index and skips incrementing bwd_idx counter

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - the fix is well-scoped to the delayed wgrad case with proper validation
  • The implementation correctly handles the delayed wgrad computation case with proper validation logic. The constraint that delay_wgrad_compute requires exactly 1 layer per chunk addresses previous concerns about the loop iteration. The validation ensures wgrad/dgrad count matches, preventing configuration errors. Minor deduction due to lack of tests and the complexity of the validation logic which could benefit from additional edge case coverage.
  • No files require special attention - the single changed file has appropriate validation and follows the existing code patterns

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/graph.py 4/5 Fixes CUDA graph capture order for delayed wgrad computation by supporting float values in _order (e.g., -1.5 for wgrad of chunk 1). Adds validation to ensure wgrad/dgrad count matches.

Sequence Diagram

sequenceDiagram
    participant Order as Order List
    participant Parser as Order Parser
    participant Validator as Wgrad Validator
    participant Capture as Graph Capture

    Note over Order: Example: [1, -1, -1.5]
    Order->>Parser: Process order entries
    
    alt Forward pass (c_id > 0)
        Parser->>Capture: Capture forward graph<br/>for chunk c_id
        Note right of Capture: Standard forward capture
    else Backward dgrad (c_id < 0, ceil(c_id)==c_id)
        Parser->>Validator: Check if wgrad validation needed
        alt First dgrad with need_bwd_dw_graph
            Validator->>Validator: Count future dgrad entries (c_id)
            Validator->>Validator: Count future wgrad entries (c_id-0.5)
            alt Counts match
                Validator->>Validator: Mark validation[i] = True
            else Counts mismatch
                Validator->>Validator: Mark validation[i] = False
                Validator->>Order: Assert fails: counts must match
            end
        end
        Parser->>Capture: Capture dgrad graph<br/>at per_callable_bwd_idx
        Parser->>Parser: Increment bwd_idx[m_chunk]
    else Backward wgrad (c_id < 0, ceil(c_id)!=c_id)
        Note right of Parser: Float value detected (e.g., -1.5)
        Parser->>Parser: m_chunk = -ceil(c_id) - 1
        Parser->>Parser: per_callable_bwd_idx -= num_layers_per_chunk
        Parser->>Capture: Capture wgrad graph only<br/>(backward_dw)
        Note right of Capture: No dgrad, only weight gradient
        Parser->>Parser: Continue (skip bwd_idx increment)
    end
Loading

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +524 to +545
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
elif len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The validation loop can exit without setting wgrad_validation_list[i], leaving it as None. This happens when the loop completes without finding a match (neither len(same_bwd_c_id_list) == num_wgrad_c_id nor len(same_bwd_c_id_list) < num_wgrad_c_id is true). The assertion on line 542 will then fail with a confusing error since it checks truthiness of None.

Add an else clause after the loop to handle this case:

Suggested change
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
elif len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
elif len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
else:
# Loop completed without break - mismatch found
wgrad_validation_list[i] = False
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. None means the current check doesn't find the corresponding wgrad chunk_id, and hence should raise assertion.
  2. Assert None and Assert False will both trigger the assertion.

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ksivaman
Copy link
Copy Markdown
Member

@lhb8125 Could you take a look?

@ksivaman
Copy link
Copy Markdown
Member

@Wohox Could you sign-off your commits (docs ref)? Thanks!

@Wohox
Copy link
Copy Markdown
Contributor Author

Wohox commented Nov 13, 2025

@Wohox Could you sign-off your commits (docs ref)? Thanks!

@ksivaman My commits are all signed, the bots' commits are not. Do you mean I should force push all the changes in 1 signed commit?

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +548 to +569
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
assert is_training, "Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert ceil(c_id) - c_id == 0.5, (
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
assert need_bwd_dw_graph[
per_callable_bwd_idx
], "No module needs wgrad computation but get float in order"
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The wgrad handling is inside the for l_no loop (line 514), but the continue on line 569 only skips to the next iteration of this inner loop. This means wgrad graphs will be captured multiple times (once for each layer in _num_layers_per_chunk[m_chunk]), which seems incorrect since wgrad should only be captured once per chunk.

Consider using break instead of continue, or restructure the code to handle wgrad entries outside the layer loop.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.

f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The index adjustment per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] appears incorrect. This will cause wgrad graphs to be captured from the wrong microbatch or with invalid indices.

When processing a wgrad entry (e.g., -1.5), bwd_idx[m_chunk] has already been incremented by the previous dgrad entry. The correction should likely be at the microbatch level in the index formula, not by subtracting the number of layers. For example, use (bwd_idx[m_chunk] - 1) in the index calculation instead of modifying per_callable_bwd_idx after the fact.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(bwd_idx[m_chunk] - 1) is mathematically the same as subtracting by _num_layers_per_chunk[m_chunk]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! Let me trace through the logic to show why they're mathematically equivalent in this context.

Looking at the code structure:

  1. per_callable_bwd_idx is calculated as:

    per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
        bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
    )
  2. When we have delay_wgrad_compute enabled, each model chunk has exactly one layer (_num_layers_per_chunk[m_chunk] == 1).

  3. The bwd_idx[m_chunk] gets incremented after processing each dgrad entry.

So when we reach a wgrad entry (where ceil(c_id) != c_id), the bwd_idx[m_chunk] has already been incremented by the previous dgrad processing.

Given that _num_layers_per_chunk[m_chunk] == 1 in the delay_wgrad_compute case:

  • (bwd_idx[m_chunk] - 1) * _num_layers_per_chunk[m_chunk] becomes (bwd_idx[m_chunk] - 1) * 1 = bwd_idx[m_chunk] - 1
  • bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] - _num_layers_per_chunk[m_chunk] becomes bwd_idx[m_chunk] * 1 - 1 = bwd_idx[m_chunk] - 1

You're correct that both approaches yield the same result mathematically. The current implementation works correctly for the delay_wgrad_compute case. My original comment incorrectly identified this as a bug when it's actually functioning as intended.

The key insight is that when delay_wgrad_compute is enabled, _num_layers_per_chunk[m_chunk] is always 1, making the subtraction equivalent to decrementing the microbatch index, which is the correct behavior for accessing the corresponding forward pass tensors for the wgrad computation.

@Wohox Wohox force-pushed the pingtian/fix_wgrad_cuda_graph_order branch from 2603712 to b197819 Compare November 14, 2025 00:32
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
@Wohox Wohox force-pushed the pingtian/fix_wgrad_cuda_graph_order branch from b197819 to bbbb173 Compare November 14, 2025 00:33
@Wohox
Copy link
Copy Markdown
Contributor Author

Wohox commented Nov 14, 2025

@Wohox Could you sign-off your commits (docs ref)? Thanks!

DCO passes now, thanks for the heads up.

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +548 to +569
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
assert is_training, "Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert ceil(c_id) - c_id == 0.5, (
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
assert need_bwd_dw_graph[
per_callable_bwd_idx
], "No module needs wgrad computation but get float in order"
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Wgrad graphs will be captured multiple times (once per layer in the chunk) because this code is inside the for l_no loop (line 514). The continue on line 569 only skips to the next layer iteration, not out of the chunk.

For a chunk with 3 layers, the wgrad graph would be captured 3 times with indices calculated as:

  • Layer 0: per_callable_bwd_idx = base + (bwd_idx[m_chunk] * 3 + 0) - 3
  • Layer 1: per_callable_bwd_idx = base + (bwd_idx[m_chunk] * 3 + 1) - 3
  • Layer 2: per_callable_bwd_idx = base + (bwd_idx[m_chunk] * 3 + 2) - 3

The wgrad should only be captured once per chunk. Move this block outside the layer loop or use break instead of continue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.

Comment on lines +518 to +547
if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]:
# Check if bwd graph has corresponding wgrad graph:
# Number of dgrad backward graphs should be equal to number of
# wgrad backward graphs.
# Note: For MCore, the validation rule is more strict (the next backward
# of dgrad graph must be corresponding wgrad graph).
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
if len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
if wgrad_validation_list[i] is None:
wgrad_validation_list[i] = False
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Validation only runs if the first layer (last in reverse iteration) needs wgrad due to need_bwd_dw_graph[per_callable_bwd_idx] check. If only later layers need wgrad, validation is skipped entirely. The validation logic should check if ANY layer in the chunk needs wgrad, not just per_callable_bwd_idx.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@buptzyb Do you think we should validate num_layers_per_chunk is a all-1 array when any of the model chunk has need_bwd_dw_graph?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. I think this is necessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For delay wgrad case, the layer number is guaranteed to be all 1s, therefore only need to check the first layer

@Wohox
Copy link
Copy Markdown
Contributor Author

Wohox commented Nov 14, 2025

@ksivaman The pytorch build failed due to ERROR: Could not install packages due to an OSError: [Errno 28] No space left on device, can you help check this, thanks!

@Wohox
Copy link
Copy Markdown
Contributor Author

Wohox commented Nov 14, 2025

/te-ci pytorch

1 similar comment
@Wohox
Copy link
Copy Markdown
Contributor Author

Wohox commented Nov 17, 2025

/te-ci pytorch

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The index calculation uses bwd_idx[m_chunk] which was already incremented by the previous dgrad entry. For order [1, -1, -1.5], when processing -1.5, bwd_idx[0] is already 1 (from -1), so this calculates index for microbatch 1 then subtracts layers, giving wrong index.

Use (bwd_idx[m_chunk] - 1) in the formula on line 515-516 instead:

Suggested change
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
(bwd_idx[m_chunk] - 1) * _num_layers_per_chunk[m_chunk] + l_no
) - _num_layers_per_chunk[m_chunk]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(bwd_idx[m_chunk] - 1) is mathematically the same as subtracting by _num_layers_per_chunk[m_chunk]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! Looking at the code more carefully, I can see that per_callable_bwd_idx is calculated using bwd_idx[m_chunk] which gets incremented after each backward pass. So when we're processing the wgrad (float value), bwd_idx[m_chunk] has already been incremented from the previous dgrad pass.

The current line:

per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]

Is indeed mathematically equivalent to using (bwd_idx[m_chunk] - 1) in the original calculation on lines 529-531, since:

  • Original: (_prefix_num_layers[m_chunk] * num_microbatches) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no)
  • With (bwd_idx[m_chunk] - 1): (_prefix_num_layers[m_chunk] * num_microbatches) + ((bwd_idx[m_chunk] - 1) * _num_layers_per_chunk[m_chunk] + l_no)
  • Current approach: Calculate with bwd_idx[m_chunk], then subtract _num_layers_per_chunk[m_chunk]

Both approaches achieve the same result of using the previous microbatch index. The current implementation is actually more straightforward since it reuses the same calculation pattern and then adjusts the result, rather than modifying the formula itself.

Your suggested fix would work, but the current approach in line 563 is equivalent and arguably cleaner since it maintains consistency with the calculation pattern used elsewhere in the function.

Comment on lines +548 to +569
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
assert is_training, "Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert ceil(c_id) - c_id == 0.5, (
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
assert need_bwd_dw_graph[
per_callable_bwd_idx
], "No module needs wgrad computation but get float in order"
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Wgrad capture happens inside the for l_no loop, so it runs once per layer in the chunk. For a 3-layer chunk, wgrad gets captured 3 times with different indices:

  • l_no=2: per_callable_bwd_idx = base + (bwd_idx * 3 + 2) - 3
  • l_no=1: per_callable_bwd_idx = base + (bwd_idx * 3 + 1) - 3
  • l_no=0: per_callable_bwd_idx = base + (bwd_idx * 3 + 0) - 3

Move this entire wgrad block outside the layer loop to capture once per chunk, OR use break instead of continue.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.

per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + (
bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no
)
if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Validation only runs if need_bwd_dw_graph[per_callable_bwd_idx] is true, which depends on the current layer (l_no). If the first layer in reverse order (last layer of the chunk) doesn't need wgrad but other layers do, validation is skipped entirely.

Check if ANY layer in the chunk needs wgrad before running validation, not just the current layer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. For delay wgrad case, the layer number is guaranteed to be all 1s, therefore only need to check the first layer
  2. If the current layer doesn't have wgrad, then the check should be skipped. It checks for c_id-0.5 only when there are delay wgrad TE modules.

@Wohox
Copy link
Copy Markdown
Contributor Author

Wohox commented Nov 17, 2025

@ksivaman Can you help take a look at the CI error, I reran the CI and still got no space on disk left, this MR is important for DSV3 performance benchmarking, please help review, thanks.

Wohox and others added 2 commits November 16, 2025 23:37
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 3 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

Comment on lines +562 to +583
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
assert is_training, "Only training mode supports backward_dw."
# If no one module needs the backward_dw, the bwd_dw_graph will be empty.
# So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate
# the specific order of backward_dw.
assert ceil(c_id) - c_id == 0.5, (
"The order diff of wgrad and dgrad must be 0.5, "
f"get {ceil(c_id) - c_id}."
)
assert need_bwd_dw_graph[
per_callable_bwd_idx
], "No module needs wgrad computation but get float in order"
bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx]
with _graph_context_wrapper(bwd_dw_graph, pool=mempool):
for module in visited_te_modules[per_callable_bwd_idx]:
if (
hasattr(module, "need_backward_dw")
and module.need_backward_dw()
):
module.backward_dw()
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: wgrad capture runs inside the for l_no loop (line 528) - once per layer in the chunk. For a chunk with 3 layers, this captures the same wgrad graph 3 times with different indices due to the loop iteration. Move this entire block outside the layer loop

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.

f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
elif ceil(c_id) != c_id:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: index calculation is incorrect: bwd_idx[m_chunk] was already incremented by the previous dgrad entry, so this subtracts layers from the wrong microbatch index. When delay_wgrad_compute=True, _num_layers_per_chunk[m_chunk] is always 1, making this per_callable_bwd_idx -= 1, but the real issue is using the wrong base bwd_idx

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bwd_idx is added by 1, so subtracting by 1 and then multiplies by _num_layers_per_chunk[m_chunk] is mathematically the same as subtracting by _num_layers_per_chunk[m_chunk].

Comment on lines +532 to +561
if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]:
# Check if bwd graph has corresponding wgrad graph:
# Number of dgrad backward graphs should be equal to number of
# wgrad backward graphs.
# Note: For MCore, the validation rule is more strict (the next backward
# of dgrad graph must be corresponding wgrad graph).
if wgrad_validation_list[i] is None:
same_bwd_c_id_list = [i]
num_wgrad_c_id = 0
for idx in range(i + 1, len(_order)):
if _order[idx] > 0:
continue
if _order[idx] == c_id:
same_bwd_c_id_list.append(idx)
if _order[idx] + 0.5 == c_id:
num_wgrad_c_id += 1
if len(same_bwd_c_id_list) == num_wgrad_c_id:
for same_c_id_idx in same_bwd_c_id_list:
wgrad_validation_list[same_c_id_idx] = True
break
if len(same_bwd_c_id_list) < num_wgrad_c_id:
# It's impossible to have more wgrad than dgrad.
wgrad_validation_list[i] = False
break
if wgrad_validation_list[i] is None:
wgrad_validation_list[i] = False
assert wgrad_validation_list[i], (
f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number "
f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: validation only runs if need_bwd_dw_graph[per_callable_bwd_idx] is true for the current l_no layer. Since layers are processed in reverse order, if the last layer of a chunk doesn't need wgrad but earlier layers do, validation is skipped entirely. Check if ANY layer in the chunk needs wgrad before running validation

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For delay wgrad case, the layer number is guaranteed to be all 1s, therefore only need to check the first layer

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

Copy link
Copy Markdown
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI pending

@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ksivaman ksivaman merged commit b3c2505 into NVIDIA:main Nov 25, 2025
11 of 13 checks passed
Wohox added a commit to Wohox/TransformerEngine that referenced this pull request Jan 7, 2026
* fix backward_dw cuda graph order

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>

* add validation for num_layers_per_chunk

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Wohox added a commit to Wohox/TransformerEngine that referenced this pull request Jan 9, 2026
* fix backward_dw cuda graph order

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>

* add validation for num_layers_per_chunk

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Wohox added a commit to Wohox/TransformerEngine that referenced this pull request Jan 9, 2026
* fix backward_dw cuda graph order

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>

* add validation for num_layers_per_chunk

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants