[PyTorch] Support delay_wgrad_compute cudagraph#1948
Merged
ksivaman merged 6 commits intoNVIDIA:mainfrom Oct 24, 2025
Merged
Conversation
0d8ddcf to
91c3c49
Compare
75b6a88 to
a2a3c76
Compare
a2a3c76 to
cc0cad0
Compare
0be616d to
601026a
Compare
601026a to
72d7389
Compare
72d7389 to
045870e
Compare
Signed-off-by: Robin Zhang <robinz@nvidia.com>
045870e to
3bde7ef
Compare
Contributor
Author
|
Hi @timmoon10 , could you help review this? Thanks! Cc @Wohox @lhb8125 |
Contributor
|
LGTM, thanks! |
Contributor
|
/te-ci pytorch |
Contributor
Author
|
Hi @timmoon10 may you help take a look or assign to someone? Thanks! |
Member
|
/te-ci pytorch |
Member
|
/te-ci pytorch |
Contributor
|
/te-ci pytorch |
Contributor
|
@ksivaman The CI looks good now, could you review it? |
ksivaman
approved these changes
Oct 24, 2025
This was referenced Nov 7, 2025
13 tasks
6 tasks
Wohox
pushed a commit
to Wohox/TransformerEngine
that referenced
this pull request
Jan 9, 2026
* support cudagraph dw Signed-off-by: Robin Zhang <robinz@nvidia.com> * fix lint Signed-off-by: Robin Zhang <robinz@nvidia.com> * fix ci Signed-off-by: Robin Zhang <robinz@nvidia.com> --------- Signed-off-by: Robin Zhang <robinz@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Wohox
pushed a commit
to Wohox/TransformerEngine
that referenced
this pull request
Jan 9, 2026
* support cudagraph dw Signed-off-by: Robin Zhang <robinz@nvidia.com> * fix lint Signed-off-by: Robin Zhang <robinz@nvidia.com> * fix ci Signed-off-by: Robin Zhang <robinz@nvidia.com> --------- Signed-off-by: Robin Zhang <robinz@nvidia.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Some TE modules are allowed to make delayed wgrad computation. When enabled, they will not compute wgrad in the normal forward-backward pass. Instead, wgrad is calculated when the user calls their
backward_dw()method. In this PR, we support this pattern in themake_graphed_callables()API. Besides the forward and backward graphs, it will capture a new backward_dw graph. This new graph is set as an attribute to the returned graphed callable, so the user can get and execute it when needed.The backward_dw graph will not be captured if no TE module has
need_backward_dw()set to True.Type of change
Changes
Please list the changes introduced in this PR:
TransformerEngineBaseModule.need_backward_dw()method.bwd_dw_graphsfor the delayed wgrad computation cudagraph.Checklist: