Z-ZHHH / CVPR23-DML

8 stars 1 forks source link

如何重現cifar100的結果 #3

Open jspss95082 opened 4 months ago

jspss95082 commented 4 months ago

您好,目前我的研究需要重新訓練出DML+在cifar100的結果,但是使用你們的code並使用Focal loss及Center loss訓練wide resnet後test的結果與你們提供的ckpt test結果相差甚遠,可以請問你們是怎麼訓練出你們提供的weight的嗎?

Z-ZHHH commented 4 months ago

您好,很抱歉造成困擾,我在代碼整理中沒有顧及到訓練代碼,其中訓練代碼需要稍作修改。 主要問題可能在Wide ResNet模型加載過程中默認是

def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, use_norm=False, feature_norm=False)

是沒有使用任何的歸一化操作,兩個參數的含義分別是對分類器歸一化和對特征歸一化,可以參考推理代碼對train.py作修改 https://github.com/Z-ZHHH/CVPR23-DML/blob/68c117f0e2960d1ebe5e6e3e40cafe3e25f46fb2/test_DML%2B.py#L112C1-L132C1 例如,centerloss訓練需要use_norm=True, feature_norm=True等。

jspss95082 commented 4 months ago

您好,很抱歉造成困擾,我在代碼整理中沒有顧及到訓練代碼,其中訓練代碼需要稍作修改。 主要問題可能在Wide ResNet模型加載過程中默認是

def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, use_norm=False, feature_norm=False)

是沒有使用任何的歸一化操作,兩個參數的含義分別是對分類器歸一化和對特征歸一化,可以參考推理代碼對train.py作修改 https://github.com/Z-ZHHH/CVPR23-DML/blob/68c117f0e2960d1ebe5e6e3e40cafe3e25f46fb2/test_DML%2B.py#L112C1-L132C1 例如,centerloss訓練需要use_norm=True, feature_norm=True等。

您好,首先先感謝您的快速回覆

關於你提到的這部分,我先前就有對這部分做修改,我有將centerloss的分類器部分設use_norm=True, feature_norm=True,且focal loss分類器use_norm=True, feature_norm=False,依照你們在test_DML+.py的設定。另外我也有發現在common/FocalLoss.py 的forward部分也有錯誤,但是在修正以上問題後訓練依舊無法在200epochs時達到與你們相近的結果,想請問你們有沒有其他部分與github上的code有不同?謝謝

Z-ZHHH commented 4 months ago

測試代碼是完全一樣的,訓練代碼我記憶中是一樣的。 結果不同可能是超參數的原因,在實驗過程中可調節的超參數包括Focal Loss的 \gamma 以及CenterLoss的權重值(如 reference ),另外還包括訓練的輪數,可能也和種子點相關。 在這段時間DDL後我會嘗試重新整理訓練代碼,對代碼的問題感到抱歉。 短時的修改有僅使用CE損失餘弦分類器會對OOD基線方法有一定的促進(或者使用logitnorm訓練代碼,DML的代碼結構即基於LogitNorm),結合CenterLoss和FocalLoss的訓練部分可能會需要超參的調整。

jspss95082 commented 4 months ago

測試代碼是完全一樣的,訓練代碼我記憶中是一樣的。 結果不同可能是超參數的原因,在實驗過程中可調節的超參數包括Focal Loss的 \gamma 以及CenterLoss的權重值(如 reference ),另外還包括訓練的輪數,可能也和種子點相關。 在這段時間DDL後我會嘗試重新整理訓練代碼,對代碼的問題感到抱歉。 短時的修改有僅使用CE損失餘弦分類器會對OOD基線方法有一定的促進(或者使用logitnorm訓練代碼,DML的代碼結構即基於LogitNorm),結合CenterLoss和FocalLoss的訓練部分可能會需要超參的調整。

謝謝您的回覆,想請問關於Focal Loss的 \gamma 以及CenterLoss的權重值,你們有任何印象要往哪裡修改嗎,另外使用CE損失餘弦分類器的部分,應該要取代focal loss的classifier還是center loss的classifier呢?謝謝

Z-ZHHH commented 4 months ago

Focal Loss的 \gamma 我试验了1、2、3、4共四个值,CenterLoss的權重值使用了0-0.5中的多个值,多是0.1以下的值,CE損失餘弦分類器的部分指只使用CE损失训练,模型使用余弦分类器,从头训练。也可以用repo里的两个ckpt作初始化。