alibaba / TinyNeuralNetwork

TinyNeuralNetwork is an efficient and easy-to-use deep learning model compression framework.
MIT License
746 stars 115 forks source link

How to quantize ViT model with quantization aware training #374

Open Linsop2 opened 3 days ago

Linsop2 commented 3 days ago

It can train the ViT model from the Hugging Face transformer, but when converting to tflite model it appear an error message that I can't solve it. The following are the tinynn setting and the error message

Transformers version is 4.26.0

from tinynn.graph.tracer import import_patcher
from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import QATQuantizer

with import_patcher():
    from transformers import ViTForImageClassification

device = 'cuda'

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class ViTWrapper(nn.Module):
    def __init__(self, vit):
        super().__init__()
        self.vit = vit

    def forward(self, x):
        return self.vit(x).logits

model_url = 'google/vit-base-patch16-224'

with import_patcher():
        model = ViTForImageClassification.from_pretrained(model_url)
        Vit = ViTWrapper(model)

################ QAT SET #######################
dummy = torch.rand([1,3,224,224])
mean = (0.485 + 0.456 + 0.406) / 3 * 255 
std = (0.229 + 0.224 + 0.225) / 3 * 255
config={'asymmetric': True, 'per_tensor': False, 'backend':'fbgemm',"rewrite_graph": False,'quantized_input_stats':[(float(mean), float(std))]}
quantizer = QATQuantizer(Vit, dummy, work_dir='out',config=config)
Vit = quantizer.quantize()
Vit.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
Vit = torch.quantization.prepare_qat(Vit)
Vit.cuda()

#####Skip  training  Step#####

##### Convert tflite Step#####
    if((epoch) % 1 == 0):
        with torch.no_grad():
            device = 'cpu'
            Vit.eval()
            Vit.to(device)
            Vit = quantizer.convert(Vit)
            torch.backends.quantized.engine = quantizer.backend
            converter = TFLiteConverter(Vit, dummy, tflite_path='./qat_model.tflite',fuse_quant_dequant=True)
            converter.convert()

The error message:

Traceback (most recent call last):
  File "c:\Users\user\Desktop\ViT\tinynn_PTQ\QAT.py", line 100, in <module>
    converter.convert()
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 512, in convert
    self.init_jit_graph()
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 237, in init_jit_graph
    script = torch.jit.trace(self.model, self.dummy_input)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 1000, in trace
    traced_func = _trace_impl(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 695, in _trace_impl
    return trace_module(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\jit\_trace.py", line 1275, in trace_module
    module._c._create_method_from_trace(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "c:\Users\user\Desktop\ViT\tinynn_PTQ\QAT.py", line 39, in forward
    return self.vit(x).logits
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\transformers\models\vit\modeling_vit.py", line 789, in forward
    outputs = self.vit(
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\torch\nn\modules\module.py", line 1543, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\transformers\models\vit\modeling_vit.py", line 571, in forward
    expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
AttributeError: 'function' object has no attribute 'dtype
peterjc123 commented 3 days ago

Just noticed that you are not using the Quantized graph rewrite of TinyNN as I can see the following option in your code. "rewrite_graph": False and torch.quantization.prepare_qat. Just FYI, the TFLiteConverter won't work for most of the cases if you don't use our quantization tools.

Linsop2 commented 2 days ago

I modify the QAT setting part, but it appears another error

Traceback (most recent call last):
  File "c:\Users\user\Desktop\ViT\QAT.py", line 91, in <module>
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 513, in convert
    self.init_lowered_module()
  File "C:\Users\user\anaconda3\envs\tf1\lib\site-packages\tinynn\converter\base.py", line 299, in init_lowered_module
    torch._C._jit_pass_lower_all_tuples(graph)
RuntimeError: prim::TupleUnpack not matched to tuple construct

This is my whole code

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from tqdm import tqdm
from tinynn.graph.tracer import import_patcher
from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import QATQuantizer
from torch.optim.lr_scheduler import StepLR

with import_patcher():
    from transformers import ViTForImageClassification

BATCH_SIZE = 16
LEARNING_RATE = 0.0003
EPOCH = 1
N_CLASSES = 1000

device = 'cuda'
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

trainData = dsets.ImageFolder(r"C:\Users\user\Desktop\dataset\train", transform)
testData =  dsets.ImageFolder(r"C:\Users\user\Desktop\dataset\test", transform)
trainLoader = torch.utils.data.DataLoader(dataset=trainData, batch_size=BATCH_SIZE, shuffle=True)
testLoader  = torch.utils.data.DataLoader(dataset=testData, batch_size=BATCH_SIZE, shuffle=False)

class ViTWrapper(nn.Module):
    def __init__(self, vit):
        super().__init__()
        self.vit = vit

    def forward(self, x):
        return self.vit(x).logits

model_url = 'google/vit-base-patch16-224'

with import_patcher():
        model = ViTForImageClassification.from_pretrained(model_url)
        Vit = ViTWrapper(model)

print(f'Learning rate is : {LEARNING_RATE}')
print(f'Batch size is : {BATCH_SIZE}')

################ QAT SET #######################
dummy = torch.rand([1,3,224,224])
config={'asymmetric': True, 'per_tensor': False, 'backend':'fbgemm','rewrite_graph':True}
quantizer = QATQuantizer(Vit, dummy, work_dir='out',config=config)
Vit = quantizer.quantize()
Vit.cuda()
################ QAT SET #######################

################ Training #######################
# Loss, Optimizer & Scheduler
cost = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(Vit.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

# Train the model
for epoch in range(1,EPOCH+1):
    Vit.train()
    avg_loss = 0
    cnt = 0
    progress_bar = tqdm(trainLoader, desc=f"Epoch {epoch}/{EPOCH}", leave=False)
    for images, labels in progress_bar:
        images = images.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        outputs = Vit(images)
        loss = cost(outputs, labels)
        avg_loss += loss.item()
        cnt += 1
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix({'loss': f"{loss.item():.4f}", 'avg_loss': f"{avg_loss/cnt:.4f}"})

    print(f"[Epoch: {epoch}] avg_loss: {avg_loss/cnt:.4f}")
    scheduler.step()
    torch.save(Vit.state_dict(), f'./epoch_qat{epoch}.pth')

################### Convert model #########################
    if((epoch) % EPOCH == 0):
        with torch.no_grad():
            Vit.eval()
            Vit.cpu()
            Vit = quantizer.convert(Vit)
            torch.backends.quantized.engine = quantizer.backend
            converter = TFLiteConverter(Vit, dummy, tflite_path='./qat_model.tflite')
            converter.convert()
peterjc123 commented 2 days ago

We will take a look.