yoheikikuta / paper-reading

Notes about papers I read (in Japanese)
156 stars 4 forks source link

[2017] Focal Loss for Dense Object Detection #19

Open yoheikikuta opened 5 years ago

yoheikikuta commented 5 years ago

論文リンク

https://arxiv.org/abs/1708.02002

公開日(yyyy/mm/dd)

2017/08/07

概要

Focal Loss を導入して RetinaNet という one stage の高性能の object detection model を作ったという話。 one stage detector は objectness を測る直接の仕組みがないため大量の anchor に対して loss を計算するが、classification を考えるとこの anchor のうちほとんどが背景のような簡単な sample になっている。しかしそのような簡単な sample でも数が多いので total loss に占める割合は大きくなってしまい、興味ある sample の loss が埋もれてしまうという問題がある。 そこで、モデルの出力確率値が大きいものに関しては cross entropy の値が小さくなるような factor をかける Focal Loss を提案し、効率的に学習がなされるようにして、two stage detector をも上回るような結果を出した。Focal Loss は $ - (1-p_t)^γ \ln(p_t) $ のような形。 モデルの中身は backbone network から特徴量を抽出して Feature Pyramid Network を構築し、それぞれの pyramid の特徴量で classification と bbox regression を実施するもの。

yoheikikuta commented 5 years ago

最近の object detection 系の話をちゃんとフォローしてなかったので、その辺の知識をつけるため。 どこかで概要を見かけて、 loss function の設定に工夫があるみたいなのにちょっと興味があるので、ちゃんと自分で読んでみることにした。

yoheikikuta commented 5 years ago

まず、two stage detector と one stage detector を整理しておく。

yoheikikuta commented 5 years ago

今回の RetinaNet は one stage detector だ。

上記の性質を考えた時に問題となる点として、「大量の anchor に対して classification をして loss を算出するが、特に何も写ってないような easy な anchor が多くて、そいつらによって loss が dominate されるので意味ある loss の計算にならない」という点が挙げられる。

ちょっと ↑ の書き方だけだと分かりづらいかもしれないので、もう少し具体的な想像をしておく。

例えば anchor が 100 個あるときの loss を考える。 anchor それぞれの cross entropy の和が loss になるわけだが、このうち 90 個は全然 object を写してないただの背景の場合にどうなるだろうか。 残りの 10 個がちゃんと object が写っているものでこれは分類が難しいのでモデルの出力確率値が低くて(例えば 0.2)、 10 (- ln(0.2)) = 16 くらい loss に寄与するとする。 一方で 90 個の方は簡単なのでモデルの確率出力値が高い(例えば 0.8)としよう。この場合 loss への寄与は 90 (- ln(0.8)) = 20 くらいになる。こっちの方が高くなってしまう。 普通に anchor を作るとさらに比は悪くなって 1:1000 とかにもなるので、そこまで意味のない anchor に対する loss のせいで学習が我々の意図したものと食い違ってしまう。

先行研究では難易度が高い anchor を集めてくる Online Hard Example Mining などの手法が提案されていた。

この論文では、それと逆の発想をして、簡単な anchor の loss への寄与を抑えるように loss function を変形するだけで目的を達成しようというものだ。

yoheikikuta commented 5 years ago

ということでこの論文の主題。 classification の loss を通常の Cross Entropy (CE) から Focal Loss (FL) に変更するというものである。 なんと $ - (1-p_t)^γ \ln(p_t) $ として、p が 1 に近い場合に値が小さくなるように γ の値を大きくする、というシンプルなものである。 ここで、p_t は二値分類の y=1 のときが p で y=-1 のときが (1-p) という定義である。

γ の値が大きい場合に CE との相対的な差が大きくなって、 easy anchor の loss への寄与がより小さくなっていることが見て取れる。

これによって、簡単な classification の寄与が抑えられるので、興味ある anchor の classification をより効率的に学習していけるという算段である。

yoheikikuta commented 5 years ago

より tuning した形として overall factor として overall factor を入れた次の形を採用している。 α_t は y=1 のときが α で y=-1 のときが 1-α を取る。

この α は Focall Loss の意味ではそこまで重要でなく、これを導入すると少し性能がよくなるという程度のもの。

パラメタの値としては γ=2, α=0.25 を採用している。

classification のときは、各 anchor に対して class 毎に sigmoid で確率を計算して Focal Loss を計算し、全部の和をとる(正確には normalization として positive anchor の数で割る)という処理をして全体の loss を計算する。

yoheikikuta commented 5 years ago

モデルの詳細は以下の要素が分かっていればよく、しかもそれらは従来から知られているものである。

この図だけだとモデルの詳細としては色々落ちている部分がある。

まず、ResNet の方は最初に Conv2D や BN があり、そのあとに residual block の塊が例えば ResNet50 の場合は [3, 4, 6, 3] という数連なっている構造になっている。 ここでは residual block を C2, C3, C4, C5 としたときの C3, C4, C5 を使っている。

feature pyramid net の方は C5 -> P5 を Conv2D をかますことで作り、それを feature pyramid のよくある方法で P5 -> P4 -> P3 と作っている。 ここは実装を見た方が早そう。 fizyr/keras-retinanet では ここ が該当箇所になっている。 また、C5 からさらに 3*3 Conv2D などをかましていって P6, P7 も作っている。

あとは P3-P7 のそれぞれで図の右にあるように class subnet と box subnet を生やして classification と bbox regression を実施するということものになっている。

実際にどういう anchor を作るかというのは従来通りの手法。

あとはどの anchor に対して classification や regression を考えるのか、ということに関して classification のスコアでの threshold や Non Max Suppression などを使っている。 classification のスコアを使っているのは threshold は one stage detector のところで説明した事情である。 さらに Non Max Suppression で被りが大きい bbox を落として、さらに ground truth との IoU が大きいものが regression の計算対象になる。 この辺りの気持ちが分かっていれば、あとは 実装 を見れば細かいところまで理解できる。

yoheikikuta commented 5 years ago

しかしこうして書いてみるとちゃんと理解しようと思ったら結構大変だよな。

自分は昔 Faster R-CNN をしっかり読んだのですんなり理解できたが、object detection を勉強しようという人がいきなりこの手の論文から着手すると理解するまでに結構頑張る必要がありそう。

まあ object detection の必須構成要素なので、一回ちゃんと理解すれば two stage detector の Mask R-CNN とかもそこまで苦労せずに理解できると思う(自分はまだちゃんと読んでないけど、眺めた感じそう思った)。

yoheikikuta commented 5 years ago

備忘のためのメモを残しておく。

上述のように backbone network から C3, C4, C5 を取得しているが、実装を見るとその部分がちゃんと取得できているのか理解するまでにちょっと苦戦したのでそのメモ。


これ久しぶりに見たら特に難しいこともなく最初に読んだ時に何に苦戦したのか覚えてない...

yoheikikuta commented 5 years ago

あとは実験結果。 いくつか実験をしているが、抜粋。

yoheikikuta commented 5 years ago

性能はなかなか良くて、one stage detector の中では最上位と思っていいだろう。 自分で学習させて使ってみた感触としては、YOLOv3 と同等以上かなという感触である。