feat(optimizer): add FlashAdamW optimizer integration#4229
feat(optimizer): add FlashAdamW optimizer integration#4229meinie0826 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
There was a problem hiding this comment.
Pull request overview
Integrates the flashoptim library’s FlashAdamW optimizer into Megatron-Core by adding configuration/CLI support, wiring optimizer creation into the optimizer factory, and introducing unit tests for FlashAdamW behavior.
Changes:
- Added FlashAdamW argument parsing + validation and included
flashadamwin the--optimizerchoices. - Extended
OptimizerConfigwith FlashAdamW-specific knobs (master weight bits, quantization, checkpoint compression). - Implemented FlashAdamW instantiation in the optimizer factory and added a dedicated unit test module.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
tests/unit_tests/optimizer/test_flashadamw.py |
Adds FlashAdamW-focused unit tests (config defaults, basic CUDA runs, checkpoint roundtrips, and memory/state size checks). |
megatron/training/arguments.py |
Adds FlashAdamW CLI flags and validation; registers flashadamw as a supported optimizer choice. |
megatron/core/optimizer/optimizer_config.py |
Introduces FlashAdamW fields in OptimizerConfig. |
megatron/core/optimizer/__init__.py |
Wires FlashAdamW into optimizer creation and routes it through the “standard optimizer” path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| elif config.optimizer == 'flashadamw': | ||
| try: | ||
| from flashoptim import FlashAdamW | ||
| except ImportError: | ||
| raise ImportError( | ||
| "FlashAdamW optimizer requires flashoptim >= 0.1.3. " | ||
| "Install it with: pip install 'flashoptim>=0.1.3'" | ||
| ) | ||
| optimizer = FlashAdamW( | ||
| param_groups, | ||
| lr=config.lr, | ||
| betas=(config.adam_beta1, config.adam_beta2), | ||
| eps=config.adam_eps, | ||
| weight_decay=config.weight_decay, | ||
| master_weight_bits=config.flashadamw_master_weight_bits, | ||
| quantize=config.flashadamw_quantize, | ||
| compress_state_dict=config.flashadamw_compress_state_dict, | ||
| ) |
There was a problem hiding this comment.
FlashAdamW is created with BF16 model parameters here, but the standard Megatron wrapping that happens later for bf16/fp16 will wrap this optimizer with Float16OptimizerWithFloat16Params, which replaces param_group['params'] with FP32 clones (creating external master weights). That defeats FlashAdamW’s internal master-weight splitting, likely breaks FlashAdamW dtype assumptions, and contradicts the validation message that FlashAdamW manages its own master weights. Consider special-casing flashadamw so it is wrapped with a MegatronOptimizer that does not create FP32 parameter copies (e.g., FP32Optimizer or a dedicated wrapper), and ensure incompatible flags like use_distributed_optimizer/fp16 scaling are handled explicitly in-core (not only via arguments.py).
| import copy | ||
| import tempfile | ||
| from pathlib import Path | ||
| from unittest.mock import MagicMock, patch |
There was a problem hiding this comment.
Unused imports: MagicMock and patch are imported but never referenced in this test module. Please remove them to keep the test file clean and avoid lint failures in environments that enforce unused-import checks.
What does this PR do ?
Integrates FlashAdamW from the
flashoptimlibrary (>= 0.1.3) into Megatron-Core's optimizer infrastructure.FlashAdamW reduces optimizer memory from 16 bytes/param (BF16 model + fp32 master weight + fp32 exp_avg + fp32 exp_avg_sq) to ~7 bytes/param via:
exp_avg, sqrt-transformed INT8exp_avg_sqCloses #4171
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.