axinc-ai / ailia-models

The collection of pre-trained, state-of-the-art AI models for ailia SDK
2k stars 318 forks source link

ADD vit #275

Closed kyakuno closed 3 years ago

kyakuno commented 3 years ago

https://github.com/lucidrains/vit-pytorch

mucunwuxian commented 3 years ago

リサーチ備忘📝

調査した結果につきまして、こちらに記載させて頂こうと思います。


「lucidrains/vit-pytorh」について

lucidrains/vit-pytorchについて、確認させていただきました。 通常のvitだけでなく、複数種類のvitも登録されており、素晴らしいリポジトリとなっています。

Licenseは、MITになります。

リポジトリの構成としては、以下のようになっています。

./
├── LICENSE
├── README.md
├── examples
│   └── cats_and_dogs.ipynb
├── images
│   ├── cait.png
│   ├── cross_vit.png
│   ├── cvt.png
│   ├── dino.png
│   ├── distill.png
│   ├── levit.png
│   ├── nest.png
│   ├── pit.png
│   ├── t2t.png
│   ├── twins_svt.png
│   └── vit.gif
├── setup.py
└── vit_pytorch
    ├── __init__.py
    ├── cait.py
    ├── cross_vit.py
    ├── cvt.py
    ├── deepvit.py
    ├── dino.py
    ├── distill.py
    ├── efficient.py
    ├── levit.py
    ├── local_vit.py
    ├── mpp.py
    ├── nest.py
    ├── pit.py
    ├── recorder.py
    ├── rvt.py
    ├── t2t.py
    ├── twins_svt.py
    └── vit.py

images フォルダに格納されている画像は、readmeの説明に用いられているものとなります。

vit_pytorch フォルダに格納されているモジュールは、各種vitネットワーク構造のdefや、その補助機能等になります。

example フォルダに格納されている cats_and_dogs.ipynb というファイルは、当該リポジトリのコードを用いてスクラッチで学習する方法をレクチャーしているものでした。 Kaggleにある Dogs vs. Cats Redux: Kernels Edition というデータセットにて、学習する例が実装してあります。


readmeを読んでいくと、pre-trainedモデルのリンク紹介があった為、覗いてみると以下の別リポジトリへと繋がりました。

imagehttps://github.com/rwightman/pytorch-image-models

しかし、当該リポジトリとリンク先リポジトリとを読み込んでみると、どうもvitそのもののpretrainedパラメーターは無さそうでした。 resnet等の有名なネットワークを、特徴抽出として用いるvitの場合に、その特徴抽出部分について、pretrainedモデルを用いれるというニュアンスのようでした。 或いは、readmeに記載のある学習サンプルにおいては、torchvisionのpretrainedモデルを用いて、蒸留を行う手順が記載されていたりしました。 https://github.com/lucidrains/vit-pytorch#distillation


その為、一旦ではありますが、ailia上で動かすonnxの生成元として、このリポジトリは不十分である印象を受けました。 他のリポジトリについても、調査をしてみようと思います。



「jeonsworld/ViT-pytorch」について

jeonsworld/ViT-pytorchというリポジトリについても、確認させていただきました。 こちらも☆の数が600を超えるリポジトリで、中身も充実したものでした。

Licenseは、MITになります。

リポジトリの構成としては、以下のようになっています。

./
├── LICENSE
├── README.md
├── img
│   ├── figure1.png
│   ├── figure2.png
│   └── figure3.png
├── models
│   ├── configs.py
│   ├── modeling.py
│   └── modeling_resnet.py
├── requirements.txt
├── train.py
├── utils
│   ├── data_utils.py
│   ├── dist_util.py
│   └── scheduler.py
└── visualize_attention_map.ipynb


img フォルダに格納されている画像は、readmeの説明に用いられているものとなります。

models フォルダに格納されているモジュールは、各種vitネットワーク構造のdefや、そのcofig補助機能等になります。

utils フォルダに格納されているモジュールは、学習の際に使用するloaderやscheduler等のdefとなります。

visualize_attention_map.ipynb というファイルは、当該リポジトリのコードを用いて、学習済みモデルにてfeed forwardを行うnotebook形式のチュートリアルとなっています。 学習済みモデルのダウンロードについても、コードでの実施が行われるように実装がされています。 また、attention mapによる、注目領域の可視化も行ってくれます。

image


