zhutmost / lsq-net

Unofficial implementation of LSQ-Net, a neural network quantization framework
MIT License
276 stars 40 forks source link

Quantized checkpoint is not smaller / faster #18

Closed sophia1488 closed 2 years ago

sophia1488 commented 2 years ago

Hi, first, thanks for your implementation!
It's not too hard to apply your code to my current model! However, the dumped checkpoint has the same size as the original model, and I wonder how to store it with less storage.

Thank you and hope to get your feedback.

zhutmost commented 2 years ago

Hi Sophia, thank you for using my code. Because I dumped the floating-point weights (rather than the quantized ones) in the checkpoint. The quantization runs on the fly, generating quantized weights.

If you want to save quantized weights, you can declare them as Buffer (like Batch Norm), and then they will be also saved into the checkpoint. If you do so, you will find both floating-point and quantized weights in your checkpoint. Then you can remove the floating-point ones from it manually.

zhutmost commented 2 years ago

I am not very similar with deployment tools. I guess there would be some much convenient ways to generate a really quantized model (I mean all weights are integers, but now it is fake quantized) with its pretrained weights.

sophia1488 commented 2 years ago

Hi, thanks for your quick reply! I modified the code and it worked!

For people who wanna save quantized weights, basically the modification is that the LSQ module returns x after round_pass & s_scale, and I dump these 2 parameters instead of self.weight & quan_w_fn.s of QuanConv2d. https://github.com/zhutmost/lsq-net/blob/2c24a96be06d044fa4c7d651727f4574b8d88c86/quan/quantizer/lsq.py#L56 And I implemented functions for dumping & loading the quantized checkpoint.

Finally, the saved checkpoint is reduced by more than 70%, and I could also load the checkpoint for evaluation / inference. 😀

Thanks again ~ (I could also share my implementation if needed.)

zhutmost commented 2 years ago

Great. Thank you, Sophia. I'm still struggling with my theis. Maybe I will consider this feature in a few weeks (and also fix some known compatibility bugs).

zhutmost commented 2 years ago

Actually, a newer version is maintained in NeuralZip,and you can find more experiment results there. But it is developed with PyTorch-Lightning framework, rather than the raw PyTorch.

sophia1488 commented 2 years ago

Hi @zhutmost, The performance of quantized model is good! But I found that the inference time after quantization is slower, do you have any comments on this? Below is the time it takes for inference. (Note the inference time does not include checkpoint loading, and I use the same code)

Thanks! 🙏

zhutmost commented 2 years ago

I have no idea. I cannot explain it without more details.

In my code, the validation/test epoch (just like inference) is much faster than training epoch.

sophia1488 commented 2 years ago

Hi, thanks for your quick reply. Below is the code I modified from your repo. Due to the characteristics of my model, I didn't modify codes related to QuanLinear.

When dumping checkpoint,

for name, module in model.named_modules():
    if type(module) == QuanConv2d:
        # 1. save self.quantized_weight (convert to type int8), self.w_scale
        state_dict[f'{name}.quantized_weight'] = module.quantized_weight.to(torch.int8)
        state_dict[f'{name}.w_scale'] = module.w_scale
        # 2. remove quan_w_fn.s & weight, since there's no need to use them for inference.
        state_dict.pop(f'{name}.quan_w_fn.s', None)
        state_dict.pop(f'{name}.weight', None)

When loading checkpoint,

for name in quantized_weights:
    # so that QuanConv2d.weight will be updated calling model.load_state_dict
    state_dict[f'{name}.weight'] = state_dict[f'{name}.w_scale'] * state_dict[f'{name}.quantized_weight']   # back to float32
    state_dict.pop(f'{name}.w_scale', None)
    state_dict.pop(f'{name}.quantized_weight', None)

In class QuanConv2d,

class QuanConv2d(nn.Conv2d):
    def __init__(self, ..., inference=False):
        ...
        self.inference = inference
        ...
    def forward(self, x):
        # quantized input
        quantized_act, a_scale = self.quan_a_fn(x)
        act = quantized_act * a_scale
        # quantized weight
        if not self.inference:
            self.quantized_weight, self.w_scale = self.quan_w_fn(self.weight)
            weight = self.quantized_weight * self.w_scale
        else:
            weight = self.weight    # the saved quantized weight (in float32)
        return self._conv_forward(act, weight)

I don't understand why it'll get slower either :( Thank you.

784582008 commented 2 years ago

Hi, thanks for your quick reply. Below is the code I modified from your repo. Due to the characteristics of my model, I didn't modify codes related to QuanLinear.

When dumping checkpoint,

for name, module in model.named_modules():
    if type(module) == QuanConv2d:
        # 1. save self.quantized_weight (convert to type int8), self.w_scale
        state_dict[f'{name}.quantized_weight'] = module.quantized_weight.to(torch.int8)
        state_dict[f'{name}.w_scale'] = module.w_scale
        # 2. remove quan_w_fn.s & weight, since there's no need to use them for inference.
        state_dict.pop(f'{name}.quan_w_fn.s', None)
        state_dict.pop(f'{name}.weight', None)

When loading checkpoint,

for name in quantized_weights:
    # so that QuanConv2d.weight will be updated calling model.load_state_dict
    state_dict[f'{name}.weight'] = state_dict[f'{name}.w_scale'] * state_dict[f'{name}.quantized_weight']   # back to float32
    state_dict.pop(f'{name}.w_scale', None)
    state_dict.pop(f'{name}.quantized_weight', None)

In class QuanConv2d,

class QuanConv2d(nn.Conv2d):
    def __init__(self, ..., inference=False):
        ...
        self.inference = inference
        ...
    def forward(self, x):
        # quantized input
        quantized_act, a_scale = self.quan_a_fn(x)
        act = quantized_act * a_scale
        # quantized weight
        if not self.inference:
            self.quantized_weight, self.w_scale = self.quan_w_fn(self.weight)
            weight = self.quantized_weight * self.w_scale
        else:
            weight = self.weight    # the saved quantized weight (in float32)
        return self._conv_forward(act, weight)

I don't understand why it'll get slower either :( Thank you.

Hi, sophia.Is your problem solved?

sophia1488 commented 2 years ago

Honestly, it's still slow but I'm not dealing with this problem now. Thanks!

784582008 commented 2 years ago

Honestly, it's still slow but I'm not dealing with this problem now. Thanks!

Hi,thanks for your quick reply.Do you have a better way now?

BYFgithub commented 1 year ago

Hi, thanks for your quick reply! I modified the code and it worked!

For people who wanna save quantized weights, basically the modification is that the LSQ module returns x after & , and I dump these 2 parameters instead of & of .round_pass``s_scale``self.weight``quan_w_fn.s``QuanConv2d

https://github.com/zhutmost/lsq-net/blob/2c24a96be06d044fa4c7d651727f4574b8d88c86/quan/quantizer/lsq.py#L56

And I implemented functions for dumping & loading the quantized checkpoint. Finally, the saved checkpoint is reduced by more than 70%, and I could also load the checkpoint for evaluation / inference. 😀

Thanks again ~ (I could also share my implementation if needed.)

Hello,can you share this part of the code