ModelTC / MQBench

Model Quantization Benchmark
Apache License 2.0
769 stars 140 forks source link

how to apply advance_ptq to ViT? #273

Open dedoogong opened 1 month ago

dedoogong commented 1 month ago

Hello! I'm trying hard to convert ViT to int8 using MQBench. when I applied adaround config to vit( from timm library), it showed error related to usage of nn.parameter. this error occur after clibrartion and in reconstruction_ptq step. more in depth, it occurs during fx.graphmodule conversion step in calling extract_subgraph().

ViT usually contains cls_token, pos_embed as nn.parameter, not nn.Module. how to avoid or solve this case?

  1. imagenet_example's resnet 18 model contains only nn.Module so it didnt make errors.
  2. I tried to wrap nn.parameter to nn.module by creating a new nn.module class in the timm library like

    
    class ParameterWrapper(nn.Module):
    def __init__(self, param):
        super(ParameterWrapper, self).__init__()
        self.param = nn.Parameter(param)
    
    def forward(self):
        return self.param

original code : self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None

self.cls_token = ParameterWrapper(torch.zeros(1, 1, embed_dim)) if class_token else None


but actually it makes another error after extract_subgraph(), in subgraph_reconstruction() step. 

Please help me!