Skip to content

[Trainer/Experimental] Support fp8 backward#13121

Merged
comfyanonymous merged 1 commit intoComfy-Org:masterfrom
KohakuBlueleaf:quant-bwd
Mar 25, 2026
Merged

[Trainer/Experimental] Support fp8 backward#13121
comfyanonymous merged 1 commit intoComfy-Org:masterfrom
KohakuBlueleaf:quant-bwd

Conversation

@KohakuBlueleaf
Copy link
Copy Markdown
Contributor

As title, this PR added quantized backward utilize fp8e5m2 and use comfy-kitchen and original quantized matmul/linear path to do the calculation.

The speed is not faster as we may require fast quantize+linear kernel for it, but it can reduce the intermediate state saved in vram.

It should be treated as initial experimental implementation of fp8 bwd training, not production ready as it's performance not better than before.

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Mar 23, 2026

📝 Walkthrough

Walkthrough

This pull request introduces FP8 quantized backward computation support to the training pipeline. A module-level flag training_fp8_bwd is added to control backward computation behavior. The QuantLinearFunc in comfy/ops.py is modified to optionally use FP8-quantized tensors during backward passes instead of regular float tensors. The TrainLoraNode is updated to expose a quantized_backward input parameter that configures this behavior at training runtime.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically summarizes the main change: adding support for FP8 backward training. It directly matches the primary feature introduced in this PR.
Description check ✅ Passed The description is directly related to the changeset, explaining the FP8 backward quantization feature, its purpose (reducing VRAM usage), and its experimental status.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
comfy/ops.py (1)

819-829: Skip saving q_input when grad_weight is impossible.

Line 868 is the only place that reads input_mm, so the new FP8 branch does not need ctx.q_input when ctx.weight_requires_grad is false. Saving it unconditionally keeps an extra FP8 activation alive through backward in frozen-weight cases and cuts into the VRAM savings this feature is aiming for.

Possible change
-        if ctx.fp8_bwd:
-            # Cache FP8 quantized input — half the memory of bf16
-            if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'):
-                ctx.q_input = q_input  # already FP8, reuse
-            else:
-                # NVFP4 or other layout — quantize input to FP8 for backward
-                ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
+        if ctx.fp8_bwd:
+            if ctx.weight_requires_grad:
+                # Cache FP8 quantized input only when grad_weight needs it
+                if isinstance(q_input, QuantizedTensor) and layout_type.startswith("TensorCoreFP8"):
+                    ctx.q_input = q_input
+                else:
+                    ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout")
+            else:
+                ctx.q_input = None
             ctx.save_for_backward(weight)
As per coding guidelines, `comfy/**` changes should focus on memory management and performance implications in hot paths.

Also applies to: 867-868

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ops.py` around lines 819 - 829, The FP8 branch currently always sets
and saves ctx.q_input even when ctx.weight_requires_grad is False, retaining an
unnecessary FP8 activation; change the logic in the fp8_bwd branch so that
ctx.q_input is only created and saved when ctx.weight_requires_grad is True
(i.e., wrap the QuantizedTensor.from_float / assignment and
ctx.save_for_backward(weight) in a check for ctx.weight_requires_grad), and when
weight_requires_grad is False ensure ctx.q_input is left None and only the
minimal tensors required for backward are saved via ctx.save_for_backward;
mirror the same conditional behavior for the non-fp8 branch to avoid retaining
q_input when grad_weight is impossible.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@comfy_extras/nodes_train.py`:
- Line 1135: The assignment comfy.model_management.training_fp8_bwd =
quantized_backward changes a process-global flag too early and is not restored
if setup/model loading fails; modify execute() so you capture the previous value
(e.g., prev = comfy.model_management.training_fp8_bwd) before setting it and
restore it in an outer finally block (or scope the flag to the execution) so
that regardless of errors the original value is reapplied; ensure the restore
happens outside the existing in_training cleanup and reference
comfy.model_management.training_fp8_bwd and QuantLinearFunc to locate the change
and the affected FP8 backward path.

---

Nitpick comments:
In `@comfy/ops.py`:
- Around line 819-829: The FP8 branch currently always sets and saves
ctx.q_input even when ctx.weight_requires_grad is False, retaining an
unnecessary FP8 activation; change the logic in the fp8_bwd branch so that
ctx.q_input is only created and saved when ctx.weight_requires_grad is True
(i.e., wrap the QuantizedTensor.from_float / assignment and
ctx.save_for_backward(weight) in a check for ctx.weight_requires_grad), and when
weight_requires_grad is False ensure ctx.q_input is left None and only the
minimal tensors required for backward are saved via ctx.save_for_backward;
mirror the same conditional behavior for the non-fp8 branch to avoid retaining
q_input when grad_weight is impossible.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b467a0d6-1d3c-4a49-a2ea-e02726cdd9c6

📥 Commits

Reviewing files that changed from the base of the PR and between 6265a23 and b4e0688.

📒 Files selected for processing (3)
  • comfy/model_management.py
  • comfy/ops.py
  • comfy_extras/nodes_train.py

@comfyanonymous comfyanonymous merged commit 5ebb0c2 into Comfy-Org:master Mar 25, 2026
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants