Toma0916 / GlobalWheatDetection

3 stars 0 forks source link

kfold #99

Closed kminoda closed 4 years ago

kminoda commented 4 years ago

train.pyの形だいぶ変えたので、多分conflictしそう

Toma0916 commented 4 years ago

ワイtrainほぼいじってないよ

kminoda commented 4 years ago

じゃおっけー

kminoda commented 4 years ago

だいたい終わった が、そもそもk-foldってこれであってますよね?

mlflowはとりあえず別々にこんな感じに出るようにしといた

スクリーンショット 2020-05-25 14 32 49
kminoda commented 4 years ago

ちなみにsourceのやつはまだ実装してないのでおいおい

Toma0916 commented 4 years ago

後で見ます!!

Toma0916 commented 4 years ago

見ます

Toma0916 commented 4 years ago

splitで乱数のシード入力してたけど、sklearnはそれない場合numpyの方のseedを参照するので消した。

source_splitの時にGroupKFoldを追加した。 ちゃんとsourceごとにいい感じに分けてくれてることを確認した。 引数でvalid_sourcesを受けるように書いてあるがこれはダミー変数でKfold時は全部sklearn.model_selection.GroupKFoldにまかせている。

def source_split_kfold(dataframe, fold_k, valid_sources=None):
    """
    ignore valid_sources when applying K-folds
    """
    image_source = dataframe[['image_id', 'source']].drop_duplicates()
    image_ids = image_source['image_id'].to_numpy()
    sources = image_source['source'].to_numpy()

    gkf = GroupKFold(n_splits=fold_k)
    split= gkf.split(image_ids, groups=sources)
    res = []
    for (train_ids, valid_ids) in strain_idsplit:
        res.append((image_ids[train_ids], image_ids[valid_ids]))
    import pdb; pdb.set_trace()
    return res
kminoda commented 4 years ago

神 後でみます

Toma0916 commented 4 years ago

もうちょっとだけ直しま

Toma0916 commented 4 years ago

スクリーンショット 2020-05-25 19 02 45

logのところ、foldをrunnameの後ろにつけるようにした。 そしてrunnameの後のランダム文字列ですがシード固定してるからコード同じだと実行時ごとに不変だったので、UTCの情報用いて擬似的に乱数を生成する感じにしました。

kminoda commented 4 years ago

すばらしい あとでみます

Toma0916 commented 4 years ago

大丈夫だったらマージしておけ