Add support for 4bit JoyCaption#344
Conversation
tl;dr: bitsanbytes is quantizing a nn.NonDynamicallyQuantizableLinear output in JoyCaption. We revert it back to a nn.Linear (as nn.NonDynamicallyQuantizableLinear is not a public type.) We also set the dtype in a few more places where we don't, or don't always.
| load_in_4bit=True, | ||
| bnb_4bit_quant_type='nf4', | ||
| bnb_4bit_compute_dtype=self.dtype, | ||
| bnb_4bit_quant_storage=self.dtype, |
There was a problem hiding this comment.
Could you explain why you added this?
There was a problem hiding this comment.
This is something I don't fully understand to be honest, but the investigation I've done so far suggests it's not currently making anything worse. If that is incorrect, this part of the PR can be made to only happen for JoyCaption fairly trivially. Let me know if you would like to go that route, and I'll make the change.
My theory is that without specifying our quant dtype to match our model dtype, the parts of the model that aren't quantizable (just the attention head? I'm not sure.) are still expecting data in the model-native size. That's just a rough theory though; I haven't dug into the why at all when the problem went away easily.
Without this fix, when running JoyCaption in 4bit, we get a type mismatch error on uint8 and bf16. With this change, we store our quantized data in the type dtype is, which makes the error go away. For the record, valid types are: int8, uint8, fp16, bf16, fp32, and fp64. The default is torch.uint8 as per above.
I looked at the video-memory usage with/without this and it doesn't seem to increase (on models that you can run without this change). The parameter comment ("the storage type to pack the quanitzed 4-bit prarams" suggested to me that maybe multiple tensors are being packed into this type, which makes sense that they'd have support for if they're normally usually storing 4bits in a 8bit data type..) My local test runs also did not show a difference in how long it takes to caption images (on a 4bit model with and without this change), they were the same in a build with this and without.
| # If our out_proj was converted into a nn.Linear4bit, replace | ||
| # it with the original nn.Linear. JoyCaption's out-projection | ||
| # layer is not dynamically quantizable. | ||
| if isinstance(attention.out_proj, bitsandbytes.nn.Linear4bit): |
There was a problem hiding this comment.
What would be the case where this is not true? Isn't this fixed for the given model?
There was a problem hiding this comment.
I think no, it would currently always be true. This add is an after-thought, when I was repacked this for the PR.
Basically, my thought process was "what if [the model authors] fix the model in whatever way BNB is not currently happy with it, and this hack is not needed and causes the model to break and/or silently misbehave?". I wanted to add one more layer of shielding between that possibility, and since this code should only run once per run, it seemed fairly safe to do so.
Another way to do this might be to load the model normally, but without the weights1 to avoid double-loading the model, and inspect if out_proj is a torch.nn.NondynamicallyQuantizableLinear before BNB mutates it, and use that to guide our decision to swap a torch.nn.Linear4bit to torch.nn.Linear. That solution seemed like a lot more work for something that I think is something that might not ever happen.
1 - is this possible via the torch meta device?
|
works well on a 16 GB card. thank you! |
Add support for 4bit JoyCaption
Add support for 4bit JoyCaption
tl;dr: bitsandbytes is quantizing a
nn.NonDynamicallyQuantizableLinearoutput in JoyCaption. We revert it back to ann.Linear(asnn.NonDynamicallyQuantizableLinearis not a constructable type.)We also set the dtype in a few more places where we don't, or don't always.