[Pytorch] Fix backward_dw cuda graph order#2376
Conversation
|
@timmoon10 Could you help review this or assign to someone? Thanks! Cc @buptzyb @lhb8125 |
Greptile OverviewGreptile SummaryFixes CUDA graph capture order when Key changes:
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Order as Order List
participant Parser as Order Parser
participant Validator as Wgrad Validator
participant Capture as Graph Capture
Note over Order: Example: [1, -1, -1.5]
Order->>Parser: Process order entries
alt Forward pass (c_id > 0)
Parser->>Capture: Capture forward graph<br/>for chunk c_id
Note right of Capture: Standard forward capture
else Backward dgrad (c_id < 0, ceil(c_id)==c_id)
Parser->>Validator: Check if wgrad validation needed
alt First dgrad with need_bwd_dw_graph
Validator->>Validator: Count future dgrad entries (c_id)
Validator->>Validator: Count future wgrad entries (c_id-0.5)
alt Counts match
Validator->>Validator: Mark validation[i] = True
else Counts mismatch
Validator->>Validator: Mark validation[i] = False
Validator->>Order: Assert fails: counts must match
end
end
Parser->>Capture: Capture dgrad graph<br/>at per_callable_bwd_idx
Parser->>Parser: Increment bwd_idx[m_chunk]
else Backward wgrad (c_id < 0, ceil(c_id)!=c_id)
Note right of Parser: Float value detected (e.g., -1.5)
Parser->>Parser: m_chunk = -ceil(c_id) - 1
Parser->>Parser: per_callable_bwd_idx -= num_layers_per_chunk
Parser->>Capture: Capture wgrad graph only<br/>(backward_dw)
Note right of Capture: No dgrad, only weight gradient
Parser->>Parser: Continue (skip bwd_idx increment)
end
|
| if wgrad_validation_list[i] is None: | ||
| same_bwd_c_id_list = [i] | ||
| num_wgrad_c_id = 0 | ||
| for idx in range(i + 1, len(_order)): | ||
| if _order[idx] > 0: | ||
| continue | ||
| if _order[idx] == c_id: | ||
| same_bwd_c_id_list.append(idx) | ||
| if _order[idx] + 0.5 == c_id: | ||
| num_wgrad_c_id += 1 | ||
| if len(same_bwd_c_id_list) == num_wgrad_c_id: | ||
| for same_c_id_idx in same_bwd_c_id_list: | ||
| wgrad_validation_list[same_c_id_idx] = True | ||
| break | ||
| elif len(same_bwd_c_id_list) < num_wgrad_c_id: | ||
| # It's impossible to have more wgrad than dgrad. | ||
| wgrad_validation_list[i] = False | ||
| break | ||
| assert wgrad_validation_list[i], ( | ||
| f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " | ||
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | ||
| ) |
There was a problem hiding this comment.
logic: The validation loop can exit without setting wgrad_validation_list[i], leaving it as None. This happens when the loop completes without finding a match (neither len(same_bwd_c_id_list) == num_wgrad_c_id nor len(same_bwd_c_id_list) < num_wgrad_c_id is true). The assertion on line 542 will then fail with a confusing error since it checks truthiness of None.
Add an else clause after the loop to handle this case:
| if wgrad_validation_list[i] is None: | |
| same_bwd_c_id_list = [i] | |
| num_wgrad_c_id = 0 | |
| for idx in range(i + 1, len(_order)): | |
| if _order[idx] > 0: | |
| continue | |
| if _order[idx] == c_id: | |
| same_bwd_c_id_list.append(idx) | |
| if _order[idx] + 0.5 == c_id: | |
| num_wgrad_c_id += 1 | |
| if len(same_bwd_c_id_list) == num_wgrad_c_id: | |
| for same_c_id_idx in same_bwd_c_id_list: | |
| wgrad_validation_list[same_c_id_idx] = True | |
| break | |
| elif len(same_bwd_c_id_list) < num_wgrad_c_id: | |
| # It's impossible to have more wgrad than dgrad. | |
| wgrad_validation_list[i] = False | |
| break | |
| assert wgrad_validation_list[i], ( | |
| f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " | |
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | |
| ) | |
| if wgrad_validation_list[i] is None: | |
| same_bwd_c_id_list = [i] | |
| num_wgrad_c_id = 0 | |
| for idx in range(i + 1, len(_order)): | |
| if _order[idx] > 0: | |
| continue | |
| if _order[idx] == c_id: | |
| same_bwd_c_id_list.append(idx) | |
| if _order[idx] + 0.5 == c_id: | |
| num_wgrad_c_id += 1 | |
| if len(same_bwd_c_id_list) == num_wgrad_c_id: | |
| for same_c_id_idx in same_bwd_c_id_list: | |
| wgrad_validation_list[same_c_id_idx] = True | |
| break | |
| elif len(same_bwd_c_id_list) < num_wgrad_c_id: | |
| # It's impossible to have more wgrad than dgrad. | |
| wgrad_validation_list[i] = False | |
| break | |
| else: | |
| # Loop completed without break - mismatch found | |
| wgrad_validation_list[i] = False | |
| assert wgrad_validation_list[i], ( | |
| f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " | |
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | |
| ) |
There was a problem hiding this comment.
- None means the current check doesn't find the corresponding wgrad chunk_id, and hence should raise assertion.
- Assert None and Assert False will both trigger the assertion.
|
@lhb8125 Could you take a look? |
| elif ceil(c_id) != c_id: | ||
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] | ||
| assert is_training, "Only training mode supports backward_dw." | ||
| # If no one module needs the backward_dw, the bwd_dw_graph will be empty. | ||
| # So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate | ||
| # the specific order of backward_dw. | ||
| assert ceil(c_id) - c_id == 0.5, ( | ||
| "The order diff of wgrad and dgrad must be 0.5, " | ||
| f"get {ceil(c_id) - c_id}." | ||
| ) | ||
| assert need_bwd_dw_graph[ | ||
| per_callable_bwd_idx | ||
| ], "No module needs wgrad computation but get float in order" | ||
| bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] | ||
| with _graph_context_wrapper(bwd_dw_graph, pool=mempool): | ||
| for module in visited_te_modules[per_callable_bwd_idx]: | ||
| if ( | ||
| hasattr(module, "need_backward_dw") | ||
| and module.need_backward_dw() | ||
| ): | ||
| module.backward_dw() | ||
| continue |
There was a problem hiding this comment.
logic: The wgrad handling is inside the for l_no loop (line 514), but the continue on line 569 only skips to the next iteration of this inner loop. This means wgrad graphs will be captured multiple times (once for each layer in _num_layers_per_chunk[m_chunk]), which seems incorrect since wgrad should only be captured once per chunk.
Consider using break instead of continue, or restructure the code to handle wgrad entries outside the layer loop.
There was a problem hiding this comment.
Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | ||
| ) | ||
| elif ceil(c_id) != c_id: | ||
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] |
There was a problem hiding this comment.
logic: The index adjustment per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] appears incorrect. This will cause wgrad graphs to be captured from the wrong microbatch or with invalid indices.
When processing a wgrad entry (e.g., -1.5), bwd_idx[m_chunk] has already been incremented by the previous dgrad entry. The correction should likely be at the microbatch level in the index formula, not by subtracting the number of layers. For example, use (bwd_idx[m_chunk] - 1) in the index calculation instead of modifying per_callable_bwd_idx after the fact.
There was a problem hiding this comment.
(bwd_idx[m_chunk] - 1) is mathematically the same as subtracting by _num_layers_per_chunk[m_chunk]
There was a problem hiding this comment.
You're absolutely right! Let me trace through the logic to show why they're mathematically equivalent in this context.
Looking at the code structure:
-
per_callable_bwd_idxis calculated as:per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no )
-
When we have
delay_wgrad_computeenabled, each model chunk has exactly one layer (_num_layers_per_chunk[m_chunk] == 1). -
The
bwd_idx[m_chunk]gets incremented after processing each dgrad entry.
So when we reach a wgrad entry (where ceil(c_id) != c_id), the bwd_idx[m_chunk] has already been incremented by the previous dgrad processing.
Given that _num_layers_per_chunk[m_chunk] == 1 in the delay_wgrad_compute case:
(bwd_idx[m_chunk] - 1) * _num_layers_per_chunk[m_chunk]becomes(bwd_idx[m_chunk] - 1) * 1 = bwd_idx[m_chunk] - 1bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] - _num_layers_per_chunk[m_chunk]becomesbwd_idx[m_chunk] * 1 - 1 = bwd_idx[m_chunk] - 1
You're correct that both approaches yield the same result mathematically. The current implementation works correctly for the delay_wgrad_compute case. My original comment incorrectly identified this as a bug when it's actually functioning as intended.
The key insight is that when delay_wgrad_compute is enabled, _num_layers_per_chunk[m_chunk] is always 1, making the subtraction equivalent to decrementing the microbatch index, which is the correct behavior for accessing the corresponding forward pass tensors for the wgrad computation.
2603712 to
b197819
Compare
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
b197819 to
bbbb173
Compare
| elif ceil(c_id) != c_id: | ||
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] | ||
| assert is_training, "Only training mode supports backward_dw." | ||
| # If no one module needs the backward_dw, the bwd_dw_graph will be empty. | ||
| # So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate | ||
| # the specific order of backward_dw. | ||
| assert ceil(c_id) - c_id == 0.5, ( | ||
| "The order diff of wgrad and dgrad must be 0.5, " | ||
| f"get {ceil(c_id) - c_id}." | ||
| ) | ||
| assert need_bwd_dw_graph[ | ||
| per_callable_bwd_idx | ||
| ], "No module needs wgrad computation but get float in order" | ||
| bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] | ||
| with _graph_context_wrapper(bwd_dw_graph, pool=mempool): | ||
| for module in visited_te_modules[per_callable_bwd_idx]: | ||
| if ( | ||
| hasattr(module, "need_backward_dw") | ||
| and module.need_backward_dw() | ||
| ): | ||
| module.backward_dw() | ||
| continue |
There was a problem hiding this comment.
logic: Wgrad graphs will be captured multiple times (once per layer in the chunk) because this code is inside the for l_no loop (line 514). The continue on line 569 only skips to the next layer iteration, not out of the chunk.
For a chunk with 3 layers, the wgrad graph would be captured 3 times with indices calculated as:
- Layer 0:
per_callable_bwd_idx = base + (bwd_idx[m_chunk] * 3 + 0) - 3 - Layer 1:
per_callable_bwd_idx = base + (bwd_idx[m_chunk] * 3 + 1) - 3 - Layer 2:
per_callable_bwd_idx = base + (bwd_idx[m_chunk] * 3 + 2) - 3
The wgrad should only be captured once per chunk. Move this block outside the layer loop or use break instead of continue.
There was a problem hiding this comment.
Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.
| if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]: | ||
| # Check if bwd graph has corresponding wgrad graph: | ||
| # Number of dgrad backward graphs should be equal to number of | ||
| # wgrad backward graphs. | ||
| # Note: For MCore, the validation rule is more strict (the next backward | ||
| # of dgrad graph must be corresponding wgrad graph). | ||
| if wgrad_validation_list[i] is None: | ||
| same_bwd_c_id_list = [i] | ||
| num_wgrad_c_id = 0 | ||
| for idx in range(i + 1, len(_order)): | ||
| if _order[idx] > 0: | ||
| continue | ||
| if _order[idx] == c_id: | ||
| same_bwd_c_id_list.append(idx) | ||
| if _order[idx] + 0.5 == c_id: | ||
| num_wgrad_c_id += 1 | ||
| if len(same_bwd_c_id_list) == num_wgrad_c_id: | ||
| for same_c_id_idx in same_bwd_c_id_list: | ||
| wgrad_validation_list[same_c_id_idx] = True | ||
| break | ||
| if len(same_bwd_c_id_list) < num_wgrad_c_id: | ||
| # It's impossible to have more wgrad than dgrad. | ||
| wgrad_validation_list[i] = False | ||
| break | ||
| if wgrad_validation_list[i] is None: | ||
| wgrad_validation_list[i] = False | ||
| assert wgrad_validation_list[i], ( | ||
| f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " | ||
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | ||
| ) |
There was a problem hiding this comment.
logic: Validation only runs if the first layer (last in reverse iteration) needs wgrad due to need_bwd_dw_graph[per_callable_bwd_idx] check. If only later layers need wgrad, validation is skipped entirely. The validation logic should check if ANY layer in the chunk needs wgrad, not just per_callable_bwd_idx.
There was a problem hiding this comment.
@buptzyb Do you think we should validate num_layers_per_chunk is a all-1 array when any of the model chunk has need_bwd_dw_graph?
There was a problem hiding this comment.
Make sense. I think this is necessary.
There was a problem hiding this comment.
For delay wgrad case, the layer number is guaranteed to be all 1s, therefore only need to check the first layer
|
@ksivaman The pytorch build failed due to |
|
/te-ci pytorch |
1 similar comment
|
/te-ci pytorch |
There was a problem hiding this comment.
1 file reviewed, 3 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | ||
| ) | ||
| elif ceil(c_id) != c_id: | ||
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] |
There was a problem hiding this comment.
logic: The index calculation uses bwd_idx[m_chunk] which was already incremented by the previous dgrad entry. For order [1, -1, -1.5], when processing -1.5, bwd_idx[0] is already 1 (from -1), so this calculates index for microbatch 1 then subtracts layers, giving wrong index.
Use (bwd_idx[m_chunk] - 1) in the formula on line 515-516 instead:
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] | |
| per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( | |
| (bwd_idx[m_chunk] - 1) * _num_layers_per_chunk[m_chunk] + l_no | |
| ) - _num_layers_per_chunk[m_chunk] |
There was a problem hiding this comment.
(bwd_idx[m_chunk] - 1) is mathematically the same as subtracting by _num_layers_per_chunk[m_chunk]
There was a problem hiding this comment.
You're absolutely right! Looking at the code more carefully, I can see that per_callable_bwd_idx is calculated using bwd_idx[m_chunk] which gets incremented after each backward pass. So when we're processing the wgrad (float value), bwd_idx[m_chunk] has already been incremented from the previous dgrad pass.
The current line:
per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk]Is indeed mathematically equivalent to using (bwd_idx[m_chunk] - 1) in the original calculation on lines 529-531, since:
- Original:
(_prefix_num_layers[m_chunk] * num_microbatches) + (bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no) - With
(bwd_idx[m_chunk] - 1):(_prefix_num_layers[m_chunk] * num_microbatches) + ((bwd_idx[m_chunk] - 1) * _num_layers_per_chunk[m_chunk] + l_no) - Current approach: Calculate with
bwd_idx[m_chunk], then subtract_num_layers_per_chunk[m_chunk]
Both approaches achieve the same result of using the previous microbatch index. The current implementation is actually more straightforward since it reuses the same calculation pattern and then adjusts the result, rather than modifying the formula itself.
Your suggested fix would work, but the current approach in line 563 is equivalent and arguably cleaner since it maintains consistency with the calculation pattern used elsewhere in the function.
| elif ceil(c_id) != c_id: | ||
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] | ||
| assert is_training, "Only training mode supports backward_dw." | ||
| # If no one module needs the backward_dw, the bwd_dw_graph will be empty. | ||
| # So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate | ||
| # the specific order of backward_dw. | ||
| assert ceil(c_id) - c_id == 0.5, ( | ||
| "The order diff of wgrad and dgrad must be 0.5, " | ||
| f"get {ceil(c_id) - c_id}." | ||
| ) | ||
| assert need_bwd_dw_graph[ | ||
| per_callable_bwd_idx | ||
| ], "No module needs wgrad computation but get float in order" | ||
| bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] | ||
| with _graph_context_wrapper(bwd_dw_graph, pool=mempool): | ||
| for module in visited_te_modules[per_callable_bwd_idx]: | ||
| if ( | ||
| hasattr(module, "need_backward_dw") | ||
| and module.need_backward_dw() | ||
| ): | ||
| module.backward_dw() | ||
| continue |
There was a problem hiding this comment.
logic: Wgrad capture happens inside the for l_no loop, so it runs once per layer in the chunk. For a 3-layer chunk, wgrad gets captured 3 times with different indices:
- l_no=2:
per_callable_bwd_idx = base + (bwd_idx * 3 + 2) - 3 - l_no=1:
per_callable_bwd_idx = base + (bwd_idx * 3 + 1) - 3 - l_no=0:
per_callable_bwd_idx = base + (bwd_idx * 3 + 0) - 3
Move this entire wgrad block outside the layer loop to capture once per chunk, OR use break instead of continue.
There was a problem hiding this comment.
Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.
| per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( | ||
| bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no | ||
| ) | ||
| if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]: |
There was a problem hiding this comment.
logic: Validation only runs if need_bwd_dw_graph[per_callable_bwd_idx] is true, which depends on the current layer (l_no). If the first layer in reverse order (last layer of the chunk) doesn't need wgrad but other layers do, validation is skipped entirely.
Check if ANY layer in the chunk needs wgrad before running validation, not just the current layer.
There was a problem hiding this comment.
- For delay wgrad case, the layer number is guaranteed to be all 1s, therefore only need to check the first layer
- If the current layer doesn't have wgrad, then the check should be skipped. It checks for c_id-0.5 only when there are delay wgrad TE modules.
|
@ksivaman Can you help take a look at the CI error, I reran the CI and still got no space on disk left, this MR is important for DSV3 performance benchmarking, please help review, thanks. |
Signed-off-by: Pingtian Li <pingtianl@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
1 file reviewed, 3 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| elif ceil(c_id) != c_id: | ||
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] | ||
| assert is_training, "Only training mode supports backward_dw." | ||
| # If no one module needs the backward_dw, the bwd_dw_graph will be empty. | ||
| # So skip capturing it. For backward_dw, the order value is c_id - 0.5 to indicate | ||
| # the specific order of backward_dw. | ||
| assert ceil(c_id) - c_id == 0.5, ( | ||
| "The order diff of wgrad and dgrad must be 0.5, " | ||
| f"get {ceil(c_id) - c_id}." | ||
| ) | ||
| assert need_bwd_dw_graph[ | ||
| per_callable_bwd_idx | ||
| ], "No module needs wgrad computation but get float in order" | ||
| bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] | ||
| with _graph_context_wrapper(bwd_dw_graph, pool=mempool): | ||
| for module in visited_te_modules[per_callable_bwd_idx]: | ||
| if ( | ||
| hasattr(module, "need_backward_dw") | ||
| and module.need_backward_dw() | ||
| ): | ||
| module.backward_dw() | ||
| continue |
There was a problem hiding this comment.
logic: wgrad capture runs inside the for l_no loop (line 528) - once per layer in the chunk. For a chunk with 3 layers, this captures the same wgrad graph 3 times with different indices due to the loop iteration. Move this entire block outside the layer loop
There was a problem hiding this comment.
Only one 1 layer in each model chunk if delay wgrad is enabled, so using continue and break is exactly the same.
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | ||
| ) | ||
| elif ceil(c_id) != c_id: | ||
| per_callable_bwd_idx -= _num_layers_per_chunk[m_chunk] |
There was a problem hiding this comment.
logic: index calculation is incorrect: bwd_idx[m_chunk] was already incremented by the previous dgrad entry, so this subtracts layers from the wrong microbatch index. When delay_wgrad_compute=True, _num_layers_per_chunk[m_chunk] is always 1, making this per_callable_bwd_idx -= 1, but the real issue is using the wrong base bwd_idx
There was a problem hiding this comment.
bwd_idx is added by 1, so subtracting by 1 and then multiplies by _num_layers_per_chunk[m_chunk] is mathematically the same as subtracting by _num_layers_per_chunk[m_chunk].
| if ceil(c_id) == c_id and need_bwd_dw_graph[per_callable_bwd_idx]: | ||
| # Check if bwd graph has corresponding wgrad graph: | ||
| # Number of dgrad backward graphs should be equal to number of | ||
| # wgrad backward graphs. | ||
| # Note: For MCore, the validation rule is more strict (the next backward | ||
| # of dgrad graph must be corresponding wgrad graph). | ||
| if wgrad_validation_list[i] is None: | ||
| same_bwd_c_id_list = [i] | ||
| num_wgrad_c_id = 0 | ||
| for idx in range(i + 1, len(_order)): | ||
| if _order[idx] > 0: | ||
| continue | ||
| if _order[idx] == c_id: | ||
| same_bwd_c_id_list.append(idx) | ||
| if _order[idx] + 0.5 == c_id: | ||
| num_wgrad_c_id += 1 | ||
| if len(same_bwd_c_id_list) == num_wgrad_c_id: | ||
| for same_c_id_idx in same_bwd_c_id_list: | ||
| wgrad_validation_list[same_c_id_idx] = True | ||
| break | ||
| if len(same_bwd_c_id_list) < num_wgrad_c_id: | ||
| # It's impossible to have more wgrad than dgrad. | ||
| wgrad_validation_list[i] = False | ||
| break | ||
| if wgrad_validation_list[i] is None: | ||
| wgrad_validation_list[i] = False | ||
| assert wgrad_validation_list[i], ( | ||
| f"Number of wgrad graph({num_wgrad_c_id}) doesn't match number " | ||
| f"of dgrad graphs ({len(same_bwd_c_id_list)}) for chunk {c_id}." | ||
| ) |
There was a problem hiding this comment.
logic: validation only runs if need_bwd_dw_graph[per_callable_bwd_idx] is true for the current l_no layer. Since layers are processed in reverse order, if the last layer of a chunk doesn't need wgrad but earlier layers do, validation is skipped entirely. Check if ANY layer in the chunk needs wgrad before running validation
There was a problem hiding this comment.
For delay wgrad case, the layer number is guaranteed to be all 1s, therefore only need to check the first layer
There was a problem hiding this comment.
1 file reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
/te-ci pytorch |
* fix backward_dw cuda graph order Signed-off-by: Pingtian Li <pingtianl@nvidia.com> * add validation for num_layers_per_chunk Signed-off-by: Pingtian Li <pingtianl@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pingtian Li <pingtianl@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* fix backward_dw cuda graph order Signed-off-by: Pingtian Li <pingtianl@nvidia.com> * add validation for num_layers_per_chunk Signed-off-by: Pingtian Li <pingtianl@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pingtian Li <pingtianl@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* fix backward_dw cuda graph order Signed-off-by: Pingtian Li <pingtianl@nvidia.com> * add validation for num_layers_per_chunk Signed-off-by: Pingtian Li <pingtianl@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pingtian Li <pingtianl@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
Based on #1948
Fixes the cuda graph order of backward_dw graphs when enabling
delay_wgrad_compute, the user may delay the wgrad compute to the end of overlapped forward layers, therefore the capture order should also be moved accordingly.Overlap Pattern:
Design
Allow float values in
_order, for model chunki, its corresponding wgrad order is-i-0.5.Example
Type of change
Changes
Please list the changes introduced in this PR:
_orderinput, representing thebackward_dwofceil(order)chunk.Checklist: