Open GENZITSU opened 2 years ago
花王 + 歩行解析 といえばメリーズ。
花王は、ベビー用の紙おむつ開発の一環として、幼児の歩行動作解析に関する研究を実施してきた。また、成人の歩行動作を解析するシステムも開発し、3万人以上の歩行データをもとにした研究を進めている。
動作解析の研究で用いられるモーションキャプチャ技術には、3次元の歩行動作を正確に解析できるメリットがある一方、複数の赤外線カメラが必要になるなど、場所や時間に制約があった。
そこで花王は、スマートフォンなどのカメラで撮影した2次元の動画に関して機械学習を行い、人体の骨格点の動きを3次元データとして算出する技術を開発。
モーションキャプチャから得られrた各パラメータをお元に日齢の回帰ができるらしい。
モーションキャプチャ技術から得られた推定日齢は実日齢と高い相関関係があり、測定精度の高さが実証された。 mGCC技術から得られた推定日齢は、モーションキャプチャ技術から得られた推定日齢と高い相関関係があり、両者はほぼ同等の精度で日齢を推定できることが実証されている。
高度な機械なしに3dモーションが取れるようになると、ユーザーからデータを拾い上げるのが楽になる...? 直接的な活用法はよくわからないが、面白い技術。
アンスコの使い方についていろいろ解説されている。勉強になったのは以下。
# from https://qiita.com/_Kohei_/items/381737d0e03e24ab0f41
def __double_leading_underscore:
return something
# when naming a class attribute, invokes name mangling
# (inside class FooBar, __boo becomes _FooBar__boo; see below).
マングリングは以下がわかりやすい。
# from https://qiita.com/mounntainn/items/e3fb1a5757c9cf7ded63
class HogeHoge:
_dummy_prv = "hoge"
__almost_prv = "hoge"
class Main:
hogehoge = HogeHoge()
# 丸見え
a = hogehoge._dummy_prv
# エラー
b = hogehoge.__almost_prv
# こうすると見えちゃう
c = hogehoge._HogeHoge__almost_prv
あんまり使いどころないかもしれないけど、うっかり変数を見られないようにするためにありかも...?
サイボウズ社の2021研修資料から
代表的なOSSのライセンスがまとめられている。
地味にまとまった資料がなかったのでとても助かる。
ShapPackではSHAPライブラリのKernel SHAPと比較して,以下の三つの新たな機能を実装している.
- マルチプロセスで並列処理できる機能
- 特性関数を独自に実装して組み込める機能
- SHAP値を計算しない特徴量を指定できる機能
使い方はオリジナルshapと同様。
# from https://blog.tsurubee.tech/entry/2021/07/21/094213
import shappack
i = 2
explainer = shappack.KernelExplainer(model.predict, X_train_std[:100])
shap_value = explainer.shap_values(X_test_std[i])
マルチプロセス化できるのは恩恵大きそう。
DevOpsにおける中心技術としては、バージョン管理(VCS)、継続的インテグレーション(CI)、継続的デリバリー(CD)、システムの監視が挙げられます。MLOpsの文脈では、DevOpsの開発パイプラインに対して以下のような機能を拡張したパイプライン(MLパイプライン)が提供・開発されることが多いです:
- モデルバージョン管理(VCSの拡張)
- データバリデーション・モデル性能テスト(CIの拡張)
- モデルサービングや段階的デプロイ基盤(CDの拡張)
- モデルの継続的再訓練(CTと呼ばれることがある)
- モデルの性能監視
機械学習モデルの更新頻度やデータの環境変化が多い業態におけるシステムでは、こういったMLパイプラインを効率化・高信頼化する需要が大きいと
機械学習開発で必須になる作業である実験開発のプロセスにおいても、効率化やスケール化の需要が発生します。 MLOpsでは、機械学習のモデルの実験作業を効率化したり、実験のスケール化を可能にする基盤を用意したりする技術が扱われる
代表的な技術要素を以下に挙げてみます。
- 必要なマシンリソースの自動確保・分散学習(実験インフラ)
- 実験管理(実験設定と使用したデータなどを管理する仕組み)
- 実験ワークフローの管理(データ処理や実験実行自体の自動化)
- ハイパーパラメータ調整(実験設定探索の自動化)
- AutoML(機械学習モデルを選別する作業の自動化)
- 特徴量サービング(開発された特徴量の管理・共有)
- モデル説明・可視化(実験結果の確認の効率化)
将来の機械学習エンジニアの仕事は、問題の規格化といった仕様策定、システム化によるコスト削減、依然難易度の高い作業である機械学習システムのデバッグや、タスクのセミカスタマイズなどに中心が移っていくのかもしれません。 現在はまだMLOps技術が探索的な発展を続けている状況下でもあるため、自社の機械学習モデルのライフサイクルのどこに効率化の旨味があるかを見極めることが機械学習エンジニアの重要な役割の一つといえる
機械学習技術を導入する旨味は、実験開発という大きなコストを払うことで、多くの判断が自動化されるだけでなく、同等の機能を実現するソフトウェアよりもシステムの複雑性が大きく減じ得る点にもあります。 実験と実績により簡単に実現できるとわかった機能はなるべく自動化し、投資に対するスケールメリットを積極的に取りに行くのが得策です。
データの複雑性と上手く付き合うという側面もMLOpsの文脈で取り扱われる
データ分析・データ量への対応・高速処理という面ではデータエンジニアリングが既に解決していることが多い現状です。 一方でMLOpsでもサポートの薄い分野はまだまだ存在し、データ作成の効率化や、データの多様化への対応、データの意味に関わるデータ品質保証面については、今後の発展が期待される
データエンジニアリング的なMLOpsとしては、以下のような技術的話題があります。
- DWH・データレイクのような分析用途のデータ集約
- ETL・データパイプラインの整備(日々生成されるデータの加工や配線管理)
- データバリデーションとデータクレンジング(データの品質保証)
- データバージョン管理・実験データ管理
- アノテーションの作成・管理・効率化(Human-in-the-loopなデータ基盤)
- モデル性能監視(メトリクスだけでなく、データの追跡も必要
残る3, 4のポイントについては省略。
いくつかのポイントがまとまっていると、どこから手をつけていけるかの参考になる。
Yahoo! BEAUTYに今まで投稿されたヘアスタイル画像とタグのデータを9:1にtrainデータとtestデータに分割し、trainデータを使ってモデルを学習しました。評価は以下の3つで行いました。
- 分類評価: testデータに対するタグ分類評価(ノイジーなデータなので参考値)
- 検索ランキング評価: 検索のクリックログを使ってのオフライン評価
- 定性評価: 検索結果にモデルを反映させてみての定性評価
当初の評価値
また、タグごとの評価値や検索数などを確認し、106タグ全てではなく効果が高そうな28タグを選びABテストを実施することになりました。
元データがノイジーすぎたので、アノテーションデータの追加をしたとのこと。
最初のステップとして、自前のツールで1タグのみ1,000画像をサービス担当者にアノテーションし、検証 学習データにアノテーションデータを適用すれば精度が上がることを確認 精度向上の見込みがついたので、Yahoo!クラウドソーシングを利用し、27タグ×約1,000画像の合計26,700のデータをアノテーション
学習データへの反映と、学習時の重み調整(後述)もして、アノテーションした合計28タグの分類評価値のmAPを0.76から0.82に向上
手堅く改善してる感じがしてとても良い。
学習時の重み調整はいいアイデア。人工データ+リアルアノテーションデータの時にも使えるかもしれん。
なんとなくshopeeコンペのアイデアとか足せば良さそうにも見えるが...?
ブログの書き方がkaggleの会報共有っぽい笑
remote containersを入れた状態で以下の設定をすると、リモート上のコンテナが見れるようになる。
実際に使ってみた感想。
.ssh/configの設定では以下に注意する
リモートのdocker側では以下に気を付ける
クラウド使ってたらいかに気をつける
DDPはDPと異なりGPU間の通信が少ないので、こちらをつ可能が良いらしい。
A. DDP 使え(DDPの方が速いし PyTorch のドキュメントでも DDP が推奨されてるから)
実はDDPというものを最近知った。
こっちの方がいいんだなぁ
上の記事に関連してDDPのやり方を調べてみたが、かなり複雑な模様。
# https://qiita.com/meshidenn/items/1f50246cca075fa0fce2
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def train(rank, n_gpu, input_size, output_size, batch_size, train_dataset):
dist.init_process_group("gloo", rank=rank, world_size=n_gpu)
# create local model
model = Model(input_size, output_size)
# construct DDP model
model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.001)
sampler = DistributedSampler(train_dataset, num_replicas=n_gpu, rank=rank, shuffle=True)
rand_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
for data, label in rand_loader:
input = data.to(rank)
output = model(input)
label = label.to(rank)
# backward pass
loss_fn(outputs, label).backward()
# update parameters
optimizer.step()
def main():
n_gpu = 2
input_size = 5
output_size = 2
batch_size = 30
data_size = 100
dataset = RandomDataset(input_size, data_size)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
mp.spawn(train,
args=(n_gpu, input_size, output_size, batch_size, dataset),
nprocs=n_gpu,
join=True)
手早くやるにはDPの方が良いけど、効率を求めるとDDPに手を出さないとダメになりそう。
表題の通り。
一つの環境構築手段として参考になる。
ViT + SSLによる表現学習がうまくハマったとのこと。
今回ViTのSSLに使った手法はDiNO 他にも色々手法はあるんですが,全部試してダメだったので最終的にDiNOになりました. (EsViTはうまくいかず,BEiTはdVAEの学習がうまくいかなかった)
DiNOのチューニングはEpochのみで、 ~ 、SSLは長く学習すると精度が良い
長い時間SSLにかけてスクラッチでやると時間がかかるので、1600 epochまでやったものの追学習を行ったとのこと。
当初はFineTuneでViTを学習したがうまくいかず,学習初期のAccuracyがLinearSVMよりも明らかに低いことと,学習セットによってはLinearSVMの方が良くなっていることがありました.
FineTuneを行うでSSLによって得られた表現が失われてしまい,データセットにoverfitしているんじゃないかと思い,SSLの表現だけで3層のMLPを学習 3層のMLPでもすぐにoverfitするため,しないような工夫を加えることでFIneTuneを行ったモデルよりも精度が良くなります.
overfitを防ぐ工夫
- Dropout rate 0.7
- データ拡張をして画像表現を得る(RandomResizeCrop,反転,輝度,コントラスト)
- 活性化関数 ReLU
- 3層程度の浅いネットワーク
- LabelSmoothing
画像サイズも影響があるとのこと。
SSLの表現を保ったままFineTuneをすれば精度が上がるんじゃないかと考えました
マルチタスクでViTの学習もしたらしいのだが、よく理解できなかった。
DEiTと呼ばれる,でかいCNNの予測をViTで予測させる蒸留モデルがあるが,このモデルでは,CLSトークンで画像分類を行うのと同時にDistillation TokenでCNNの予測を予測します.
タスク数に合わせてCLSトークンを複製してそれぞれのトークンで分類や回帰をさせる ラベル分類とmeta分類,tech分類,ラベル回帰の4タスクを同時に学習させます. ただ,回帰はCLSトークンだけでは学習してくれないので,画像パッチも加え平均プーリングを行うことで学習が進みました. ログは残してないんですが,若干精度が上がります.
SSLで得た表現がfine tuneで破壊されるのは結構面白い。すぐ過学習するというのも面白い。
たかSSLは長い時間学習させないとダメなんだなぁ。
デフォルト値は実行時、モジュールがロードされた時ただ1度だけ評価されるため、動的な値に奇妙な振る舞いをもたらす可能性がある
ダメな例
# from https://qiita.com/nazeudon/items/7cade516dffe25e78d67
def bad_decode(data, default={}):
try:
return json.loads(data)
except ValueError:
return default
foo = bad_decode('bad data')
foo['staff'] = 5
bar = bad_decode('also bad data')
bar['meep'] = 1
foo {'staff': 5, 'meep': 1}
bar {'staff': 5, 'meep': 1}
良い例
# from https://qiita.com/nazeudon/items/7cade516dffe25e78d67
def good_decode(data, default=None):
if default is None:
default = {}
try:
return json.loads(data)
except ValueError:
return default
foo = good_decode('bad data')
foo['staff'] = 5
bar = good_decode('also bad data')
bar['meep'] = 1
foo {'staff': 5}
bar {'meep': 1}
default値の空ディクトが一度だけ評価されて、そいつが使い回されるので、変な挙動となる。
デフォルト値に変数を使うのは良くなさそう。
DeepExplainerは早いものの、環境構築が困難らしい。
単語ごとの貢献度を求めることができる。
shapでも単語ごとの貢献度だせたんだなぁ。
計算時間がちょっと長いから微妙か...?
作成された動画はこちら。
今回は2017年にGoogleから発表された Tacotron2 (+WaveGlow) を使用
YouTube上の動画の音声 と 自動字幕 を利用する 綺麗なデータではないので精度は落ちますが大量のデータを用意することができます。 YouTubeからできるだけBGMや笑い声などノイズが少ない動画を探してきていくつか利用しました。
今回は行っていませんが、音声に対してアノテーションして綺麗なデータを用意することで、より良い音声合成ができるようになります。
Audio-driven Talking Face Video Generation with Learning-based Personalized Head Pose を利用してフェイク動画を生成します。
音声と対象人物の10秒ほどの動画の二つを与えることで、対象人物がその音声を喋っている動画を生成することができます。
フェイク動画作成に関連したさまざまな手法がまとめられている。
さらに、参考にした記事も詳細に書かれているので、これを参考にしてオリジナルのものをつくれそう。
動画の生成はこのコードを使っているらしい。
10秒ほどで動画作れるのすごいな...
環境変数を使って接続の設定をすることは開発時には一般的にOKですが、本番環境では非常にお勧めできません。
環境変数で秘密情報を扱う問題点
- 環境変数はプロセスで暗黙的に利用可能のため、アクセスや利用を追跡することが困難
- アプリケーションが環境変数をデバッグやエラー報告のために出力
- 環境変数は子プロセスに引き継がれるため、意図しないアクセスが可能になる。これは、最小権限の原則を破る。
- アプリケーションがクラッシュした時、デバッグのために、環境変数をログファイルに保存することは一般的であり、これはディスク上に平文の秘密情報があることを意味する
コンテナオーケストレーションツールであるDocker SwarmやKubernetesの機能を使って秘密情報を扱うことが推奨される
Dockerfile(docker-compose.yamlではない)でENVやARGを使い環境変数を設定すると、Docker imageに変数の中身まで焼き付いてしまうため注意が必要
Docker SwarmやKubernetesを使わない方法だと、Docker Composeには秘密情報を扱うsectetsがある 詳細はhttps://docs.docker.com/compose/compose-file/compose-file-v3/ を参照
# from https://qiita.com/myabu/items/89797cddfa7225ff2b5d
version: "3.9"
services:
app:
image: busybox:latest
command: cat -n /run/secrets/my_secret
secrets:
- my_secret
secrets:
my_secret:
file: my_secret.txt
秘密情報を書いたファイルをsecretsで指定すると、secretsで指定した名前(例では、my_secret)で、コンテナの/run/secrets/以下にマウントされます。
機密情報を書いたファイルは必要だが、環境変数よりかはいいのかな..?
機密情報のファイルは暗号化しておく必要があるか。
pythonのdatetimeライブラリも提供している時間や日付の取り回しを遥かに良くした便利ライブラリ。
datetimeと異なり、time deltaや時間の換算がめちゃくちゃ簡単な模様。
# from https://qiita.com/tand826/items/8076ebf90941fd78beb3
# 自動でAsia/Tokyoを設定してくれるのでnow関数は引数不要?
>>> now = pendulum.now()
>>> print(now.timezone)
Timezone('Asia/Tokyo')
# 2019年5月1日0時0分0秒をAsia/TokyoとUTCで表示
>>> in_utc = pendulum.datetime(2019, 5, 1, 0, 0, 0)
>>> in_utc.timezone
Timezone('UTC')
>>> print(in_utc)
2019-05-01T00:00:00+00:00
>>> in_jpn = in_utc.in_timezone("Asia/Tokyo")
>>> in_jpn.timezone
Timezone('Asia/Tokyo')
>>> print(in_jpn)
2019-05-01T09:00:00+09:00
# python:5/2を表示する
>>> in_utc = pendulum.datetime(2019, 5, 1, 0, 0, 0)
>>> dt = in_utc.in_timezone("Asia/Tokyo")
>>> tomorrow = dt.add(days=1)
>>> print(tomorrow)
2019-05-02T09:00:00+09:00
# from https://qiita.com/tand826/items/8076ebf90941fd78beb3
>>> dur = pendulum.duration(days=15)
>>> print(dur)
2 weeks 1 day
>>> print(dur.weeks)
2
>>> print(dur.days)
15
>>> print(dur.in_hours())
360
挙動がわかりづらいdatetimeと違って、感覚的に使えるのがよい。
この記事ではNVIDIA Isaac SimとTransfer Learning Toolkitを用いて、シミューレションデータによる検出モデルの作り方が詳細に書かれている。
Omniverse で学習データを作成する際に重要な要素は RT コアによる反射光の再現です。 ほとんどの CNN はテクスチャを認識するのですが、反射光が忠実に再現されるとテクスチャの表現が現実世界により近づくので、シミュレーションと現実世界のギャップが小さくなります。 NVIDIA Isaac Sim では現実世界のデータとシミュレーションの差分を埋めるために、ドメインランダマイゼーションを使用しています。
ちなみに、その他のシミュレーションギャップを埋める手法には以下がある。
3 つのアプローチがあります。
- ドメイン適応: 現実のデータを用意して、シミュレーションで学習したモデルを現実のデータで適応させる
- システム同定: シミュレーションをできるだけ現実に近づける
- ドメインランダマイゼーション: 多様なランダム化されたシミュレーションバリエーションを作成し、現実が Deep Learning にとって単なる異なるバリエーションと見なされるようにする
NVIDIA Isaac Simでドメインランダマイゼーションを用いるには以下のようにするらしい。
この記事によると ドメインランダマイゼーションというのはさまざまな物理パラメータをランダムに設定することのよう。
学習時には毎回微妙に異なる条件を設定する、「ドメイン・ランダマイゼーション」と呼ばれる技術を用いているそうです。 この「微妙に異なる条件」というのは、立方体や背景の色が異なるというだけでなく、ロボットハンドの動作スピードや立方体の重さ、立方体とロボットハンドの間に起きる摩擦係数に至るまで、あらゆる要素をランダム化してDactylを学習させたとのこと。
シミュレーション x 機械学習の相性はやはりバチクソいいんだな。
シミュレーションと現実のギャップを埋める方法に工夫が今後は重要なのかも。
ドメインランダマイゼーションも適当にやると学習時間べらぼうに伸びるしなぁ。
ある程度はシミュレーションであとは現実データでfine tuneとかもありかも?
その際、破滅的な忘却を起こさないようにするための術とかも工夫しないとか...?
考えるべきことは多い...
Pytorch-Lightning, Gydra, Mlflow, Poetryを用いた機械学習テンプレート
ファイル構成はこんな感じ。
参考にしたい。
google 謹製のJAX / Flaxの特徴やサンプルコードの紹介
既存のDL系ライブラリとの比較もされている。
trainingが早くなるのであれば嬉しい。
今後の隆盛に期待
herokuではアプリの容量が500MB以下という制約があるが、普通のpytorchは重いというこでの、対応。
対応方法を見てみると「pytorchのGPU対応実装が大きく、これはheroku上では不要なため、CPU用のバージョンを利用すると良い」という記事を見つけた。(参照記事)
ただし、インストールするさいは --find-links オプションをして以下のように指定するする必要がある。 この記事を確認して、herokuでは torch==1.8.0+cpu を使用することにした。
# from https://zenn.dev/piruty/articles/2315fd9f09103b0738ff
requirements.txt
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.8.0+cpu; python_full_version >= "3.6.2"
CPUにすると結構軽くなるんだなぁ
対策内容は以下
MIDDLEWAREの最初で、直近60秒以内に35回アクセスで404ページ。45回アクセスでブラックリスト。 ※但し、googlebotはホワイトリストに登録する ・ブラックリスト、ホワイトリスト登録で、ライン通知する ・ライントークンはここで発行 https://notify-bot.line.me/ja/
DjangoのMiddlewareMixinで以下のように実装したとこのと。
# from https://qiita.com/kin292929/items/92aa0f6f5e1fbca553ee
(settings.py)
# 一番最初になるように設定。
MIDDLEWARE = [
'base.middle.middle.AppMiddle',
'django.middleware.security.SecurityMiddleware',
....
# ライントークン
LINE_NOTIFY_TOKEN = 'XXXXXXXXXX'
# from https://qiita.com/kin292929/items/92aa0f6f5e1fbca553ee
(base.middle.middle.py)
class AppMiddle(MiddlewareMixin):
@staticmethod
def process_request(request):
# アクセスログ登録
if not settings.DEBUG: AccessLogService.insert_access_log(request)
else: return
request_util = RequestUtil(request)
ip = request_util.get_ip()
control_ip_list = cache.get(CONTROL_IP_LIST, CONTROL_IP_LIST_DEFAULT)
# ホワイトリスト
if ip in control_ip_list[WHITE_IP_LIST]: return
# ブラックリスト
if ip in control_ip_list[BLACK_IP_LIST]:
# ブラックリスト対象のIPで期限を確認
if cache.get('black_ip_' + ip) is None:
# 期限切れならブラックリストから除外
control_ip_list[BLACK_IP_LIST].remove(ip)
cache.set(CONTROL_IP_LIST, control_ip_list)
else:
raise Http404("not found")
ip_time_list = cache.get(ip, [])
time_temp = time.time()
# 60秒前の更新記録を削除
while ip_time_list and (time_temp - ip_time_list[-1]) > 60:
ip_time_list.pop()
ip_time_list.insert(0, time_temp)
cache.set(ip, ip_time_list, timeout=60)
if len(ip_time_list) > 45:
control_ip_list[BLACK_IP_LIST].append(ip)
# 該当IPで365日使用不可とする
cache.set(CONTROL_IP_LIST, control_ip_list, timeout=60 * 60 * 24 * 365)
cache.set('black_ip_' + ip, '', timeout=60 * 60 * 24 * 365)
lineUtils.send_line_notify(f'ブラックリストに登録しました IP:{ip}')
if len(ip_time_list) > 35:
if is_google_bot(ip):
# googlebotのIPをホワイトリストに登録する
control_ip_list[WHITE_IP_LIST].append(ip)
cache.set(CONTROL_IP_LIST, control_ip_list, timeout=60 * 60 * 24 * 365)
lineUtils.send_line_notify(f'ホワイトリストに登録しました IP:{ip}')
else:
raise Http404("not found")
# googlebotか判定する
def is_google_bot(ip):
# googlebotか判定
try:
host = subprocess.run(['host', ip], stdout=subprocess.PIPE).stdout.decode().replace('\n', '')
return host.endswith(('googlebot.com', 'google.com'))
except:
return False
公開して間もないサイトでも攻撃されることがあるようなので、しっかり対策していきたい。
表題の通り。
かなり詳細に書かれているので、記事をまんまコピーすれば、自分用のbotを作ることも可能。
ありがたい。
天望デッキおよびソラマチ商店街に設置したドローンポートから本ドローンを完全自律飛行させ、以下の項目を検証し、有用性を確認しました。 (1)スムーズなフロア間移動 (2)巡回ルート上のチェックポイント通過 (3)飛行中のリアルタイム映像配信 (4)飛行中のAIによる人物検知
自動巣走行にはVSLAMを用いているようだ。
ALSOKは、画像巡回を可能にするドローンを開発しました。本ドローンには以下の特徴があり、屋外およびGPSによる飛行が困難な屋内において人の手を介さない完全自律運用が可能です。 (1)搭載した4Kカメラを用いて全方向の画像処理(Visual SLAM※)をリアルタイムに行い、屋内での完全自律飛行が可能 (2)離隔距離最小50cmの狭い空間での飛行が可能 (3)障害物を自動で回避可能 (4)充電ポートに自動で離発着および充電が可能 (5)リアルタイムに遠隔地へカメラ映像を送信可能 (6)AIエッジコンピュータを搭載し、ドローン単体で人物などの検出が可能
屋内での自立飛行って結構すごいことでは?
Data Parallelの計算効率を維持しながら、計算に用いないパラメータを削除することによってメモリの使用効率を上げる方法がZero Redundancy Model(ZeRO)です。また、ZeROでパラメータを削減すると同時に、GPUのメモリを必要とする計算の一部をCPU及びメインメモリで計算する(CPU-Offload)ことに依ってGPUのメモリを節約するZeRO-Offloadがあります。
ここでは、各GPUの担当箇所以外のoptimizer stateの削除や必要のない勾配、必要のないパラメータなどを削除していく。
ZeRO-Offloadはパラメータの更新をCPU及びメインメモリ上で行うことにより、optimizer stateをGPUメモリから削減します。
パラメータの更新をCPUで行うことになり、計算速度の劣化が気になるところかもしれませんが、CPUでの処理の高速化のための工夫もしています。例えばSIMD命令による並列化、loop unrolling、OMPのmultithredingによるコア並列化などです。これにより、学習の1イテレーションあたりではほぼ同等の速度で実行できる
ブログ中では、bert-baseのがくしゅをTitan RTX 2枚のサーバーで学習させることに成功している。
(元の論文だと、16 TPUで学習)
今後でかいモデルを学習させる時に必要になってくるかもしれない。
議会中の“スマホいじり”を可視化するAIツール。機械学習と顔認証技術でサボり議員を暴露
コメント
一歩間違えればディストピアにつながりかねないが面白い。
出典
元記事