kiccho1101 / kaggle_global_wheat_detection

My solution for Kaggle Global Wheat Detection Competition.
1 stars 0 forks source link

Debug model forwarding #7

Closed kktsubota closed 4 years ago

kktsubota commented 4 years ago

When running notebook/cv.ipynb, we got a following error.

Fitter prepared. Device is cpu

2020-07-12T22:45:16.055624
LR: 0.0002
Train Step 0/675, summary_loss: 0.00000, time: 0.26056
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-4-42e71780afd2> in <module>
     20 )
     21 
---> 22 fitter.fit(train_loader, valid_loader)

~/Documents/src/github.com/kiccho1101/kaggle_global_wheat_detection/src/factories/fitter.py in fit(self, train_loader, valid_loader)
     59 
     60             start = time.time()
---> 61             summary_loss = self._train_one_epoch(train_loader)
     62 
     63             self.log(

~/Documents/src/github.com/kiccho1101/kaggle_global_wheat_detection/src/factories/fitter.py in _train_one_epoch(self, train_loader)
    109             self.optimizer.zero_grad()
    110 
--> 111             loss, _, _ = self.model(images, bboxes, labels)
    112             loss.backward()
    113             summary_loss.update(loss.detach().item(), batch_size)

~/.local/share/virtualenvs/kaggle_global_wheat_detection-Hxp-F21z/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

TypeError: forward() takes 3 positional arguments but 4 were given
kktsubota commented 4 years ago

原因は特定できたと思う。ライブラリのバージョンの違いだと思う。

直接的な原因は src/factories/model.py が返すDetBenchTrainのインスタンスの forward

    def forward(self, x, target):
        class_out, box_out = self.model(x)
        cls_targets, box_targets, num_positives = self.anchor_labeler.batch_label_anchors(
            x.shape[0], target['bbox'], target['cls'])
        ....

と実装されているせい。bboxes, labelsじゃなくて target['bbox'], target['cls']。 ライブラリのバージョンの違いでAPIが変わっていると思われる。 他にも直す必要があるらしく、直しきるよりは、元コードが使っているバージョンを使う方が楽そう。 kaggle上で公開されているコードをダウンロードしてきてそれをレポジトリに直接追加 / submoduleで追加するのが良さそう。

kiccho1101 commented 4 years ago

なるほど、effdetのバージョンが違ってたのか。

kiccho1101 commented 4 years ago

ちなみに元コードはこれ https://www.kaggle.com/kiccho11/training-efficientdet/edit

kktsubota commented 4 years ago

そのURL見れない…

ちなみにefficientdetではなく、オレオレ実装のtimm-efficientdet-pytorchのようね。 →違うか、githubで公開されているのがそれってことか。 https://www.kaggle.com/shonenkov/timm-efficientdet-pytorch image

kiccho1101 commented 4 years ago

なるほど。

Screen Shot 2020-07-18 at 10 30 30

1つ目のセルのこのプロセスを実行する必要があるのかな

kiccho1101 commented 4 years ago

1つ目のセルのこのプロセスを実行する必要があるのかな

だめだった

kktsubota commented 4 years ago

ああ、誤解が生じてしまったかもしれない。 間違いなくAPIが変わっているのは effdet で実装されている DetBenchTrain です。 他のライブラリ (timm, pycocotools) のバージョンはどうある必要があるかは分からない。

kktsubota commented 4 years ago

言っていることが逆になって申し訳ないけどREADMEを見る限り最新モデルの方が0.012も高いので、これで頑張るならeffdet は最新のものを使うべきかも… https://github.com/rwightman/efficientdet-pytorch

D0 モデルに関して、Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] での評価が 2020-05-04: 0.324 2020-05-27: 0.331 2020-06-04: 0.336251

kiccho1101 commented 4 years ago

なるほどなるほど。ありがとう。理解した。 最新実装のDetBenchTrainを使うのがそんなに難しそうでなければ、そっちのほうが良さそうだね

kiccho1101 commented 4 years ago

その場合はforwardの際にモデルに入力する形式が変わってるから、https://github.com/kiccho1101/kaggle_global_wheat_detection/blob/f9552d46117a084be91ccb53e9aac89569589d9a/src/factories/fitter.py#L103-L111 この部分を変更するべきなのかな?

kktsubota commented 4 years ago

まずはそこ。 そこを直したらlossの形式も辞書になったよってエラーが出て、 そこを直したらdatasetでbboxが空のときのエラーが出た -- effdet関係あるんだろうか? という現状

kiccho1101 commented 4 years ago

なるほど。まあgithubからSOTAモデル持ってきて変数の形式合わせて動くようにするっていうプロセスはこれからも無限にありそうだから、やってみる〜 githubにレポジトリにサンプルコードみたいなのあるはずだから、探してみる

kktsubota commented 4 years ago

ありがとうー、よろしくー。

kiccho1101 commented 4 years ago

ごめん、もしわかったらでいいんだけど、だいたいこのファイル見ればわかりそうとかあったりする?

kktsubota commented 4 years ago

そこまで丁寧にメンテナンスされているようなライブラリではないため、 エラーが出たときにはライブラリの中身を見ながらデバッグするよりないのかも。

一応今出ているエラーの話はnotebookの下の方で話がされている模様。 ページ内で "I thought the effdet package you are using is a months old"で検索してもらえると。 https://www.kaggle.com/shonenkov/training-efficientdet

validationのimg_scaleのエラーの回避方法についても描いてあるけど、 関数をさかのぼってみると予測用にresizeしていたbounding boxサイズを元画像の解像度に戻すような処理をするパラメータであることが分かるため、元画像の解像度である1024 / 学習に利用した解像度 の値を入れておけばよいのかな… https://github.com/rwightman/efficientdet-pytorch/blob/5332cfac6d82f135c080986b8df116104450afdb/effdet/anchors.py#L219 関数をたどるときはgithubのsourcegraphが便利。

bounding boxがemptyな問題とかはこのあたりが参考になるのかも? https://www.kaggle.com/c/global-wheat-detection/discussion/153191

kiccho1101 commented 4 years ago

ありがとう! notebookのコメントの方法で,fitまではエラーなく回るようになった

kiccho1101 commented 4 years ago
Screen Shot 2020-07-18 at 12 45 54

コメント欄にあったちっちゃいバグ https://github.com/kiccho1101/kaggle_global_wheat_detection/pull/8/commits/38f14d7844dc32e23f346b3ba23c45cbbb330807