pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
994 stars 106 forks source link

HF checkpoint integration story #456

Open msaroufim opened 3 months ago

msaroufim commented 3 months ago

Right now ao works just fine to quantize an arbitrary HF model

However this simple workflow is failing meaning we don't really interop well with the rest of the HF ecosystem

from transformers import BertModel
import torch
from transformers import BertModel
import torchao 

model = BertModel.from_pretrained('bert-base-uncased')
model = torchao.quantize(model, torchao.quantization.quant_api.int8_weight_only())

state_dict = model.state_dict()

# Save state dict for later use
# This works
torch.save(state_dict, 'bert_state_dict.pth')

# Load the saved state dict
# This fails!
state_dict = torch.load('bert_state_dict.pth')

# Initialize the model
model = BertModel.from_pretrained('bert-base-uncased', state_dict=state_dict)
assert "Affine" in str(model.state_dict())
gau-nernst commented 3 months ago

Ran your snippet on my machine, got this error

While copying the parameter named "pooler.dense.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('AffineQuantizedTensor dispatch: attempting to run aten.copy_.default, this is not supported',).

I think it's just load_state_dict() calls weight.copy_(state_dict["weight"]). So only need to implement aten.copy_.default. Need to handle all combinations

A snippet from my code for 8-bit Adam, might be useful

@DynamicInt8.implements([aten.add_.Tensor, aten.mul_.Tensor, aten.addcmul_.default, aten.addcdiv_.default, aten.lerp_.Scalar, aten.copy_.default])
def _(func, *args, **kwargs):
    out = func(*[x.dequantize() if isinstance(x, DynamicInt8) else x for x in args], **kwargs)

    # args[0] is the original quantized tensor to be updated in-place
    if isinstance(args[0], DynamicInt8):
        out = DynamicInt8.from_float(out, args[0].group_size)
        args[0].int_data.copy_(out.int_data)
        args[0].scale.copy_(out.scale)
        args[0].zero_point.copy_(out.zero_point)

        # return the original quantized tensor with updated values
        out = args[0]

    return out

(on a second look, int8_weight.copy_(int8_weight) can skip dequant + quant step. And might need to check if it works if the other tensor is fp16/bf16)

jerryzh168 commented 3 months ago

@gau-nernst did you error out in this line?

model = BertModel.from_pretrained('bert-base-uncased', state_dict=state_dict)

or this line:

# This fails!
state_dict = torch.load('bert_state_dict.pth')

if it's the former, I think the API we can use is the following:

# remove state_dict arg
model = BertModel.from_pretrained('bert-base-uncased')
model.load_state_dict(state_dict, assign=True)

(this is tested in https://github.com/pytorch/ao/blob/c2f9b84604536a72804787001c1b63daae792ee9/test/quantization/test_quant_api.py#L605) and is what we intend people to use in the end

although I feel overriding copy_ also works, just a bit extra thing to do for tensor subclass

gau-nernst commented 3 months ago

@jerryzh168 I didn't comment out anything, I ran it as is. Though my error came from the following line

model = BertModel.from_pretrained('bert-base-uncased', state_dict=state_dict)

torch.load('bert_state_dict.pth') doesn't give any errors (I was expecting it to give an error if it is run in a new session, without having imported torchao, but that works too)

Using assign=True is interesting. Personally I haven't used it much. But I can imagine that most users, for now, don't know or more used to the usual .load_state_dict().

Normally for me, when I want to load weights to a model, I would apply transformations (if any) to the "base" model so it is exactly the same as when its state dict is exported, then load state dict. In this case, I would expect to quantize the newly init model before loading the saved state dict. Feel like there is too much "magic" if I do fp32_model.load_state_dict(int8_state_dict, assign=True) then the model becomes int8. So in my comment above, when we do fp32_model.load_state_dict(int8_state_dict), I was expecting to dequantize int8 state dict to fp32 and load it to fp32 model. We should probably discuss what should be the expected behavior in this case?

I didn't know that state_dict() will export the subclass tensor directly, instead of the underlying "storage" tensors. Kinda neat in some sense, though might be unexpected.

jerryzh168 commented 2 months ago

fp32_model.load_state_dict(int8_state_dict, assign=True) then the model becomes int8

yeah I understand the surprise, although you could also argue this is an advantage of using tensor subclass, as we don't need to touch the model structure at all, this is possible because the model definition is not changed, we only changed the (generalized version of "dtype") of weight tensors, it's like, is it surprising if when I do: fp32_model.load_state_dict(fp16_state_dict, assign=True), I loaded fp16 weights to my model, and now my model is a fp16 model?

if "dtype" model is defined as a model with all "dtype" weights, I feel this seems to be a reasonable transformation.

I feel the surprise might come from the fact that traditionally we are treating quantized model as a separate concept (category) as fp32 models. while now we are kind of blurring the boundary with tensor subclass implementation. But please let me know if you have further thought on this

gau-nernst commented 2 months ago

@jerryzh168 That makes sense to me! It's indeed a shift in the way of thinking about model loading, so definitely need some time to get used to. I think we just need to be very clear to the users about the behavior of .load_state_dict(assign=True), which is different from .load_state_dict(assign=False), especially when tensor subclass is used.

Another question about backward compatibility. Does this mean that to load the weights with tensor subclass, the implementation of tensor subclass must be compatible with the saved weights? (and also the imports? not 100% sure about how pickle works) We might need a mechanism to load tensor subclass weight when the tensor subclass definition is not available (library is not installed or when we make breaking changes), if such mechanism does not exist yet.