xinghaochen / TinySAM

Official PyTorch implementation of "TinySAM: Pushing the Envelope for Efficient Segment Anything Model"
Apache License 2.0
403 stars 23 forks source link

how to quantize the lightweight SAM model? #24

Open ranpin opened 6 months ago

ranpin commented 6 months ago

Hi, nice work it is. I'm tring your method to do some application and have some questions about the quantization.

I have carefully looked at the code in demo_quan.py and layer.py, but currently the model in demo_quan.py is loaded directly from the quantized weights. I would like to ask how to quantize the instances created from an existing pre-trained SAM model using your quantization method?

Since I don't know how you quantize your lightweight SAM model using the quantization method in layer.py, can you provide a reference example of how did you do when quantizing the model? Thank you very much!

Here is the demo I wrote, it runs successfully, but the test result after quantization is close to 0. does it need retraining? Or maybe I'm not thinking correctly? hoping your reply!

from quantization_layer.layers import InferQuantConv2d, InferQuantConvTranspose2d

model_type = 'vit_b'
checkpoint = 'checkpoints/sam_vit_b_01ec64.pth'
model = sam_model_registry[model_type](checkpoint=checkpoint)
model.to(device)
model.eval()
predictor = SamPredictor(model)

w_bit = 8
a_bit = 8
input_size = (1, 3, 1024, 1024)  
n_V = input_size[2]
n_H = input_size[3]
a_interval = torch.tensor(0.1)
a_bias = torch.tensor(0.0)
w_interval = torch.tensor(0.01)

# 量化模型中的卷积层和卷积转置层
def replace_with_quantized_layers(model):
    layers_to_replace = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            layers_to_replace.append((name, module))
    for name, module in layers_to_replace:
        if isinstance(module, nn.Conv2d):
            quantized_module = InferQuantConv2d(
                in_channels=module.in_channels,
                out_channels=module.out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                dilation=module.dilation,
                groups=module.groups,
                bias=module.bias is not None,
                mode='quant_forward',
                w_bit=w_bit,
                a_bit=a_bit
            )
            quantized_module.get_parameter(n_V=n_V, 
                      n_H=n_H,
                      a_interval=a_interval,
                      a_bias=a_bias,
                      w_interval=w_interval)
        elif isinstance(module, nn.ConvTranspose2d):
            quantized_module = InferQuantConvTranspose2d(
                in_channels=module.in_channels,
                out_channels=module.out_channels,
                kernel_size=module.kernel_size,
                stride=module.stride,
                padding=module.padding,
                output_padding=module.output_padding,
                groups=module.groups,
                bias=module.bias is not None,
                mode='quant_forward',
                w_bit=w_bit,
                a_bit=a_bit
            )
            quantized_module.get_parameter(n_V=n_V,
                              n_H=n_H,
                              a_interval=a_interval,
                              a_bias=a_bias,
                              w_interval=w_interval)
        setattr(model, name, quantized_module)
    return model

quan_model = replace_with_quantized_layers(model)
print(quan_model)