[Trainer/Experimental] Support fp8 backward#13121
[Trainer/Experimental] Support fp8 backward#13121comfyanonymous merged 1 commit intoComfy-Org:masterfrom
Conversation
📝 WalkthroughWalkthroughThis pull request introduces FP8 quantized backward computation support to the training pipeline. A module-level flag 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
comfy/ops.py (1)
819-829: Skip savingq_inputwhengrad_weightis impossible.Line 868 is the only place that reads
input_mm, so the new FP8 branch does not needctx.q_inputwhenctx.weight_requires_gradis false. Saving it unconditionally keeps an extra FP8 activation alive through backward in frozen-weight cases and cuts into the VRAM savings this feature is aiming for.As per coding guidelines, `comfy/**` changes should focus on memory management and performance implications in hot paths.Possible change
- if ctx.fp8_bwd: - # Cache FP8 quantized input — half the memory of bf16 - if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'): - ctx.q_input = q_input # already FP8, reuse - else: - # NVFP4 or other layout — quantize input to FP8 for backward - ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout") + if ctx.fp8_bwd: + if ctx.weight_requires_grad: + # Cache FP8 quantized input only when grad_weight needs it + if isinstance(q_input, QuantizedTensor) and layout_type.startswith("TensorCoreFP8"): + ctx.q_input = q_input + else: + ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout") + else: + ctx.q_input = None ctx.save_for_backward(weight)Also applies to: 867-868
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/ops.py` around lines 819 - 829, The FP8 branch currently always sets and saves ctx.q_input even when ctx.weight_requires_grad is False, retaining an unnecessary FP8 activation; change the logic in the fp8_bwd branch so that ctx.q_input is only created and saved when ctx.weight_requires_grad is True (i.e., wrap the QuantizedTensor.from_float / assignment and ctx.save_for_backward(weight) in a check for ctx.weight_requires_grad), and when weight_requires_grad is False ensure ctx.q_input is left None and only the minimal tensors required for backward are saved via ctx.save_for_backward; mirror the same conditional behavior for the non-fp8 branch to avoid retaining q_input when grad_weight is impossible.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@comfy_extras/nodes_train.py`:
- Line 1135: The assignment comfy.model_management.training_fp8_bwd =
quantized_backward changes a process-global flag too early and is not restored
if setup/model loading fails; modify execute() so you capture the previous value
(e.g., prev = comfy.model_management.training_fp8_bwd) before setting it and
restore it in an outer finally block (or scope the flag to the execution) so
that regardless of errors the original value is reapplied; ensure the restore
happens outside the existing in_training cleanup and reference
comfy.model_management.training_fp8_bwd and QuantLinearFunc to locate the change
and the affected FP8 backward path.
---
Nitpick comments:
In `@comfy/ops.py`:
- Around line 819-829: The FP8 branch currently always sets and saves
ctx.q_input even when ctx.weight_requires_grad is False, retaining an
unnecessary FP8 activation; change the logic in the fp8_bwd branch so that
ctx.q_input is only created and saved when ctx.weight_requires_grad is True
(i.e., wrap the QuantizedTensor.from_float / assignment and
ctx.save_for_backward(weight) in a check for ctx.weight_requires_grad), and when
weight_requires_grad is False ensure ctx.q_input is left None and only the
minimal tensors required for backward are saved via ctx.save_for_backward;
mirror the same conditional behavior for the non-fp8 branch to avoid retaining
q_input when grad_weight is impossible.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b467a0d6-1d3c-4a49-a2ea-e02726cdd9c6
📒 Files selected for processing (3)
comfy/model_management.pycomfy/ops.pycomfy_extras/nodes_train.py
As title, this PR added quantized backward utilize fp8e5m2 and use comfy-kitchen and original quantized matmul/linear path to do the calculation.
The speed is not faster as we may require fast quantize+linear kernel for it, but it can reduce the intermediate state saved in vram.
It should be treated as initial experimental implementation of fp8 bwd training, not production ready as it's performance not better than before.