Refactor Wan Model Training & Add Wan-VACE Training Support#352
Refactor Wan Model Training & Add Wan-VACE Training Support#352
Conversation
|
As this is a fairly large refactor:
|
| ) | ||
|
|
||
| max_logging.log("Restoring WAN checkpoint") | ||
| restored_checkpoint = self.checkpoint_manager.restore( |
There was a problem hiding this comment.
is this replicating the sharding across devices? If so, would this be able to load on a trillium tpu with 32GB of HBM?
There was a problem hiding this comment.
The mesh is created using CPU devices, not TPU devices. So the model is being loaded into RAM.
| try: | ||
| state[path].value = device_put_replicated(val, sharding) | ||
| except Exception as e: | ||
| max_logging.log(f"Failed to device_put_replicated {path}: {e}") |
There was a problem hiding this comment.
under what conditions is the exception code is executed?
There was a problem hiding this comment.
This code is executed when weights are not fully available on all hosts, which occurs when a checkpoint is loaded in a multi-host training or inference environment. Without process_allgather in this case, the code raises the following error: ValueError: When the second argument to device_put is a Device, the first argument must be a fully addressable array or a non-addressable array with a single device sharding.
|
@ninatu overall looks good, can you squash your commits and run the linter and this should be good to go. |
Key changes include:
1. Bug fixes:
* Resolved training mode bug when dropout > 0 (e.g., ensured rngs parameter is passed to layer_forward for gradient checkpointing with dropout)
* Fixed prepare_sample_fn usage for 'tfrecord' dataset type.
* Addressed checkpoint loading issues with larger TPU slices and different topologies for Wan 2.1.
* Corrected timestep sampling for continuous sampling
2. Config updates:
* Ensured adam_weight_decay is a float.
* Added tensorboard_dir parameter for logging.
* Now uses config.learning_rate instead of a hardcoded value.
* Set default dropout to 0.0 in WAN configs (instead of 0.1).
3. Wan-VACE Support:
* Refactoring: Common training components (initialization, scheduler, TFLOPs calculation, training/eval loops) have been abstracted into a new BaseWanTrainer ABC to improve code structure and reusability.
* Added new scripts (train_wan_vace.py), trainer (wan_vace_trainer.py), and checkpointing logic (wan_vace_checkpointing_2_1.py) to enable training of WAN-VACE models.
4. New Features:
* Introduced config.disable_training_weights to optionally disable mid-point loss weighting.
* Added logging for max_grad_norm and max_abs_grad.
Co-authored-by: martinarroyo <martinarroyo@google.com>
8c82c74 to
f9f9506
Compare
|
@entrpn, thanks, I squashed! It requires approval again |
This PR introduces several improvements and fixes to the Wan model training, as well as adds support for training Wan-VACE models.
Key changes include:
Bug fixes:
Config updates:
Wan-VACE Support:
New Features: