Skip to content

feat: add FlashOptim optimizer integration#1492

Merged
hemildesai merged 5 commits intomainfrom
feature/flashoptim-integration
Mar 10, 2026
Merged

feat: add FlashOptim optimizer integration#1492
hemildesai merged 5 commits intomainfrom
feature/flashoptim-integration

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

@hemildesai hemildesai commented Mar 9, 2026

Summary

Closes #1447

Convergence run - https://wandb.ai/Nemo-automodel/automodel-flashoptim

Need to pin to >=0.1.3 due to databricks/flashoptim#3

  • Add FlashOptim (FlashAdamW) as a supported optimizer, reducing optimizer memory footprint via compressed master weights (8-bit/16-bit correction terms)
  • Add example config for Qwen3 MoE 30B fine-tuning with FlashAdamW (qwen3_moe_30b_te_packed_sequence_flashoptim.yaml)
  • Add L2 CI test (L2_FlashOptim_DCP_Roundtrip) that verifies DCP checkpoint save/load fidelity for FlashAdamW's compressed optimizer states
  • Add flashoptim>=0.1.3 as a dependency in pyproject.toml

Details

FlashOptim integration:
FlashOptim >= 0.1.3 provides native DTensor support, enabling seamless integration with PyTorch DCP (Distributed Checkpoint) and FSDP2. The FlashAdamW optimizer uses quantized correction terms to reduce optimizer memory while maintaining training quality.

DCP roundtrip test:
The L2 test trains a model for N steps with FlashAdamW, saves model + optimizer state via DCP, loads the checkpoint into a fresh model/optimizer, and compares continued-training losses between the original and resumed runs. This validates that the quantized optimizer states survive the DCP save/load roundtrip within a configurable loss delta threshold.

The test uses synthetic random data (no dataset download dependency) and runs as part of the existing L2_HF_DCP CI job.

Test plan

  • L2_HF_DCP CI job passes (includes test_flashoptim_dcp_roundtrip)
  • Manual: torchrun --nproc-per-node=2 tests/functional_tests/checkpoint/test_flashoptim_dcp_roundtrip.py --model <model_path>

🤖 Generated with Claude Code

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@hemildesai hemildesai force-pushed the feature/flashoptim-integration branch from 380e30a to 8433588 Compare March 9, 2026 02:31
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 8433588

@hemildesai hemildesai force-pushed the feature/flashoptim-integration branch from fcd88cf to b08abc0 Compare March 9, 2026 06:22
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test b08abc0

- Rewrite L2 test to run two separate torchrun invocations (train +
  resume) so dataloader state restores correctly in a fresh process
- Validator script compares training.jsonl logs: checks num_label_tokens
  match and losses are within threshold after checkpoint resume
- Add llama3_2_1b_squad_flashoptim.yaml config for CI (bf16 model dtype)
- Fix Qwen3 MoE flashoptim config: torch_fp32 rms_norm, rope_fusion
  disabled, remove unsupported foreach param

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hemildesai <hemild@nvidia.com>
@hemildesai hemildesai force-pushed the feature/flashoptim-integration branch from b08abc0 to 507f7a3 Compare March 9, 2026 19:04
@hemildesai hemildesai changed the title [DO NOT MERGE] feat: add FlashOptim optimizer integration feat: add FlashOptim optimizer integration Mar 9, 2026
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 507f7a3

Copy link
Copy Markdown
Contributor

@akoumpa akoumpa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @hemildesai

@hemildesai hemildesai enabled auto-merge (squash) March 10, 2026 02:30
@hemildesai hemildesai merged commit 03b8f91 into main Mar 10, 2026
55 of 57 checks passed
@hemildesai hemildesai deleted the feature/flashoptim-integration branch March 10, 2026 13:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Integrate flashoptim

3 participants