axinc-ai / ailia-models-tflite

Quantized version of model library
24 stars 2 forks source link

ADD vit #84

Open kyakuno opened 4 months ago

kyakuno commented 4 months ago

下記のvitをtfliteに変換する。 https://github.com/taki0112/vit-tensorflow

Kitazume-Ax commented 3 months ago

作業ブランチ: https://github.com/axinc-ai/ailia-models-tflite/tree/kitazume/add_vision_transformer TensorFlow Liteを使用(--tflite)し、floatのモデルを使用(--float)している状態。

Kitazume-Ax commented 3 months ago

下記のPRを作成済み。 https://github.com/axinc-ai/ailia-models-tflite/pull/85

Kitazume-Ax commented 3 months ago

変換に使用した ViT TensorFlow 実装

下記のリポジトリをgit clone Vision Transformer in TensorFlow 2.x https://github.com/hrithickcodes/vision_transformer_tf

READMEに書かれている requirements.txt を使用したインストールは失敗するので、動作するバージョンのパッケージを指定してインストールする。

pip install tensorflow==2.6.5
pip install matplotlib==3.6.3
pip install contourpy==1.1.1
pip install numpy==1.19.5
pip install pyyaml

TensorFlow Liteに変換するため tf.Erf が含まれないように、gelu を approximate=True にする。 vision_transformer_tf/layers/pwffn.py : Line 15

self.gelu = tf.keras.layers.Lambda(lambda x: tf.keras.activations.gelu(x, approximate=True))

vit_architectures.yaml は 224x224 に変更。

ViT-BASE16: 
    encoder_layers: 12
    patch_embedding_dim: 768
    units_in_mlp: 3072
    attention_heads: 12
    image_size: [224, 224, 3]
    patch_size: 16
    dropout_rate: 0.1
    classes: 1000
    class_activation: "sigmoid"
Kitazume-Ax commented 3 months ago

tf_flowers の学習

元実装のImageNet学習済みウェイトはURLが404で失われている。 下記のようなコードを作成して学習を実行。

import os
import numpy as np
from vit import viT
import tensorflow as tf
import tensorflow_datasets as tfds
from utils.loss import vit_loss
from utils.plots import plot_accuracy, plot_loss

image_size = 224
batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 256
num_classes = 5

learning_rate = 0.001
momentum = 0.9
global_clipnorm = 1.0
vit_config = "vit_architectures.yaml"
epochs = 30
validation_batch_size = 16

def preprocess_dataset(is_training=True):
    def _pp(image, label):
        if is_training:
            image = tf.image.resize(image, (resize_bigger, resize_bigger))
            image = tf.image.random_crop(image, (image_size, image_size, 3))
            image = tf.image.random_flip_left_right(image)
        else:
            image = tf.image.resize(image, (image_size, image_size))
        image = image / 127.5 - 1.0
        return image, label

    return _pp

def prepare_dataset(dataset, is_training=True):
    if is_training:
        dataset = dataset.shuffle(batch_size * 10)
    dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
    return dataset.batch(batch_size).prefetch(auto)

train_dataset, val_dataset = tfds.load("tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True)
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)

vit = viT(vit_size="ViT-BASE16", num_classes=num_classes, config_path=vit_config)

optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum, global_clipnorm=global_clipnorm)

chekpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join("training_weights", f"ViT-BASE16_tf_flowers"),
    monitor="val_acc", save_best_only=True, save_weights_only=True)

vit.compile(optimizer=optimizer, loss=vit_loss, metrics=["acc"])

history = vit.fit(train_dataset,
                  validation_data=val_dataset,
                  shuffle=True,
                  validation_batch_size=validation_batch_size,
                  callbacks=[chekpoint],
                  epochs=epochs)
Kitazume-Ax commented 3 months ago

tflite変換

保存した Checkpoint を使用した tflite 変換は下記のように実行可能。

import os
from vit import viT
import tensorflow as tf 
from utils.general import load_config

num_classes = 5
VIT_CONFIG = load_config("vit_architectures.yaml")

model = viT("ViT-BASE16", num_classes)
model.load_weights(os.path.join("training_weights", "ViT-BASE16_tf_flowers")).expect_partial()
model.compute_output_shape(input_shape = [1] + VIT_CONFIG["ViT-BASE16"]["image_size"])

converter = tf.lite.TFLiteConverter.from_keras_model(model)
vit_tflite = converter.convert()
open("vision_transformer_float.tflite", "wb").write(vit_tflite)
Kitazume-Ax commented 3 months ago

int8 量子化 tflite 変換

学習に使用している tf_flowers から100イメージを代表データセットとして量子化を実行。 ※ PCの実行時間の関係で100にしている

import os
from vit import viT
import tensorflow as tf 
from utils.general import load_config
import tensorflow_datasets as tfds

num_classes = 5
VIT_CONFIG = load_config("vit_architectures.yaml")

model = viT("ViT-BASE16", num_classes)
model.load_weights(os.path.join("training_weights", "ViT-BASE16_tf_flowers")).expect_partial()
model.compute_output_shape(input_shape = [1] + VIT_CONFIG["ViT-BASE16"]["image_size"])

model.summary()
print(os.linesep)

ds = tfds.load("tf_flowers", as_supervised=True)
ds_train = ds["train"]
print(ds_train.cardinality().numpy())

image_size = 224
def preprocess_dataset():
    def _pp(image, label):
        image = tf.image.resize(image, (image_size, image_size))
        image = image / 127.5 - 1.0
        return image, label

    return _pp

def prepare_dataset(dataset):
    return dataset.map(preprocess_dataset(), num_parallel_calls=tf.data.AUTOTUNE)

pp_ds = prepare_dataset(ds_train)

def representative_data_gen():
  for data in pp_ds.batch(1).take(100):
    yield [data[0]]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
vit_tflite = converter.convert()