mit-han-lab / tinyengine

[NeurIPS 2020] MCUNet: Tiny Deep Learning on IoT Devices; [NeurIPS 2021] MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning; [NeurIPS 2022] MCUNetV3: On-Device Training Under 256KB Memory
https://mcunet.mit.edu
MIT License
792 stars 130 forks source link

Torch->TFlite Converter? #6

Closed travisjayday closed 2 years ago

travisjayday commented 2 years ago

Your example .tflite fliles in the /assets folder, seem like they were generated by a custom tool. At least their description field in the binary is TinyNeuralNetwork Converted. instead of your standard MLIR Converted. or TOCO Converted., coming from tensorflow's tf.lite.TFLiteConverter. Is this correct?

We're trying to convert our own Proxyless models but are having trouble doing so because restricting op support in the code generator. Are there plans to open source a torch->tflite converter?

In the original mcunet submodule (the old MCUNet repo), there's some TensorFlow 1.x code to convert a ProxylessNAS network to TFLite. Do you have updated code for this? And updated Proxyless models? Which ties in with... #5

Thanks!

tonylins commented 2 years ago

Hi, thanks for reaching out. In the original version, we manually "rewrite" the model in Tensorflow and copy the weights. It works, but it is quite complicated to do the manual conversion. So we later switch to a third-party conversion tool which we find quite handy: https://github.com/alibaba/TinyNeuralNetwork

To use the tool for exporting to tf-lite, you can use the following code piece:

import os
from tqdm import tqdm
import json

import torch
import argparse
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.utils.data.distributed
from torchvision import datasets, transforms

from mcunet.model_zoo import build_model
from mcunet.utils import AverageMeter, accuracy, count_net_flops, count_parameters

# Training settings
parser = argparse.ArgumentParser()
# net setting
parser.add_argument('--net_id', type=str, help='net id of the model')
# data loader setting
parser.add_argument('--dataset', default='imagenet', type=str, choices=['imagenet', 'vww'])
parser.add_argument('--data-dir', default=os.path.expanduser('/dataset/imagenet/val'),
                    help='path to ImageNet validation data')
parser.add_argument('--batch-size', type=int, default=128,
                    help='input batch size for training')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers')

args = parser.parse_args()

torch.backends.cudnn.benchmark = True
device = 'cuda'

def build_val_data_loader(resolution):
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    kwargs = {'num_workers': args.workers, 'pin_memory': True}

    if args.dataset == 'imagenet':
        val_transform = transforms.Compose([
            transforms.Resize(int(resolution * 256 / 224)),
            transforms.CenterCrop(resolution),
            transforms.ToTensor(),
            normalize
        ])
    elif args.dataset == 'vww':
        val_transform = transforms.Compose([
            transforms.Resize((resolution, resolution)),  # if center crop, the person might be excluded
            transforms.ToTensor(),
            normalize
        ])
    else:
        raise NotImplementedError
    val_dataset = datasets.ImageFolder(args.data_dir, transform=val_transform)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, **kwargs)
    return val_loader

def validate(model, val_loader):
    model.eval()
    val_loss = AverageMeter()
    val_top1 = AverageMeter()

    with tqdm(total=len(val_loader), desc='Validate') as t:
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)

                output = model(data)
                val_loss.update(F.cross_entropy(output, target).item())
                top1 = accuracy(output, target, topk=(1,))[0]
                val_top1.update(top1.item(), n=data.shape[0])
                t.set_postfix({'loss': val_loss.avg,
                               'top1': val_top1.avg})
                t.update(1)

    return val_top1.avg

def main():
    model, resolution, description = build_model(args.net_id, pretrained=True)
    model = model.to(device)
    model.eval()
    val_loader = build_val_data_loader(resolution)

    from tinynn.converter import TFLiteConverter
    from tinynn.graph.quantization.quantizer import PostQuantizer
    from tinynn.util.train_util import DLContext
    from tinynn.util.cifar10 import calibrate

    # Provide a viable input for the model
    dummy_input = torch.rand((1, 3, resolution, resolution))

    quantizer = PostQuantizer(model, dummy_input, work_dir='ptq', config={'asymmetric': True, 'per_tensor': False})
    ptq_model = quantizer.quantize()

    # Move model to the appropriate device
    ptq_model.to(device)

    context = DLContext()
    context.device = device

    context.train_loader = context.val_loader = val_loader
    context.max_iteration = 100

    # Post quantization calibration
    calibrate(ptq_model, context)

    with torch.no_grad():
        ptq_model.eval()
        ptq_model.cpu()
        ptq_model = torch.quantization.convert(ptq_model)
        torch.backends.quantized.engine = quantizer.backend
        converter = TFLiteConverter(ptq_model, dummy_input, tflite_path='ptq/ptq.tflite',
                                    quantize_target_type='int8')
        converter.convert()

    # profile model
    total_macs = count_net_flops(model, [1, 3, resolution, resolution])
    total_params = count_parameters(model)
    print(' * FLOPs: {:.4}M, param: {:.4}M'.format(total_macs / 1e6, total_params / 1e6))

    acc = validate(model, val_loader)
    print(' * Accuracy: {:.2f}%'.format(acc))

if __name__ == '__main__':
    main()
travisjayday commented 2 years ago

Thanks for the example code!