Closed sophia1488 closed 2 years ago
Hi Sophia, thank you for using my code. Because I dumped the floating-point weights (rather than the quantized ones) in the checkpoint. The quantization runs on the fly, generating quantized weights.
If you want to save quantized weights, you can declare them as Buffer (like Batch Norm), and then they will be also saved into the checkpoint. If you do so, you will find both floating-point and quantized weights in your checkpoint. Then you can remove the floating-point ones from it manually.
I am not very similar with deployment tools. I guess there would be some much convenient ways to generate a really quantized model (I mean all weights are integers, but now it is fake quantized) with its pretrained weights.
Hi, thanks for your quick reply! I modified the code and it worked!
For people who wanna save quantized weights, basically the modification is that the LSQ module returns x after round_pass
& s_scale
, and I dump these 2 parameters instead of self.weight
& quan_w_fn.s
of QuanConv2d
.
https://github.com/zhutmost/lsq-net/blob/2c24a96be06d044fa4c7d651727f4574b8d88c86/quan/quantizer/lsq.py#L56
And I implemented functions for dumping & loading the quantized checkpoint.
Finally, the saved checkpoint is reduced by more than 70%, and I could also load the checkpoint for evaluation / inference. 😀
Thanks again ~ (I could also share my implementation if needed.)
Great. Thank you, Sophia. I'm still struggling with my theis. Maybe I will consider this feature in a few weeks (and also fix some known compatibility bugs).
Actually, a newer version is maintained in NeuralZip,and you can find more experiment results there. But it is developed with PyTorch-Lightning framework, rather than the raw PyTorch.
Hi @zhutmost, The performance of quantized model is good! But I found that the inference time after quantization is slower, do you have any comments on this? Below is the time it takes for inference. (Note the inference time does not include checkpoint loading, and I use the same code)
Thanks! 🙏
I have no idea. I cannot explain it without more details.
In my code, the validation/test epoch (just like inference) is much faster than training epoch.
Hi, thanks for your quick reply. Below is the code I modified from your repo.
Due to the characteristics of my model, I didn't modify codes related to QuanLinear
.
When dumping checkpoint,
for name, module in model.named_modules():
if type(module) == QuanConv2d:
# 1. save self.quantized_weight (convert to type int8), self.w_scale
state_dict[f'{name}.quantized_weight'] = module.quantized_weight.to(torch.int8)
state_dict[f'{name}.w_scale'] = module.w_scale
# 2. remove quan_w_fn.s & weight, since there's no need to use them for inference.
state_dict.pop(f'{name}.quan_w_fn.s', None)
state_dict.pop(f'{name}.weight', None)
When loading checkpoint,
for name in quantized_weights:
# so that QuanConv2d.weight will be updated calling model.load_state_dict
state_dict[f'{name}.weight'] = state_dict[f'{name}.w_scale'] * state_dict[f'{name}.quantized_weight'] # back to float32
state_dict.pop(f'{name}.w_scale', None)
state_dict.pop(f'{name}.quantized_weight', None)
In class QuanConv2d
,
class QuanConv2d(nn.Conv2d):
def __init__(self, ..., inference=False):
...
self.inference = inference
...
def forward(self, x):
# quantized input
quantized_act, a_scale = self.quan_a_fn(x)
act = quantized_act * a_scale
# quantized weight
if not self.inference:
self.quantized_weight, self.w_scale = self.quan_w_fn(self.weight)
weight = self.quantized_weight * self.w_scale
else:
weight = self.weight # the saved quantized weight (in float32)
return self._conv_forward(act, weight)
I don't understand why it'll get slower either :( Thank you.
Hi, thanks for your quick reply. Below is the code I modified from your repo. Due to the characteristics of my model, I didn't modify codes related to
QuanLinear
.When dumping checkpoint,
for name, module in model.named_modules(): if type(module) == QuanConv2d: # 1. save self.quantized_weight (convert to type int8), self.w_scale state_dict[f'{name}.quantized_weight'] = module.quantized_weight.to(torch.int8) state_dict[f'{name}.w_scale'] = module.w_scale # 2. remove quan_w_fn.s & weight, since there's no need to use them for inference. state_dict.pop(f'{name}.quan_w_fn.s', None) state_dict.pop(f'{name}.weight', None)
When loading checkpoint,
for name in quantized_weights: # so that QuanConv2d.weight will be updated calling model.load_state_dict state_dict[f'{name}.weight'] = state_dict[f'{name}.w_scale'] * state_dict[f'{name}.quantized_weight'] # back to float32 state_dict.pop(f'{name}.w_scale', None) state_dict.pop(f'{name}.quantized_weight', None)
In class
QuanConv2d
,class QuanConv2d(nn.Conv2d): def __init__(self, ..., inference=False): ... self.inference = inference ... def forward(self, x): # quantized input quantized_act, a_scale = self.quan_a_fn(x) act = quantized_act * a_scale # quantized weight if not self.inference: self.quantized_weight, self.w_scale = self.quan_w_fn(self.weight) weight = self.quantized_weight * self.w_scale else: weight = self.weight # the saved quantized weight (in float32) return self._conv_forward(act, weight)
I don't understand why it'll get slower either :( Thank you.
Hi, sophia.Is your problem solved?
Honestly, it's still slow but I'm not dealing with this problem now. Thanks!
Honestly, it's still slow but I'm not dealing with this problem now. Thanks!
Hi,thanks for your quick reply.Do you have a better way now?
Hi, thanks for your quick reply! I modified the code and it worked!
For people who wanna save quantized weights, basically the modification is that the LSQ module returns x after & , and I dump these 2 parameters instead of & of .
round_pass``s_scale``self.weight``quan_w_fn.s``QuanConv2d
And I implemented functions for dumping & loading the quantized checkpoint. Finally, the saved checkpoint is reduced by more than 70%, and I could also load the checkpoint for evaluation / inference. 😀
Thanks again ~ (I could also share my implementation if needed.)
Hello,can you share this part of the code
Hi, first, thanks for your implementation!
It's not too hard to apply your code to my current model! However, the dumped checkpoint has the same size as the original model, and I wonder how to store it with less storage.
Thank you and hope to get your feedback.