結果も見せ方も良いコードとなっているので、こちらの踏襲すれば良いかと思ったのですが、よくよく調べてみると、このリポジトリにおけるモデルの実装は、Vision Transformer ではなく、Compact Convolutional Transformer というもののようでした。 その為、ひょっとすると、踏襲させていただくのは微妙かもしれないと思いました。


尚、補足としまして、当該リポジトリを用いた学習については、事前学習済みモデルを用いて学習する方針が推奨されておりますが、逆にスクラッチでの学習を行おうとした場合には、一部コードの改変を要するようでありました。 https://github.com/jeonsworld/ViT-pytorch/issues/7

その為、スクラッチで学習をする場合には、lucidrains/vit-pytorchの方が適しているかもしれません。



「google-research/vision_transformer」について

lucidrains/vit-pytorchのリポジトリを見ていると、officialリポジトリが以下である、という記載がありました。 https://github.com/google-research/vision_transformer

このリポジトリは、google-researchによるJAXの実装になります。 notebookによる実装もあり、Colab上にてTPUを使いながら、学習から推論までを一通り実行することができました。 可視化についても、論文の方針と合致しています。 https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax.ipynb


構造は以下のようになっています。

./
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── figure1.png
├── vit_jax
│   ├── __init__.py
│   ├── checkpoint.py
│   ├── checkpoint_test.py
│   ├── configs.py
│   ├── flags.py
│   ├── hyper.py
│   ├── input_pipeline.py
│   ├── logging.py
│   ├── models.py
│   ├── models_resnet.py
│   ├── models_test.py
│   ├── momentum_clip.py
│   ├── requirements.txt
│   ├── train.py
│   └── train_test.py
└── vit_jax.ipynb

models.py に、vitネットワーク構造の大枠が定義されています。 vit_jax.ipynb が、notebookによるスクリプトになります。

Licenseは、Apache-2.0になります。


このリポジトリのモデルを、onnxエクスポートできれば、ailiaへの搭載が可能と考えたのですが、jaxからonnxのエクスポートについて、情報が見つかりませんでした。 当該リポジトリ上のreadmeやissueにonnxというキーワードが無く、「jax onnx」などで検索をしても解決策は見当たらない形です。


しかし、色々調べてみると、JAX版のgoogle researchによるvitを、pytorchへとconvertしてくれているリポジトリを見つけました。 その為、そのリポジトリにて、pytorch版の学習済みモデルを作った後、onnxへのconvertを実施すれば、originalのvit構造を、ailiaに移植できると考えました。



「lukemelas / PyTorch-Pretrained-ViT」について

JAX版のgoogle researchによるvitを、pytorchへとconvertしてくれているリポジトリが以下になります。 https://github.com/lukemelas/PyTorch-Pretrained-ViT/issues?q=is%3Aissue+

尚、リポジトリに格納されているモジュール外に、Colabのスクリプトモジュールもあり、そのリンクはreadmeに記載されています。 https://colab.research.google.com/drive/1muZ4QFgVfwALgqmrfOkp7trAvqDemckO?usp=sharing


リポジトリの構造は以下の様になっています。

./
├── README.md
├── examples
│   ├── imagenet
│   │   ├── README.md
│   │   ├── data
│   │   │   └── README.md
│   │   └── main.py
│   └── simple
│       ├── example.ipynb
│       ├── imagenet-21k-labels.py
│       ├── imagenet.synset.obtain_synset_list
│       ├── img.jpg
│       ├── img2.jpg
│       ├── labels_map.txt
│       └── labels_map_21k.txt
├── jax_to_pytorch
│   ├── README.md
│   ├── convert.py
│   ├── explore-conversion-21k.ipynb
│   ├── explore-conversion.ipynb
│   ├── jax_weights
│   │   └── download.sh
│   └── weights
├── pytorch_pretrained_vit
│   ├── __init__.py
│   ├── configs.py
│   ├── model.py
│   ├── transformer.py
│   └── utils.py
└── setup.py

作られて7ヶ月のリポジトリであるからか、実施内容の性質上か、Licenseについての記載はありませんでした。


convertの方法は非常にシンプルで、jaxで保存されたモデルを、pytorchのstate_dictへと変換するもののようでした。 構造が似ているために、微調整をすれば変換ができるようです。 https://github.com/lukemelas/PyTorch-Pretrained-ViT/blob/master/jax_to_pytorch/convert.py

この変換を経て、pytorchモデルからonnxエクスポートを行えば、googleresearchオリジナルのvitがailiaへと搭載できそうです。


mucunwuxian commented 3 years ago

リサーチ備忘2📝

