Skip to content

Refactor Wan Model Training & Add Wan-VACE Training Support#352

Open
ninatu wants to merge 1 commit intomainfrom
ninatu/wan_training
Open

Refactor Wan Model Training & Add Wan-VACE Training Support#352
ninatu wants to merge 1 commit intomainfrom
ninatu/wan_training

Conversation

@ninatu
Copy link
Copy Markdown
Collaborator

@ninatu ninatu commented Mar 11, 2026

This PR introduces several improvements and fixes to the Wan model training, as well as adds support for training Wan-VACE models.

Key changes include:

  1. Bug fixes:

    • Resolved training mode bug when dropout > 0 (e.g., ensured rngs parameter is passed to layer_forward for gradient checkpointing with dropout)
    • Fixed prepare_sample_fn usage for 'tfrecord' dataset type.
    • Addressed checkpoint loading issues with larger TPU slices and different topologies for Wan 2.1.
    • Corrected timestep sampling for continuous sampling
  2. Config updates:

    • Ensured adam_weight_decay is a float.
    • Added tensorboard_dir parameter for logging.
    • Now uses config.learning_rate instead of a hardcoded value.
    • Set default dropout to 0.0 in WAN configs (instead of 0.1).
  3. Wan-VACE Support:

    • Refactoring: Common training components (initialization, scheduler, TFLOPs calculation, training/eval loops) have been abstracted into a new BaseWanTrainer ABC to improve code structure and reusability.
    • Added new scripts (train_wan_vace.py), trainer (wan_vace_trainer.py), and checkpointing logic (wan_vace_checkpointing_2_1.py) to enable training of WAN-VACE models.
  4. New Features:

    • Introduced config.disable_training_weights to optionally disable mid-point loss weighting.
    • Added logging for max_grad_norm and max_abs_grad.

@ninatu ninatu requested a review from entrpn as a code owner March 11, 2026 14:35
@github-actions
Copy link
Copy Markdown

@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented Mar 12, 2026

As this is a fairly large refactor:

  • @prishajain1 can you do a review of the checkpointing changes?

  • @susanbao can you take a quick look at the training changes?

)

max_logging.log("Restoring WAN checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this replicating the sharding across devices? If so, would this be able to load on a trillium tpu with 32GB of HBM?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mesh is created using CPU devices, not TPU devices. So the model is being loaded into RAM. 

try:
state[path].value = device_put_replicated(val, sharding)
except Exception as e:
max_logging.log(f"Failed to device_put_replicated {path}: {e}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under what conditions is the exception code is executed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is executed when weights are not fully available on all hosts, which occurs when a checkpoint is loaded in a multi-host training or inference environment. Without process_allgather in this case, the code raises the following error: ValueError: When the second argument to device_put is a Device, the first argument must be a fully addressable array or a non-addressable array with a single device sharding.

@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented Apr 1, 2026

@ninatu overall looks good, can you squash your commits and run the linter and this should be good to go.

entrpn
entrpn previously approved these changes Apr 1, 2026
Key changes include:
1. Bug fixes:
    * Resolved training mode bug when dropout > 0 (e.g., ensured rngs parameter is passed to layer_forward for gradient checkpointing with dropout)
    * Fixed prepare_sample_fn usage for 'tfrecord' dataset type.
    * Addressed checkpoint loading issues with larger TPU slices and different topologies for Wan 2.1.
    * Corrected timestep sampling for continuous sampling
2. Config updates:
    * Ensured adam_weight_decay is a float.
    * Added tensorboard_dir parameter for logging.
    * Now uses config.learning_rate instead of a hardcoded value.
    * Set default dropout to 0.0 in WAN configs (instead of 0.1).
3. Wan-VACE Support:
    * Refactoring: Common training components (initialization, scheduler, TFLOPs calculation, training/eval loops) have been abstracted into a new BaseWanTrainer ABC to improve code structure and reusability.
    * Added new scripts (train_wan_vace.py), trainer (wan_vace_trainer.py), and checkpointing logic (wan_vace_checkpointing_2_1.py) to enable training of WAN-VACE models.
4. New Features:
    * Introduced config.disable_training_weights to optionally disable mid-point loss weighting.
    * Added logging for max_grad_norm and max_abs_grad.

Co-authored-by: martinarroyo <martinarroyo@google.com>
@ninatu
Copy link
Copy Markdown
Collaborator Author

ninatu commented Apr 1, 2026

@entrpn, thanks, I squashed! It requires approval again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants