Hello! I'm trying hard to convert ViT to int8 using MQBench.
when I applied adaround config to vit( from timm library), it showed error related to usage of nn.parameter.
this error occur after clibrartion and in reconstruction_ptq step.
more in depth, it occurs during fx.graphmodule conversion step in calling extract_subgraph().
ViT usually contains cls_token, pos_embed as nn.parameter, not nn.Module.
how to avoid or solve this case?
imagenet_example's resnet 18 model contains only nn.Module so it didnt make errors.
I tried to wrap nn.parameter to nn.module by creating a new nn.module class in the timm library like
Hello! I'm trying hard to convert ViT to int8 using MQBench. when I applied adaround config to vit( from timm library), it showed error related to usage of nn.parameter. this error occur after clibrartion and in reconstruction_ptq step. more in depth, it occurs during fx.graphmodule conversion step in calling extract_subgraph().
ViT usually contains cls_token, pos_embed as nn.parameter, not nn.Module. how to avoid or solve this case?
I tried to wrap nn.parameter to nn.module by creating a new nn.module class in the timm library like
original code : self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.cls_token = ParameterWrapper(torch.zeros(1, 1, embed_dim)) if class_token else None