Skip to content

[PyTorch] Support delay_wgrad_compute cudagraph#1948

Merged
ksivaman merged 6 commits intoNVIDIA:mainfrom
buptzyb:robinz/cudagraph_dw
Oct 24, 2025
Merged

[PyTorch] Support delay_wgrad_compute cudagraph#1948
ksivaman merged 6 commits intoNVIDIA:mainfrom
buptzyb:robinz/cudagraph_dw

Conversation

@buptzyb
Copy link
Copy Markdown
Contributor

@buptzyb buptzyb commented Jul 14, 2025

Description

Some TE modules are allowed to make delayed wgrad computation. When enabled, they will not compute wgrad in the normal forward-backward pass. Instead, wgrad is calculated when the user calls their backward_dw() method. In this PR, we support this pattern in the make_graphed_callables() API. Besides the forward and backward graphs, it will capture a new backward_dw graph. This new graph is set as an attribute to the returned graphed callable, so the user can get and execute it when needed.

The backward_dw graph will not be captured if no TE module has need_backward_dw() set to True.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added a TransformerEngineBaseModule.need_backward_dw() method.
  • Added bwd_dw_graphs for the delayed wgrad computation cudagraph.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@buptzyb buptzyb force-pushed the robinz/cudagraph_dw branch from 0d8ddcf to 91c3c49 Compare July 14, 2025 08:35
@buptzyb buptzyb force-pushed the robinz/cudagraph_dw branch from 75b6a88 to a2a3c76 Compare July 28, 2025 08:55
@buptzyb buptzyb force-pushed the robinz/cudagraph_dw branch from a2a3c76 to cc0cad0 Compare August 5, 2025 01:05
@buptzyb buptzyb force-pushed the robinz/cudagraph_dw branch 2 times, most recently from 0be616d to 601026a Compare September 2, 2025 02:04
@buptzyb buptzyb force-pushed the robinz/cudagraph_dw branch from 601026a to 72d7389 Compare September 8, 2025 10:42
@buptzyb buptzyb force-pushed the robinz/cudagraph_dw branch from 72d7389 to 045870e Compare September 28, 2025 11:07
@nvMelissa nvMelissa added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Oct 9, 2025
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the robinz/cudagraph_dw branch from 045870e to 3bde7ef Compare October 13, 2025 06:39
@buptzyb
Copy link
Copy Markdown
Contributor Author

buptzyb commented Oct 15, 2025

Hi @timmoon10 , could you help review this? Thanks! Cc @Wohox @lhb8125

@lhb8125
Copy link
Copy Markdown
Contributor

lhb8125 commented Oct 21, 2025

LGTM, thanks!

@lhb8125
Copy link
Copy Markdown
Contributor

lhb8125 commented Oct 21, 2025

/te-ci pytorch

@buptzyb
Copy link
Copy Markdown
Contributor Author

buptzyb commented Oct 21, 2025

Hi @timmoon10 may you help take a look or assign to someone? Thanks!

@ksivaman ksivaman self-requested a review October 21, 2025 13:08
@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

buptzyb and others added 2 commits October 22, 2025 06:06
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member

/te-ci pytorch

Signed-off-by: Robin Zhang <robinz@nvidia.com>
@lhb8125
Copy link
Copy Markdown
Contributor

lhb8125 commented Oct 24, 2025

/te-ci pytorch

@lhb8125
Copy link
Copy Markdown
Contributor

lhb8125 commented Oct 24, 2025

@ksivaman The CI looks good now, could you review it?

Wohox pushed a commit to Wohox/TransformerEngine that referenced this pull request Jan 9, 2026
* support cudagraph dw

Signed-off-by: Robin Zhang <robinz@nvidia.com>

* fix lint

Signed-off-by: Robin Zhang <robinz@nvidia.com>

* fix ci

Signed-off-by: Robin Zhang <robinz@nvidia.com>

---------

Signed-off-by: Robin Zhang <robinz@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Wohox pushed a commit to Wohox/TransformerEngine that referenced this pull request Jan 9, 2026
* support cudagraph dw

Signed-off-by: Robin Zhang <robinz@nvidia.com>

* fix lint

Signed-off-by: Robin Zhang <robinz@nvidia.com>

* fix ci

Signed-off-by: Robin Zhang <robinz@nvidia.com>

---------

Signed-off-by: Robin Zhang <robinz@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants