Toma0916 / GlobalWheatDetection

3 stars 0 forks source link

DANNの実装 #150

Closed Toma0916 closed 4 years ago

Toma0916 commented 4 years ago

torchvisionの中身を弄ってbackboneのfeatureもtrain時に返すようにする。

セットアップ手順

https://github.com/pytorch/vision/blob/master/torchvision/models/detection/generalized_rcnn.pyGeneralizedRCNNを次のように書き換える

    @torch.jit.unused
    def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses

        return detections

    @torch.jit.unused
    def eager_outputs(self, losses, detections, features):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses, features

        return detections

にする。

forwardの後部の

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
            return (losses, detections)
        else:
            return self.eager_outputs(losses, detections)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
            return (losses, detections)
        else:
            return self.eager_outputs(losses, detections, features['pool'])

にする。

Toma0916 commented 4 years ago

backboneにdomainのロス足せたわ

configとかロスのスケールとかはまだだけど学習回してみます

Toma0916 commented 4 years ago

サクッと見た感じちゃんとdomain loss (cross entropy lossの符号反転したやつ)いい感じに下がってそう

明日に期待

Toma0916 commented 4 years ago

@kminoda

オレンジ: 最近 trainしてたやつ 青: オレンジにdomain lodd 足した + cutmix抜いた(ドメインわからなくなると思ったので)

スクリーンショット 2020-06-25 10 39 26

スクリーンショット 2020-06-25 10 39 38

スクリーンショット 2020-06-25 10 39 45

domain loss はあまり汎化できてないんだけど、validの他のloss見る感じ結構可能性感じるぞ

Toma0916 commented 4 years ago

domain loss計算の部分、いま最低限の実装だからもっとやりようはあると思う

kminoda commented 4 years ago

実装早すぎて草

Toma0916 commented 4 years ago

たぶん同じ感じの雰囲気でOHEMも出来る

実装の速さには感動してほしい

Toma0916 commented 4 years ago

validのロス、trainに比べると高いけどSoftmax Cross Entropyで-2.5なのでどっちにしろクラスは全然当てれていない気もするね(上手く行っているという意味

kminoda commented 4 years ago

これちなみにvalid=arvalis_2?

Toma0916 commented 4 years ago

randomです

Toma0916 commented 4 years ago

arvalis_2でも回してみる

Toma0916 commented 4 years ago

回し始めた

Toma0916 commented 4 years ago

arvalis_2ってどれくらい行けば標準?

@kminoda

kminoda commented 4 years ago

俺はwbf込みで0.77くらいいってたかな

Toma0916 commented 4 years ago

その時のconfig探してみたけど見つからないから貼ってくれたりしますか?

kminoda commented 4 years ago

あごめんpush忘れた他

kminoda commented 4 years ago

https://github.com/Toma0916/GlobalWheatDetection/tree/master/experiment/harvest/koji/exp_006 これです(mlrunsの方も更新した)

Toma0916 commented 4 years ago

さんこす それと対照実験してみるかな

Toma0916 commented 4 years ago

config_all.jsonってなんだっけ?もともとあった?

configと違うところもあるみたいだけど

kminoda commented 4 years ago

俺が勝手につけた submit用に(ほぼ)全データで学習したやつ

Toma0916 commented 4 years ago

CycleGANもコミコミで一旦mergeしますね

Toma0916 commented 4 years ago

@kminoda そいえばこれ、一番上の変更しておいて!

kminoda commented 4 years ago

DANN回す時ってことよね おっけー

Toma0916 commented 4 years ago

使わなくても変えておかないとpipeline動かんと思う

kminoda commented 4 years ago

マジか おっけー