Open kyakuno opened 4 months ago
作業ブランチ: https://github.com/axinc-ai/ailia-models-tflite/tree/kitazume/add_vision_transformer TensorFlow Liteを使用(--tflite)し、floatのモデルを使用(--float)している状態。
下記のリポジトリをgit clone Vision Transformer in TensorFlow 2.x https://github.com/hrithickcodes/vision_transformer_tf
READMEに書かれている requirements.txt を使用したインストールは失敗するので、動作するバージョンのパッケージを指定してインストールする。
pip install tensorflow==2.6.5
pip install matplotlib==3.6.3
pip install contourpy==1.1.1
pip install numpy==1.19.5
pip install pyyaml
TensorFlow Liteに変換するため tf.Erf が含まれないように、gelu を approximate=True にする。 vision_transformer_tf/layers/pwffn.py : Line 15
self.gelu = tf.keras.layers.Lambda(lambda x: tf.keras.activations.gelu(x, approximate=True))
vit_architectures.yaml は 224x224 に変更。
ViT-BASE16:
encoder_layers: 12
patch_embedding_dim: 768
units_in_mlp: 3072
attention_heads: 12
image_size: [224, 224, 3]
patch_size: 16
dropout_rate: 0.1
classes: 1000
class_activation: "sigmoid"
元実装のImageNet学習済みウェイトはURLが404で失われている。 下記のようなコードを作成して学習を実行。
import os
import numpy as np
from vit import viT
import tensorflow as tf
import tensorflow_datasets as tfds
from utils.loss import vit_loss
from utils.plots import plot_accuracy, plot_loss
image_size = 224
batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 256
num_classes = 5
learning_rate = 0.001
momentum = 0.9
global_clipnorm = 1.0
vit_config = "vit_architectures.yaml"
epochs = 30
validation_batch_size = 16
def preprocess_dataset(is_training=True):
def _pp(image, label):
if is_training:
image = tf.image.resize(image, (resize_bigger, resize_bigger))
image = tf.image.random_crop(image, (image_size, image_size, 3))
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, (image_size, image_size))
image = image / 127.5 - 1.0
return image, label
return _pp
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(batch_size * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
return dataset.batch(batch_size).prefetch(auto)
train_dataset, val_dataset = tfds.load("tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True)
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
vit = viT(vit_size="ViT-BASE16", num_classes=num_classes, config_path=vit_config)
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=momentum, global_clipnorm=global_clipnorm)
chekpoint = tf.keras.callbacks.ModelCheckpoint(os.path.join("training_weights", f"ViT-BASE16_tf_flowers"),
monitor="val_acc", save_best_only=True, save_weights_only=True)
vit.compile(optimizer=optimizer, loss=vit_loss, metrics=["acc"])
history = vit.fit(train_dataset,
validation_data=val_dataset,
shuffle=True,
validation_batch_size=validation_batch_size,
callbacks=[chekpoint],
epochs=epochs)
保存した Checkpoint を使用した tflite 変換は下記のように実行可能。
import os
from vit import viT
import tensorflow as tf
from utils.general import load_config
num_classes = 5
VIT_CONFIG = load_config("vit_architectures.yaml")
model = viT("ViT-BASE16", num_classes)
model.load_weights(os.path.join("training_weights", "ViT-BASE16_tf_flowers")).expect_partial()
model.compute_output_shape(input_shape = [1] + VIT_CONFIG["ViT-BASE16"]["image_size"])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
vit_tflite = converter.convert()
open("vision_transformer_float.tflite", "wb").write(vit_tflite)
学習に使用している tf_flowers から100イメージを代表データセットとして量子化を実行。 ※ PCの実行時間の関係で100にしている
import os
from vit import viT
import tensorflow as tf
from utils.general import load_config
import tensorflow_datasets as tfds
num_classes = 5
VIT_CONFIG = load_config("vit_architectures.yaml")
model = viT("ViT-BASE16", num_classes)
model.load_weights(os.path.join("training_weights", "ViT-BASE16_tf_flowers")).expect_partial()
model.compute_output_shape(input_shape = [1] + VIT_CONFIG["ViT-BASE16"]["image_size"])
model.summary()
print(os.linesep)
ds = tfds.load("tf_flowers", as_supervised=True)
ds_train = ds["train"]
print(ds_train.cardinality().numpy())
image_size = 224
def preprocess_dataset():
def _pp(image, label):
image = tf.image.resize(image, (image_size, image_size))
image = image / 127.5 - 1.0
return image, label
return _pp
def prepare_dataset(dataset):
return dataset.map(preprocess_dataset(), num_parallel_calls=tf.data.AUTOTUNE)
pp_ds = prepare_dataset(ds_train)
def representative_data_gen():
for data in pp_ds.batch(1).take(100):
yield [data[0]]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
vit_tflite = converter.convert()
下記のvitをtfliteに変換する。 https://github.com/taki0112/vit-tensorflow