Open msaroufim opened 3 months ago
Ran your snippet on my machine, got this error
While copying the parameter named "pooler.dense.weight", whose dimensions in the model are torch.Size([768, 768]) and whose dimensions in the checkpoint are torch.Size([768, 768]), an exception occurred : ('AffineQuantizedTensor dispatch: attempting to run aten.copy_.default, this is not supported',).
I think it's just load_state_dict()
calls weight.copy_(state_dict["weight"])
. So only need to implement aten.copy_.default
. Need to handle all combinations
fp32_weight.copy_(int8_weight)
int8_weight.copy_(fp32_weight)
int8_weight.copy_(int8_weight)
A snippet from my code for 8-bit Adam, might be useful
@DynamicInt8.implements([aten.add_.Tensor, aten.mul_.Tensor, aten.addcmul_.default, aten.addcdiv_.default, aten.lerp_.Scalar, aten.copy_.default])
def _(func, *args, **kwargs):
out = func(*[x.dequantize() if isinstance(x, DynamicInt8) else x for x in args], **kwargs)
# args[0] is the original quantized tensor to be updated in-place
if isinstance(args[0], DynamicInt8):
out = DynamicInt8.from_float(out, args[0].group_size)
args[0].int_data.copy_(out.int_data)
args[0].scale.copy_(out.scale)
args[0].zero_point.copy_(out.zero_point)
# return the original quantized tensor with updated values
out = args[0]
return out
(on a second look, int8_weight.copy_(int8_weight)
can skip dequant + quant step. And might need to check if it works if the other tensor is fp16/bf16)
@gau-nernst did you error out in this line?
model = BertModel.from_pretrained('bert-base-uncased', state_dict=state_dict)
or this line:
# This fails!
state_dict = torch.load('bert_state_dict.pth')
if it's the former, I think the API we can use is the following:
# remove state_dict arg
model = BertModel.from_pretrained('bert-base-uncased')
model.load_state_dict(state_dict, assign=True)
(this is tested in https://github.com/pytorch/ao/blob/c2f9b84604536a72804787001c1b63daae792ee9/test/quantization/test_quant_api.py#L605) and is what we intend people to use in the end
although I feel overriding copy_
also works, just a bit extra thing to do for tensor subclass
@jerryzh168 I didn't comment out anything, I ran it as is. Though my error came from the following line
model = BertModel.from_pretrained('bert-base-uncased', state_dict=state_dict)
torch.load('bert_state_dict.pth')
doesn't give any errors (I was expecting it to give an error if it is run in a new session, without having imported torchao, but that works too)
Using assign=True
is interesting. Personally I haven't used it much. But I can imagine that most users, for now, don't know or more used to the usual .load_state_dict()
.
Normally for me, when I want to load weights to a model, I would apply transformations (if any) to the "base" model so it is exactly the same as when its state dict is exported, then load state dict. In this case, I would expect to quantize the newly init model before loading the saved state dict. Feel like there is too much "magic" if I do fp32_model.load_state_dict(int8_state_dict, assign=True)
then the model becomes int8. So in my comment above, when we do fp32_model.load_state_dict(int8_state_dict)
, I was expecting to dequantize int8 state dict to fp32 and load it to fp32 model. We should probably discuss what should be the expected behavior in this case?
I didn't know that state_dict()
will export the subclass tensor directly, instead of the underlying "storage" tensors. Kinda neat in some sense, though might be unexpected.
fp32_model.load_state_dict(int8_state_dict, assign=True) then the model becomes int8
yeah I understand the surprise, although you could also argue this is an advantage of using tensor subclass, as we don't need to touch the model structure at all, this is possible because the model definition is not changed, we only changed the (generalized version of "dtype") of weight tensors, it's like, is it surprising if when I do: fp32_model.load_state_dict(fp16_state_dict, assign=True)
, I loaded fp16 weights to my model, and now my model is a fp16 model?
if "dtype" model is defined as a model with all "dtype" weights, I feel this seems to be a reasonable transformation.
I feel the surprise might come from the fact that traditionally we are treating quantized model as a separate concept (category) as fp32 models. while now we are kind of blurring the boundary with tensor subclass implementation. But please let me know if you have further thought on this
@jerryzh168 That makes sense to me! It's indeed a shift in the way of thinking about model loading, so definitely need some time to get used to. I think we just need to be very clear to the users about the behavior of .load_state_dict(assign=True)
, which is different from .load_state_dict(assign=False)
, especially when tensor subclass is used.
Another question about backward compatibility. Does this mean that to load the weights with tensor subclass, the implementation of tensor subclass must be compatible with the saved weights? (and also the imports? not 100% sure about how pickle works) We might need a mechanism to load tensor subclass weight when the tensor subclass definition is not available (library is not installed or when we make breaking changes), if such mechanism does not exist yet.
Right now ao works just fine to quantize an arbitrary HF model
However this simple workflow is failing meaning we don't really interop well with the rest of the HF ecosystem