引き続き調査した結果につきまして、こちらに記載させて頂きます。


各リポジトリの関係性について

先程までのリサーチを経て、概ねonnxエクスポートの方針は決まったのですが、ここで1つ想定外の事象が確認されました。 それは、google-research/vision_transformerにて、実装されているvitが論文の図の通りではなく、compact convolutional transformerのような実装であることです。

具体的には、画像パッチの作成方法を、ストライドの広いconvolutionで行っています。 バラバラになった画像のパッチを可視化すると、デモとして面白いと思ったのですが、それが叶わない次第です。


論文に記載されているネットワーク構造は以下です。

論文の図(リポジトリのreadmeにもこの図が載っている) image


一方で、実際に実装されているネットワーク構造は、こちらに近いものです。

compact convolutional transformer image

(作られるパッチ画像768枚の内の1枚、pixabayの画像にて) image


このことに関連する質問のようなものが、リポジトリのissueにもありました。 (真意は掴みづらいのですが…。) https://github.com/google-research/vision_transformer/issues/93

また、google-researchオリジナルのvitリポジトリ内を、rearrangeで検索しても検出がされませんでした。


実際に、compact convolutional transformerのような実装がされてる該当コード部分は、以下となります。 https://github.com/google-research/vision_transformer/blob/96b6a636902eddab9ce93fdfe05eaa8b3997210e/vit_jax/models.py#L197

(233行目)

    n, h, w, c = x.shape

    # We can merge s2d+emb into a single conv; it's the same.
    x = nn.Conv(
        x,
        hidden_size, patches.size,
        strides=patches.size,
        padding='VALID',
        name='embedding')

    # Here, x is a grid of embeddings.

# We can merge s2d+emb into a single conv; it's the same.と、コメントにも意図がしっかり記載されています。


これらの調査結果から、リポジトリの踏襲を行うという観点においては、compact convolutional transformerのような実装がされたモデルを、onnxエクスポートすることが順当かと思われました。



尚、論文に記載されるような画像をバラバラにして、パッチを得る方針しては、このissueのstart地点でもあるlucidrains/vit-pytorchにて行われているようでした。

該当コードは以下です。 https://github.com/lucidrains/vit-pytorch/blob/60ad4e266ecf1df52968934f571dfe6acd736035/vit_pytorch/vit.py#L81

(93行目)

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

これが論文の通りの実装かと思われます。 ただし、このリポジトリには、学習済みモデルが用意されていない為、この論文に則ったネットワーク構造にて、推論を実施することが難しい状況です。 実施するとすれば、imagenetで学習を実施する必要があります。



尚、attention mapについて実装をしているのが、jeonsworld/ViT-pytorchになります。

以下のような結果を、サンプルスクリプトから得ることができます。 image

また、改めて見てみると、こちらのリポジトリで用いられている重みについても、GoogleのOfficialと書かれています。 恐らく、google-researchのvision_transformerから移植したという意味かと思われます。 image


尚、保存されている重みの形式が、.npzとなっているのですが、これはlukemelas/PyTorch-Pretrained-ViTでの保存形式と合致します。 恐らく、同様の移植をしたものと思われます。


ただし、jeonsworld/ViT-pytorchについては、若干オリジナルのロジックが加えられているような旨が、リポジトリのissueにて伺えます。 (一例:https://github.com/jeonsworld/ViT-pytorch/issues/5

この点は、或いは、学習についてのみかもしれません。 予測精度について、大きなズレは無さそうな印象です。


より精緻に、google-researchの推論機能を移植するという意図ですと、lukemelas/PyTorch-Pretrained-ViTにて、jaxの重みをpytorchに変換した方が良いのかもしれません。 その上で、jeonsworld/ViT-pytorchにて実装されているattention mapの表示を行うと、デモとして親切かもしれません。


mucunwuxian commented 3 years ago

(tips)

patch imagesの可視化(高圧縮畳み込み特徴[224x224 → 14x14]) image


torchvision transformの代替コード

from PIL import Image

# read image
im = Image.open('hoge.png')

# transform
x = im.resize((224, 224), Image.BILINEAR)  # HWC
x = np.array(x).astype(np.float32)
x = x / 255
x = x - np.array([0.5, 0.5, 0.5])[np.newaxis, np.newaxis, :]  # mean
x = x / np.array([0.5, 0.5, 0.5])[np.newaxis, np.newaxis, :]  # std
x = x.transpose(2, 0, 1)  # CHW
x = x[np.newaxis, :, :, :]  # BCHW