Skip to content

Bug: NVFP4 quantized UNET models fail with shape mismatch when loaded via UNETLoader (convert_old_quants called after prefix strip) #13214

@codeman101

Description

@codeman101

Description

When loading an NVFP4-quantized model as a standalone UNET file via UNETLoader (or any path through load_diffusion_model / load_diffusion_model_state_dict), all quantized layers silently lose their quantization config and fail at inference time with a matrix shape mismatch.

Error

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4992x4096 and 2048x4096)

The packed uint8 NVFP4 storage (half the columns of the logical weight) gets loaded as a raw bfloat16 tensor with the wrong shape, then passed directly to F.linear.

Root Cause

In comfy/sd.py, load_diffusion_model_state_dict() strips the model.diffusion_model. prefix from all state dict keys before calling convert_old_quants():

# Current (broken) order:
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0:
    sd = temp_sd

if custom_operations is None:
    sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)  # ← too late

convert_old_quants() reads _quantization_metadata from the safetensors metadata and adds <layer>.comfy_quant keys to the state dict. But these layer names come from the metadata JSON, which still has the full model.diffusion_model. prefix. Meanwhile the state dict keys have already had that prefix stripped.

Result: comfy_quant keys look like model.diffusion_model.transformer_blocks.2.attn1.to_q.comfy_quant but the actual weight key is transformer_blocks.2.attn1.to_q.weight — no match. _load_from_state_dict never finds the quant config, so layout_type stays None and the raw packed uint8 weight is loaded without the QuantizedTensor wrapper.

Note: load_checkpoint_guess_config() (Load Checkpoint node) does not have this bug because it calls convert_old_quants before prefix stripping.

Fix

Move the convert_old_quants call to before state_dict_prefix_replace, and pass the detected prefix so the metadata keys get stripped along with everything else:

diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)

if custom_operations is None:
    sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)

temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0:
    sd = temp_sd

After this change, state_dict_prefix_replace strips model.diffusion_model. from both the weight keys and the newly-added comfy_quant keys, so they match correctly in _load_from_state_dict.

Affected Models

Any NVFP4-quantized model loaded as a standalone UNET (not a full checkpoint) whose safetensors file contains _quantization_metadata and whose state dict has a model.diffusion_model. prefix. Confirmed with LTX-2.3 22B dev NVFP4 (ltx-2.3-22b-dev-nvfp4.safetensors).

Environment

  • ComfyUI 0.18.1
  • PyTorch 2.9.1+cu130
  • RTX 5080 (SM 12.0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions