TideDancer / interspeech21_emotion

93 stars 19 forks source link

Can not reproduce the result,Please help #5

Closed Gpwner closed 2 years ago

Gpwner commented 2 years ago

As is seen in README,

For each fold, use the other 9 sessions as training, and test on the selected session. For example, for the fold 01F, we use 01F as test set and remaining 9 sessions as training set.

Train a model on 9 sessions cost so much time when I just have 2 Nvidia V100.So I change the code in run_emotion.py from :

train_dataset = datasets.load_dataset('csv', data_files='iemocap/iemocap_' + data_args.split_id + '.train.csv', cache_dir=model_args.cache_dir)['train']
val_dataset = datasets.load_dataset('csv', data_files='iemocap/iemocap_' + data_args.split_id + '.test.csv', cache_dir=model_args.cache_dir)['train']

to

train_dataset = datasets.load_dataset('csv', data_files='iemocap/iemocap_*.train.csv',
                                              cache_dir=model_args.cache_dir)['train']
val_dataset = datasets.load_dataset('csv', data_files='iemocap/iemocap_*.test.csv',
                                            cache_dir=model_args.cache_dir)['train']

And keep only these files in iemocap:

iemocap_01F.test.csv
iemocap_01F.train.csv
iemocap_01M.test.csv
iemocap_01M.train.csv

Then I call the predict API in the end like this:

print("######################################Starting to predict on test set #######################################")
predictions, label_ids, metrics = trainer.predict(val_dataset)
print(metrics)

Here are my run.sh:

export MODEL=wav2vec2-base
export TOKENIZER=wav2vec2-base
export ALPHA=0.1
export LR=1e-5
export ACC=4 # batch size * acc = 8
export WORKER_NUM=32

python run_emotion.py \
--output_dir=output/tmp \
--cache_dir=cache/ \
--num_train_epochs=100 \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="4" \
--gradient_accumulation_steps=$ACC \
--alpha $ALPHA \
--dataset_name emotion \
--split_id 01F \
--evaluation_strategy="steps" \
--save_total_limit="1" \
--save_steps="500" \
--eval_steps="500" \
--logging_steps="50" \
--logging_dir="log" \
--learning_rate=$LR \
--model_name_or_path=facebook/$MODEL \
--tokenizer facebook/$TOKENIZER \
--fp16 \
--preprocessing_num_workers=$WORKER_NUM \
--gradient_checkpointing true \
--dataloader_num_workers $WORKER_NUM \
--overwrite_output_dir
# --freeze_feature_extractor \

Here are the log of my loss:

{'loss':985.5858,'learning_rate':9.978205128205129e-06,'epoch':0.32}
{'loss':235.7915,'learning_rate':9.946153846153847e-06,'epoch':0.64}
{'loss':178.5328,'learning_rate':9.914102564102565e-06,'epoch':0.96}
{'loss':163.2667,'learning_rate':9.882051282051283e-06,'epoch':1.28}
{'loss':161.7683,'learning_rate':9.85e-06,'epoch':1.6}
{'loss':152.2902,'learning_rate':9.817948717948718e-06,'epoch':1.92}
{'loss':144.7367,'learning_rate':9.785897435897436e-06,'epoch':2.24}
{'loss':143.5392,'learning_rate':9.753846153846154e-06,'epoch':2.56}
{'loss':141.25,'learning_rate':9.721794871794872e-06,'epoch':2.88}
{'loss':138.1985,'learning_rate':9.689743589743592e-06,'epoch':3.21}
{'eval_loss':64.0247573852539,'eval_acc':0.3327188940092166,'eval_wer':1.0,'eval_correct':361,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.6289,'eval_samples_per_second':47.947,'eval_steps_per_second':6.01,'epoch':3.21}
{'loss':138.976,'learning_rate':9.65769230769231e-06,'epoch':3.53}
{'loss':137.0047,'learning_rate':9.625641025641026e-06,'epoch':3.85}
{'loss':135.6556,'learning_rate':9.593589743589744e-06,'epoch':4.17}
{'loss':135.4492,'learning_rate':9.561538461538462e-06,'epoch':4.49}
{'loss':134.6317,'learning_rate':9.52948717948718e-06,'epoch':4.81}
{'loss':137.0918,'learning_rate':9.497435897435898e-06,'epoch':5.13}
{'loss':136.7719,'learning_rate':9.465384615384615e-06,'epoch':5.45}
{'loss':131.8289,'learning_rate':9.433333333333335e-06,'epoch':5.77}
{'loss':136.6496,'learning_rate':9.401282051282053e-06,'epoch':6.09}
{'loss':132.0547,'learning_rate':9.369230769230771e-06,'epoch':6.41}
{'eval_loss':62.66316223144531,'eval_acc':0.34838709677419355,'eval_wer':1.0,'eval_correct':378,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.2954,'eval_samples_per_second':48.665,'eval_steps_per_second':6.1,'epoch':6.41}
{'loss':133.7481,'learning_rate':9.337179487179487e-06,'epoch':6.73}
{'loss':134.8716,'learning_rate':9.305128205128205e-06,'epoch':7.05}
{'loss':131.0463,'learning_rate':9.273076923076923e-06,'epoch':7.37}
{'loss':135.5874,'learning_rate':9.24102564102564e-06,'epoch':7.69}
{'loss':135.7136,'learning_rate':9.20897435897436e-06,'epoch':8.01}
{'loss':134.2747,'learning_rate':9.176923076923078e-06,'epoch':8.33}
{'loss':130.7124,'learning_rate':9.144871794871796e-06,'epoch':8.65}
{'loss':135.2457,'learning_rate':9.112820512820514e-06,'epoch':8.97}
{'loss':132.1405,'learning_rate':9.080769230769232e-06,'epoch':9.29}
{'loss':131.7786,'learning_rate':9.04871794871795e-06,'epoch':9.62}
{'eval_loss':61.31118392944336,'eval_acc':0.3622119815668203,'eval_wer':1.0,'eval_correct':393,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.8489,'eval_samples_per_second':47.486,'eval_steps_per_second':5.952,'epoch':9.62}
{'loss':132.8303,'learning_rate':9.016666666666666e-06,'epoch':9.94}
{'loss':126.1847,'learning_rate':8.984615384615386e-06,'epoch':10.26}
{'loss':126.1306,'learning_rate':8.952564102564104e-06,'epoch':10.58}
{'loss':119.6482,'learning_rate':8.920512820512822e-06,'epoch':10.9}
{'loss':114.1183,'learning_rate':8.88846153846154e-06,'epoch':11.22}
{'loss':107.5173,'learning_rate':8.856410256410257e-06,'epoch':11.54}
{'loss':100.6719,'learning_rate':8.824358974358975e-06,'epoch':11.86}
{'loss':93.2497,'learning_rate':8.792307692307693e-06,'epoch':12.18}
{'loss':92.375,'learning_rate':8.760256410256411e-06,'epoch':12.5}
{'loss':87.1823,'learning_rate':8.728205128205129e-06,'epoch':12.82}
{'eval_loss':37.06529998779297,'eval_acc':0.376036866359447,'eval_wer':0.8989448400102943,'eval_correct':408,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.7034,'eval_samples_per_second':47.79,'eval_steps_per_second':5.99,'epoch':12.82}
{'loss':81.6867,'learning_rate':8.696153846153847e-06,'epoch':13.14}
{'loss':82.2944,'learning_rate':8.664102564102565e-06,'epoch':13.46}
{'loss':75.9813,'learning_rate':8.632051282051283e-06,'epoch':13.78}
{'loss':72.7238,'learning_rate':8.6e-06,'epoch':14.1}
{'loss':70.9259,'learning_rate':8.567948717948719e-06,'epoch':14.42}
{'loss':67.0848,'learning_rate':8.535897435897436e-06,'epoch':14.74}
{'loss':66.4381,'learning_rate':8.503846153846154e-06,'epoch':15.06}
{'loss':65.7071,'learning_rate':8.471794871794872e-06,'epoch':15.38}
{'loss':59.4502,'learning_rate':8.43974358974359e-06,'epoch':15.71}
{'loss':60.145,'learning_rate':8.407692307692308e-06,'epoch':16.03}
{'eval_loss':25.065723419189453,'eval_acc':0.3723502304147465,'eval_wer':0.6378999742643905,'eval_correct':404,'eval_total':1085,'eval_strlen':1085,'eval_runtime':23.3066,'eval_samples_per_second':46.553,'eval_steps_per_second':5.835,'epoch':16.03}
{'loss':60.5066,'learning_rate':8.375641025641026e-06,'epoch':16.35}
{'loss':57.3628,'learning_rate':8.343589743589744e-06,'epoch':16.67}
{'loss':56.7665,'learning_rate':8.311538461538462e-06,'epoch':16.99}
{'loss':54.0981,'learning_rate':8.27948717948718e-06,'epoch':17.31}
{'loss':54.6994,'learning_rate':8.247435897435898e-06,'epoch':17.63}
{'loss':53.7669,'learning_rate':8.215384615384616e-06,'epoch':17.95}
{'loss':51.7579,'learning_rate':8.183333333333333e-06,'epoch':18.27}
{'loss':51.0638,'learning_rate':8.151282051282053e-06,'epoch':18.59}
{'loss':48.5319,'learning_rate':8.119230769230771e-06,'epoch':18.91}
{'loss':47.1916,'learning_rate':8.087179487179487e-06,'epoch':19.23}
{'eval_loss':20.137306213378906,'eval_acc':0.37142857142857144,'eval_wer':0.5117954876898001,'eval_correct':403,'eval_total':1085,'eval_strlen':1085,'eval_runtime':23.883,'eval_samples_per_second':45.43,'eval_steps_per_second':5.694,'epoch':19.23}
{'loss':47.5837,'learning_rate':8.055128205128205e-06,'epoch':19.55}
{'loss':46.6194,'learning_rate':8.023076923076923e-06,'epoch':19.87}
{'loss':45.7465,'learning_rate':7.991025641025641e-06,'epoch':20.19}
{'loss':45.2848,'learning_rate':7.958974358974359e-06,'epoch':20.51}
{'loss':44.397,'learning_rate':7.926923076923078e-06,'epoch':20.83}
{'loss':43.1611,'learning_rate':7.894871794871796e-06,'epoch':21.15}
{'loss':44.6381,'learning_rate':7.862820512820514e-06,'epoch':21.47}
{'loss':43.4354,'learning_rate':7.830769230769232e-06,'epoch':21.79}
{'loss':41.3708,'learning_rate':7.79871794871795e-06,'epoch':22.12}
{'loss':40.1245,'learning_rate':7.766666666666666e-06,'epoch':22.44}
{'eval_loss':17.12261199951172,'eval_acc':0.3695852534562212,'eval_wer':0.43441708844471133,'eval_correct':401,'eval_total':1085,'eval_strlen':1085,'eval_runtime':23.3053,'eval_samples_per_second':46.556,'eval_steps_per_second':5.836,'epoch':22.44}
{'loss':38.9329,'learning_rate':7.734615384615384e-06,'epoch':22.76}
{'loss':40.0727,'learning_rate':7.702564102564102e-06,'epoch':23.08}
{'loss':40.7538,'learning_rate':7.670512820512822e-06,'epoch':23.4}
{'loss':37.113,'learning_rate':7.63846153846154e-06,'epoch':23.72}
{'loss':37.875,'learning_rate':7.6064102564102575e-06,'epoch':24.04}
{'loss':39.5066,'learning_rate':7.574358974358975e-06,'epoch':24.36}
{'loss':38.3877,'learning_rate':7.542307692307693e-06,'epoch':24.68}
{'loss':37.8451,'learning_rate':7.510256410256411e-06,'epoch':25.0}
{'loss':36.6671,'learning_rate':7.478205128205129e-06,'epoch':25.32}
{'loss':35.7106,'learning_rate':7.446153846153846e-06,'epoch':25.64}
{'eval_loss':15.00920581817627,'eval_acc':0.38341013824884795,'eval_wer':0.37745560607360384,'eval_correct':416,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.6753,'eval_samples_per_second':47.849,'eval_steps_per_second':5.998,'epoch':25.64}
{'loss':36.3688,'learning_rate':7.414102564102564e-06,'epoch':25.96}
{'loss':35.7812,'learning_rate':7.382051282051282e-06,'epoch':26.28}
{'loss':36.6276,'learning_rate':7.350000000000001e-06,'epoch':26.6}
{'loss':34.574,'learning_rate':7.317948717948719e-06,'epoch':26.92}
{'loss':33.1534,'learning_rate':7.285897435897437e-06,'epoch':27.24}
{'loss':33.5777,'learning_rate':7.2538461538461545e-06,'epoch':27.56}
{'loss':33.5143,'learning_rate':7.221794871794872e-06,'epoch':27.88}
{'loss':33.2964,'learning_rate':7.189743589743591e-06,'epoch':28.21}
{'loss':29.894,'learning_rate':7.157692307692309e-06,'epoch':28.53}
{'loss':32.6357,'learning_rate':7.125641025641026e-06,'epoch':28.85}
{'eval_loss':13.63546371459961,'eval_acc':0.37511520737327186,'eval_wer':0.33816590889594234,'eval_correct':407,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.2424,'eval_samples_per_second':48.781,'eval_steps_per_second':6.114,'epoch':28.85}
{'loss':30.7197,'learning_rate':7.093589743589744e-06,'epoch':29.17}
{'loss':32.0781,'learning_rate':7.061538461538462e-06,'epoch':29.49}
{'loss':31.6058,'learning_rate':7.02948717948718e-06,'epoch':29.81}
{'loss':31.9951,'learning_rate':6.997435897435898e-06,'epoch':30.13}
{'loss':30.5193,'learning_rate':6.965384615384616e-06,'epoch':30.45}
{'loss':31.2168,'learning_rate':6.9333333333333344e-06,'epoch':30.77}
{'loss':30.5628,'learning_rate':6.901282051282052e-06,'epoch':31.09}
{'loss':28.9461,'learning_rate':6.86923076923077e-06,'epoch':31.41}
{'loss':29.4999,'learning_rate':6.837179487179487e-06,'epoch':31.73}
{'loss':29.8131,'learning_rate':6.805128205128205e-06,'epoch':32.05}
{'eval_loss':12.406923294067383,'eval_acc':0.3622119815668203,'eval_wer':0.3026507677790169,'eval_correct':393,'eval_total':1085,'eval_strlen':1085,'eval_runtime':23.3428,'eval_samples_per_second':46.481,'eval_steps_per_second':5.826,'epoch':32.05}
{'loss':29.4344,'learning_rate':6.773076923076923e-06,'epoch':32.37}
{'loss':28.6873,'learning_rate':6.741025641025641e-06,'epoch':32.69}
{'loss':26.8133,'learning_rate':6.70897435897436e-06,'epoch':33.01}
{'loss':28.489,'learning_rate':6.676923076923078e-06,'epoch':33.33}
{'loss':29.5473,'learning_rate':6.644871794871796e-06,'epoch':33.65}
{'loss':26.3075,'learning_rate':6.6128205128205135e-06,'epoch':33.97}
{'loss':26.6269,'learning_rate':6.580769230769231e-06,'epoch':34.29}
{'loss':26.9115,'learning_rate':6.548717948717949e-06,'epoch':34.62}
{'loss':27.4964,'learning_rate':6.516666666666666e-06,'epoch':34.94}
{'loss':27.6296,'learning_rate':6.484615384615385e-06,'epoch':35.26}
{'eval_loss':11.483007431030273,'eval_acc':0.3640552995391705,'eval_wer':0.2797460753195505,'eval_correct':395,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.1875,'eval_samples_per_second':48.901,'eval_steps_per_second':6.13,'epoch':35.26}
{'loss':25.1496,'learning_rate':6.452564102564103e-06,'epoch':35.58}
{'loss':27.7401,'learning_rate':6.420512820512821e-06,'epoch':35.9}
{'loss':26.2247,'learning_rate':6.388461538461539e-06,'epoch':36.22}
{'loss':25.9773,'learning_rate':6.356410256410257e-06,'epoch':36.54}
{'loss':26.1919,'learning_rate':6.324358974358975e-06,'epoch':36.86}
{'loss':25.4572,'learning_rate':6.2923076923076934e-06,'epoch':37.18}
{'loss':25.2353,'learning_rate':6.260256410256411e-06,'epoch':37.5}
{'loss':24.2015,'learning_rate':6.228205128205129e-06,'epoch':37.82}
{'loss':25.4023,'learning_rate':6.196153846153846e-06,'epoch':38.14}
{'loss':24.6039,'learning_rate':6.164102564102564e-06,'epoch':38.46}
{'eval_loss':10.490436553955078,'eval_acc':0.3576036866359447,'eval_wer':0.25306682679934805,'eval_correct':388,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.5484,'eval_samples_per_second':48.119,'eval_steps_per_second':6.031,'epoch':38.46}
{'loss':23.6292,'learning_rate':6.132051282051282e-06,'epoch':38.78}
{'loss':24.9531,'learning_rate':6.1e-06,'epoch':39.1}
{'loss':23.8321,'learning_rate':6.067948717948719e-06,'epoch':39.42}
{'loss':24.8373,'learning_rate':6.035897435897437e-06,'epoch':39.74}
{'loss':24.2936,'learning_rate':6.003846153846155e-06,'epoch':40.06}
{'loss':23.3554,'learning_rate':5.9717948717948725e-06,'epoch':40.38}
{'loss':23.855,'learning_rate':5.9397435897435904e-06,'epoch':40.71}
{'loss':23.7896,'learning_rate':5.907692307692308e-06,'epoch':41.03}
{'loss':23.8413,'learning_rate':5.875641025641025e-06,'epoch':41.35}
{'loss':22.8959,'learning_rate':5.843589743589744e-06,'epoch':41.67}
{'eval_loss':10.052651405334473,'eval_acc':0.3723502304147465,'eval_wer':0.2353092562408853,'eval_correct':404,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.1869,'eval_samples_per_second':48.903,'eval_steps_per_second':6.13,'epoch':41.67}
{'loss':23.2867,'learning_rate':5.811538461538462e-06,'epoch':41.99}
{'loss':23.4061,'learning_rate':5.77948717948718e-06,'epoch':42.31}
{'loss':22.8977,'learning_rate':5.747435897435898e-06,'epoch':42.63}
{'loss':22.8071,'learning_rate':5.715384615384616e-06,'epoch':42.95}
{'loss':22.3049,'learning_rate':5.683333333333334e-06,'epoch':43.27}
{'loss':22.8277,'learning_rate':5.6512820512820524e-06,'epoch':43.59}
{'loss':22.2595,'learning_rate':5.61923076923077e-06,'epoch':43.91}
{'loss':21.8767,'learning_rate':5.587179487179487e-06,'epoch':44.23}
{'loss':21.0823,'learning_rate':5.555128205128205e-06,'epoch':44.55}
{'loss':22.6813,'learning_rate':5.523076923076923e-06,'epoch':44.87}
{'eval_loss':9.504764556884766,'eval_acc':0.3631336405529954,'eval_wer':0.21394870035172,'eval_correct':394,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.4444,'eval_samples_per_second':48.342,'eval_steps_per_second':6.059,'epoch':44.87}
{'loss':21.193,'learning_rate':5.491025641025641e-06,'epoch':45.19}
{'loss':21.6409,'learning_rate':5.458974358974359e-06,'epoch':45.51}
{'loss':21.7376,'learning_rate':5.426923076923078e-06,'epoch':45.83}
{'loss':22.2447,'learning_rate':5.394871794871796e-06,'epoch':46.15}
{'loss':20.673,'learning_rate':5.362820512820514e-06,'epoch':46.47}
{'loss':22.2161,'learning_rate':5.3307692307692315e-06,'epoch':46.79}
{'loss':21.0008,'learning_rate':5.2987179487179494e-06,'epoch':47.12}
{'loss':21.8974,'learning_rate':5.2666666666666665e-06,'epoch':47.44}
{'loss':20.4602,'learning_rate':5.234615384615384e-06,'epoch':47.76}
{'loss':20.512,'learning_rate':5.202564102564102e-06,'epoch':48.08}
{'eval_loss':8.960915565490723,'eval_acc':0.38525345622119817,'eval_wer':0.20159560778931115,'eval_correct':418,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.1423,'eval_samples_per_second':49.001,'eval_steps_per_second':6.142,'epoch':48.08}
{'loss':19.1649,'learning_rate':5.170512820512821e-06,'epoch':48.4}
{'loss':20.5475,'learning_rate':5.138461538461539e-06,'epoch':48.72}
{'loss':20.9567,'learning_rate':5.106410256410257e-06,'epoch':49.04}
{'loss':20.4606,'learning_rate':5.074358974358975e-06,'epoch':49.36}
{'loss':20.111,'learning_rate':5.042307692307693e-06,'epoch':49.68}
{'loss':20.2693,'learning_rate':5.0102564102564115e-06,'epoch':50.0}
{'loss':20.4831,'learning_rate':4.9782051282051285e-06,'epoch':50.32}
{'loss':20.6807,'learning_rate':4.9461538461538464e-06,'epoch':50.64}
{'loss':19.8365,'learning_rate':4.914102564102564e-06,'epoch':50.96}
{'loss':19.9463,'learning_rate':4.882051282051282e-06,'epoch':51.28}
{'eval_loss':8.520371437072754,'eval_acc':0.3622119815668203,'eval_wer':0.18992879814703612,'eval_correct':393,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.3851,'eval_samples_per_second':48.47,'eval_steps_per_second':6.075,'epoch':51.28}
{'loss':19.1983,'learning_rate':4.85e-06,'epoch':51.6}
{'loss':19.1613,'learning_rate':4.817948717948718e-06,'epoch':51.92}
{'loss':18.5203,'learning_rate':4.785897435897436e-06,'epoch':52.24}
{'loss':19.0384,'learning_rate':4.753846153846155e-06,'epoch':52.56}
{'loss':19.3914,'learning_rate':4.721794871794872e-06,'epoch':52.88}
{'loss':18.0644,'learning_rate':4.68974358974359e-06,'epoch':53.21}
{'loss':19.1464,'learning_rate':4.6576923076923084e-06,'epoch':53.53}
{'loss':18.7813,'learning_rate':4.625641025641026e-06,'epoch':53.85}
{'loss':17.8645,'learning_rate':4.593589743589744e-06,'epoch':54.17}
{'loss':18.6739,'learning_rate':4.561538461538461e-06,'epoch':54.49}
{'eval_loss':8.00749683380127,'eval_acc':0.46543778801843316,'eval_wer':0.17963455434502873,'eval_correct':505,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.5481,'eval_samples_per_second':48.119,'eval_steps_per_second':6.032,'epoch':54.49}
{'loss':19.6353,'learning_rate':4.52948717948718e-06,'epoch':54.81}
{'loss':18.3448,'learning_rate':4.497435897435898e-06,'epoch':55.13}
{'loss':18.9602,'learning_rate':4.465384615384616e-06,'epoch':55.45}
{'loss':18.5076,'learning_rate':4.433333333333334e-06,'epoch':55.77}
{'loss':18.6794,'learning_rate':4.401282051282052e-06,'epoch':56.09}
{'loss':17.8309,'learning_rate':4.36923076923077e-06,'epoch':56.41}
{'loss':17.687,'learning_rate':4.3371794871794875e-06,'epoch':56.73}
{'loss':18.8916,'learning_rate':4.3051282051282054e-06,'epoch':57.05}
{'loss':18.4605,'learning_rate':4.273076923076923e-06,'epoch':57.37}
{'loss':17.1832,'learning_rate':4.241025641025641e-06,'epoch':57.69}
{'eval_loss':7.679691314697266,'eval_acc':0.5529953917050692,'eval_wer':0.17474478853907524,'eval_correct':600,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.3855,'eval_samples_per_second':48.469,'eval_steps_per_second':6.075,'epoch':57.69}
{'loss':18.9531,'learning_rate':4.208974358974359e-06,'epoch':58.01}
{'loss':17.9138,'learning_rate':4.176923076923077e-06,'epoch':58.33}
{'loss':18.0587,'learning_rate':4.144871794871795e-06,'epoch':58.65}
{'loss':16.804,'learning_rate':4.112820512820514e-06,'epoch':58.97}
{'loss':17.7542,'learning_rate':4.080769230769231e-06,'epoch':59.29}
{'loss':17.8513,'learning_rate':4.048717948717949e-06,'epoch':59.62}
{'loss':17.2847,'learning_rate':4.0166666666666675e-06,'epoch':59.94}
{'loss':17.415,'learning_rate':3.984615384615385e-06,'epoch':60.26}
{'loss':17.3434,'learning_rate':3.9525641025641024e-06,'epoch':60.58}
{'loss':16.7278,'learning_rate':3.92051282051282e-06,'epoch':60.9}
{'eval_loss':7.446893692016602,'eval_acc':0.5225806451612903,'eval_wer':0.1663378227674359,'eval_correct':567,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.4308,'eval_samples_per_second':48.371,'eval_steps_per_second':6.063,'epoch':60.9}
{'loss':16.8465,'learning_rate':3.888461538461539e-06,'epoch':61.22}
{'loss':16.9294,'learning_rate':3.856410256410257e-06,'epoch':61.54}
{'loss':18.5263,'learning_rate':3.824358974358975e-06,'epoch':61.86}
{'loss':17.3004,'learning_rate':3.7923076923076924e-06,'epoch':62.18}
{'loss':17.2132,'learning_rate':3.7602564102564103e-06,'epoch':62.5}
{'loss':17.0774,'learning_rate':3.7282051282051286e-06,'epoch':62.82}
{'loss':17.7892,'learning_rate':3.6961538461538465e-06,'epoch':63.14}
{'loss':16.9273,'learning_rate':3.6641025641025644e-06,'epoch':63.46}
{'loss':17.4801,'learning_rate':3.632051282051282e-06,'epoch':63.78}
{'loss':16.2311,'learning_rate':3.6000000000000003e-06,'epoch':64.1}
{'eval_loss':7.0554890632629395,'eval_acc':0.5465437788018433,'eval_wer':0.15964656429613108,'eval_correct':593,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.2933,'eval_samples_per_second':48.669,'eval_steps_per_second':6.1,'epoch':64.1}
{'loss':16.0454,'learning_rate':3.567948717948718e-06,'epoch':64.42}
{'loss':16.7367,'learning_rate':3.535897435897436e-06,'epoch':64.74}
{'loss':17.1129,'learning_rate':3.5038461538461544e-06,'epoch':65.06}
{'loss':16.575,'learning_rate':3.471794871794872e-06,'epoch':65.38}
{'loss':16.4406,'learning_rate':3.43974358974359e-06,'epoch':65.71}
{'loss':15.7366,'learning_rate':3.407692307692308e-06,'epoch':66.03}
{'loss':16.5086,'learning_rate':3.375641025641026e-06,'epoch':66.35}
{'loss':16.1201,'learning_rate':3.343589743589744e-06,'epoch':66.67}
{'loss':16.1581,'learning_rate':3.3115384615384614e-06,'epoch':66.99}
{'loss':17.1046,'learning_rate':3.2794871794871798e-06,'epoch':67.31}
{'eval_loss':6.7258687019348145,'eval_acc':0.5419354838709678,'eval_wer':0.1538989448400103,'eval_correct':588,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.7066,'eval_samples_per_second':47.783,'eval_steps_per_second':5.989,'epoch':67.31}
{'loss':15.8795,'learning_rate':3.2474358974358977e-06,'epoch':67.63}
{'loss':16.0978,'learning_rate':3.2153846153846156e-06,'epoch':67.95}
{'loss':16.4027,'learning_rate':3.183333333333334e-06,'epoch':68.27}
{'loss':16.9197,'learning_rate':3.1512820512820514e-06,'epoch':68.59}
{'loss':16.4184,'learning_rate':3.1192307692307693e-06,'epoch':68.91}
{'loss':15.8148,'learning_rate':3.0871794871794876e-06,'epoch':69.23}
{'loss':16.1548,'learning_rate':3.0551282051282055e-06,'epoch':69.55}
{'loss':16.0763,'learning_rate':3.0230769230769235e-06,'epoch':69.87}
{'loss':15.1609,'learning_rate':2.991025641025641e-06,'epoch':70.19}
{'loss':17.6325,'learning_rate':2.9589743589743593e-06,'epoch':70.51}
{'eval_loss':6.626087665557861,'eval_acc':0.5511520737327189,'eval_wer':0.14943810585914044,'eval_correct':598,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.5721,'eval_samples_per_second':48.068,'eval_steps_per_second':6.025,'epoch':70.51}
{'loss':15.4336,'learning_rate':2.926923076923077e-06,'epoch':70.83}
{'loss':15.3139,'learning_rate':2.894871794871795e-06,'epoch':71.15}
{'loss':15.661,'learning_rate':2.8628205128205134e-06,'epoch':71.47}
{'loss':15.8424,'learning_rate':2.830769230769231e-06,'epoch':71.79}
{'loss':15.5714,'learning_rate':2.798717948717949e-06,'epoch':72.12}
{'loss':15.4303,'learning_rate':2.766666666666667e-06,'epoch':72.44}
{'loss':15.7979,'learning_rate':2.734615384615385e-06,'epoch':72.76}
{'loss':15.3546,'learning_rate':2.7025641025641025e-06,'epoch':73.08}
{'loss':14.9681,'learning_rate':2.6705128205128204e-06,'epoch':73.4}
{'loss':16.3893,'learning_rate':2.6384615384615388e-06,'epoch':73.72}
{'eval_loss':6.418200492858887,'eval_acc':0.5539170506912442,'eval_wer':0.1454061937033542,'eval_correct':601,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.5247,'eval_samples_per_second':48.169,'eval_steps_per_second':6.038,'epoch':73.72}
{'loss':15.3735,'learning_rate':2.6064102564102567e-06,'epoch':74.04}
{'loss':16.033,'learning_rate':2.5743589743589746e-06,'epoch':74.36}
{'loss':16.317,'learning_rate':2.5429487179487182e-06,'epoch':74.68}
{'loss':14.449,'learning_rate':2.5108974358974357e-06,'epoch':75.0}
{'loss':16.2426,'learning_rate':2.478846153846154e-06,'epoch':75.32}
{'loss':14.9107,'learning_rate':2.446794871794872e-06,'epoch':75.64}
{'loss':16.2922,'learning_rate':2.41474358974359e-06,'epoch':75.96}
{'loss':15.3804,'learning_rate':2.3826923076923078e-06,'epoch':76.28}
{'loss':14.6739,'learning_rate':2.3506410256410257e-06,'epoch':76.6}
{'loss':15.7395,'learning_rate':2.3185897435897436e-06,'epoch':76.92}
{'eval_loss':6.264820575714111,'eval_acc':0.5658986175115207,'eval_wer':0.14403362786308654,'eval_correct':614,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.579,'eval_samples_per_second':48.054,'eval_steps_per_second':6.023,'epoch':76.92}
{'loss':14.807,'learning_rate':2.2865384615384615e-06,'epoch':77.24}
{'loss':14.7884,'learning_rate':2.25448717948718e-06,'epoch':77.56}
{'loss':14.9923,'learning_rate':2.2224358974358977e-06,'epoch':77.88}
{'loss':15.566,'learning_rate':2.1903846153846157e-06,'epoch':78.21}
{'loss':15.6809,'learning_rate':2.1583333333333336e-06,'epoch':78.53}
{'loss':14.8197,'learning_rate':2.1262820512820515e-06,'epoch':78.85}
{'loss':15.822,'learning_rate':2.0942307692307694e-06,'epoch':79.17}
{'loss':14.6595,'learning_rate':2.0621794871794873e-06,'epoch':79.49}
{'loss':14.122,'learning_rate':2.030128205128205e-06,'epoch':79.81}
{'loss':14.2127,'learning_rate':1.998076923076923e-06,'epoch':80.13}
{'eval_loss':6.134837627410889,'eval_acc':0.5649769585253456,'eval_wer':0.13734236939178177,'eval_correct':613,'eval_total':1085,'eval_strlen':1085,'eval_runtime':23.1227,'eval_samples_per_second':46.924,'eval_steps_per_second':5.882,'epoch':80.13}
{'loss':15.6788,'learning_rate':1.966025641025641e-06,'epoch':80.45}
{'loss':14.2348,'learning_rate':1.9339743589743593e-06,'epoch':80.77}
{'loss':15.1758,'learning_rate':1.901923076923077e-06,'epoch':81.09}
{'loss':15.2143,'learning_rate':1.8698717948717952e-06,'epoch':81.41}
{'loss':13.9283,'learning_rate':1.8378205128205129e-06,'epoch':81.73}
{'loss':14.7147,'learning_rate':1.805769230769231e-06,'epoch':82.05}
{'loss':14.6605,'learning_rate':1.773717948717949e-06,'epoch':82.37}
{'loss':14.3372,'learning_rate':1.7416666666666668e-06,'epoch':82.69}
{'loss':14.0261,'learning_rate':1.709615384615385e-06,'epoch':83.01}
{'loss':14.064,'learning_rate':1.6775641025641026e-06,'epoch':83.33}
{'eval_loss':6.011620044708252,'eval_acc':0.5723502304147465,'eval_wer':0.13759972548683194,'eval_correct':621,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.4682,'eval_samples_per_second':48.29,'eval_steps_per_second':6.053,'epoch':83.33}
{'loss':14.1767,'learning_rate':1.6455128205128207e-06,'epoch':83.65}
{'loss':14.9206,'learning_rate':1.6134615384615384e-06,'epoch':83.97}
{'loss':14.4842,'learning_rate':1.5814102564102565e-06,'epoch':84.29}
{'loss':14.2886,'learning_rate':1.5493589743589747e-06,'epoch':84.62}
{'loss':15.4714,'learning_rate':1.5173076923076924e-06,'epoch':84.94}
{'loss':14.5575,'learning_rate':1.4852564102564105e-06,'epoch':85.26}
{'loss':15.418,'learning_rate':1.4532051282051282e-06,'epoch':85.58}
{'loss':14.4758,'learning_rate':1.4211538461538463e-06,'epoch':85.9}
{'loss':14.4773,'learning_rate':1.3891025641025644e-06,'epoch':86.22}
{'loss':14.0715,'learning_rate':1.3570512820512821e-06,'epoch':86.54}
{'eval_loss':5.954724311828613,'eval_acc':0.5824884792626728,'eval_wer':0.13425409625117954,'eval_correct':632,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.4774,'eval_samples_per_second':48.271,'eval_steps_per_second':6.051,'epoch':86.54}
{'loss':14.8675,'learning_rate':1.3250000000000002e-06,'epoch':86.86}
{'loss':13.7117,'learning_rate':1.292948717948718e-06,'epoch':87.18}
{'loss':13.7393,'learning_rate':1.260897435897436e-06,'epoch':87.5}
{'loss':14.3481,'learning_rate':1.228846153846154e-06,'epoch':87.82}
{'loss':14.529,'learning_rate':1.1967948717948719e-06,'epoch':88.14}
{'loss':14.1848,'learning_rate':1.1647435897435898e-06,'epoch':88.46}
{'loss':15.2701,'learning_rate':1.1326923076923079e-06,'epoch':88.78}
{'loss':14.3803,'learning_rate':1.1006410256410258e-06,'epoch':89.1}
{'loss':14.3358,'learning_rate':1.0685897435897437e-06,'epoch':89.42}
{'loss':13.6485,'learning_rate':1.0365384615384616e-06,'epoch':89.74}
{'eval_loss':5.882339000701904,'eval_acc':0.5788018433179724,'eval_wer':0.1331388865059621,'eval_correct':628,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.356,'eval_samples_per_second':48.533,'eval_steps_per_second':6.083,'epoch':89.74}
{'loss':13.9803,'learning_rate':1.0044871794871795e-06,'epoch':90.06}
{'loss':14.2189,'learning_rate':9.724358974358974e-07,'epoch':90.38}
{'loss':14.3116,'learning_rate':9.403846153846156e-07,'epoch':90.71}
{'loss':14.276,'learning_rate':9.083333333333335e-07,'epoch':91.03}
{'loss':14.012,'learning_rate':8.762820512820514e-07,'epoch':91.35}
{'loss':14.4926,'learning_rate':8.442307692307693e-07,'epoch':91.67}
{'loss':13.7873,'learning_rate':8.128205128205128e-07,'epoch':91.99}
{'loss':14.0326,'learning_rate':7.807692307692307e-07,'epoch':92.31}
{'loss':14.107,'learning_rate':7.487179487179488e-07,'epoch':92.63}
{'loss':13.8864,'learning_rate':7.166666666666668e-07,'epoch':92.95}
{'eval_loss':5.7882466316223145,'eval_acc':0.5861751152073733,'eval_wer':0.1304795401904435,'eval_correct':636,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.1789,'eval_samples_per_second':48.92,'eval_steps_per_second':6.132,'epoch':92.95}
{'loss':14.0921,'learning_rate':6.846153846153847e-07,'epoch':93.27}
{'loss':14.2333,'learning_rate':6.525641025641026e-07,'epoch':93.59}
{'loss':14.3692,'learning_rate':6.205128205128206e-07,'epoch':93.91}
{'loss':14.5696,'learning_rate':5.884615384615385e-07,'epoch':94.23}
{'loss':14.0905,'learning_rate':5.564102564102564e-07,'epoch':94.55}
{'loss':14.5384,'learning_rate':5.243589743589744e-07,'epoch':94.87}
{'loss':13.9115,'learning_rate':4.923076923076923e-07,'epoch':95.19}
{'loss':15.1445,'learning_rate':4.602564102564103e-07,'epoch':95.51}
{'loss':14.4323,'learning_rate':4.2820512820512825e-07,'epoch':95.83}
{'loss':13.8939,'learning_rate':3.9615384615384616e-07,'epoch':96.15}
{'eval_loss':5.790042877197266,'eval_acc':0.584331797235023,'eval_wer':0.13159474993566098,'eval_correct':634,'eval_total':1085,'eval_strlen':1085,'eval_runtime':25.079,'eval_samples_per_second':43.263,'eval_steps_per_second':5.423,'epoch':96.15}
{'loss':14.3467,'learning_rate':3.641025641025641e-07,'epoch':96.47}
{'loss':13.8848,'learning_rate':3.320512820512821e-07,'epoch':96.79}
{'loss':14.2129,'learning_rate':3.0000000000000004e-07,'epoch':97.12}
{'loss':13.4828,'learning_rate':2.6794871794871795e-07,'epoch':97.44}
{'loss':14.8709,'learning_rate':2.3589743589743593e-07,'epoch':97.76}
{'loss':14.9602,'learning_rate':2.0384615384615384e-07,'epoch':98.08}
{'loss':14.0581,'learning_rate':1.717948717948718e-07,'epoch':98.4}
{'loss':14.1265,'learning_rate':1.3974358974358977e-07,'epoch':98.72}
{'loss':13.9862,'learning_rate':1.076923076923077e-07,'epoch':99.04}
{'loss':14.1605,'learning_rate':7.564102564102565e-08,'epoch':99.36}
{'eval_loss':5.765547752380371,'eval_acc':0.5944700460829493,'eval_wer':0.13150896457064426,'eval_correct':645,'eval_total':1085,'eval_strlen':1085,'eval_runtime':22.4697,'eval_samples_per_second':48.287,'eval_steps_per_second':6.053,'epoch':99.36}
{'loss':14.1408,'learning_rate':4.358974358974359e-08,'epoch':99.68}
{'loss':13.6663,'learning_rate':1.153846153846154e-08,'epoch':100.0}
{'train_runtime':127829.401,'train_samples_per_second':7.805,'train_steps_per_second':0.122,'train_loss':40.89300357133914,'epoch':100.0}
{'test_loss':5.772092819213867,'test_acc':0.591705069124424,'test_wer':0.13142317920562752,'test_correct':642,'test_total':1085,'test_strlen':1085,'test_runtime':86.2423,'test_samples_per_second':12.581,'test_steps_per_second':1.577}

The final acc is 0.591705069124424 which is not as good as the result of the paper.

TideDancer commented 2 years ago

Hello, thanks for your testing and I would like to help.

(1) In the script, I use per_device_train_batch_size=2, and accumulated_gradient_steps=4, making effective batch size 8. The reason is that I don't have GPUs that has large memory (like V100). In your case, if you set per_device_train_batch_size=8, then there is no need for accumulation, otherwise your effective batch size is 32. This leads to much less updates and results in worse performance. So you can change the ACC to 1, and test.

(2) For alpha > 0 (in your case, 0.1), use learning rate = 5e-5 instead of 1e-5. As stated in the paper, 1e-5 is only used when alpha=0.

(2) By looking at your log, it seems that wer is pretty low. I compared with my logs on 01F, I typically get around 0.2. I suspect the ASR part is well trained, but the classification part is under trained. Again, this could due to the batch size or learning rate issue.

(3) A very tricky thing you will observe, is that the evaluation loss will actually increase after around 20 epochs. It looks like overfitting, but the evaluation WER and accuracy will still getting better and better. One reason could be the training CTC loss is not the same as the evaluation metric (WER), causing this strange behavior. So when you see eval loss goes up, don't stop and keep it running. The final acc on 01F I can get is around 0.74.

Hope these helps. Thanks!

Gpwner commented 2 years ago

Hello, thanks for your testing and I would like to help.

(1) In the script, I use per_device_train_batch_size=2, and accumulated_gradient_steps=4, making effective batch size 8. The reason is that I don't have GPUs that has large memory (like V100). In your case, if you set per_device_train_batch_size=8, then there is no need for accumulation, otherwise your effective batch size is 32. This leads to much less updates and results in worse performance. So you can change the ACC to 1, and test.

(2) For alpha > 0 (in your case, 0.1), use learning rate = 5e-5 instead of 1e-5. As stated in the paper, 1e-5 is only used when alpha=0.

(2) By looking at your log, it seems that wer is pretty low. I compared with my logs on 01F, I typically get around 0.2. I suspect the ASR part is well trained, but the classification part is under trained. Again, this could due to the batch size or learning rate issue.

(3) A very tricky thing you will observe, is that the evaluation loss will actually increase after around 20 epochs. It looks like overfitting, but the evaluation WER and accuracy will still getting better and better. One reason could be the training CTC loss is not the same as the evaluation metric (WER), causing this strange behavior. So when you see eval loss goes up, don't stop and keep it running. The final acc on 01F I can get is around 0.74.

Hope these helps. Thanks!

Yes It works. But I have a question,How can I load the CTCTrainer model after training?There is no from_pretrained() in it... I am really fresh to Huggingface Trainner,so can you help?

Gpwner commented 2 years ago

Hello, thanks for your testing and I would like to help. (1) In the script, I use per_device_train_batch_size=2, and accumulated_gradient_steps=4, making effective batch size 8. The reason is that I don't have GPUs that has large memory (like V100). In your case, if you set per_device_train_batch_size=8, then there is no need for accumulation, otherwise your effective batch size is 32. This leads to much less updates and results in worse performance. So you can change the ACC to 1, and test. (2) For alpha > 0 (in your case, 0.1), use learning rate = 5e-5 instead of 1e-5. As stated in the paper, 1e-5 is only used when alpha=0. (2) By looking at your log, it seems that wer is pretty low. I compared with my logs on 01F, I typically get around 0.2. I suspect the ASR part is well trained, but the classification part is under trained. Again, this could due to the batch size or learning rate issue. (3) A very tricky thing you will observe, is that the evaluation loss will actually increase after around 20 epochs. It looks like overfitting, but the evaluation WER and accuracy will still getting better and better. One reason could be the training CTC loss is not the same as the evaluation metric (WER), causing this strange behavior. So when you see eval loss goes up, don't stop and keep it running. The final acc on 01F I can get is around 0.74. Hope these helps. Thanks!

Yes It works. But I have a question,How can I load the CTCTrainer model after training?There is no from_pretrained() in it... I am really fresh to Huggingface Trainner,so can you help?

I am not pretty sure that it seems to be like this:

model = Wav2Vec2ForCTCnCLS.from_pretrained(
    'output/tmp/checkpoint-93500',
    cache_dir=model_args.cache_dir,
    # gradient_checkpointing=training_args.gradient_checkpointing,
    vocab_size=len(processor.tokenizer),
    cls_len=len(cls_label_map),
    alpha=model_args.alpha,
)

...

trainer = CTCTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor
)

But I am so confuse that if I uncomment # gradient_checkpointing=training_args.gradient_checkpointing,then I will get an error:

  model = Wav2Vec2ForCTCnCLS.from_pretrained(
File "/home/*/miniconda3/envs/INTERSpeech21/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1402, in from_pretrained
  model = cls(config, *model_args, **model_kwargs)
TypeError: __init__() got an unexpected keyword argument 'gradient_checkpointing'

Process finished with exit code 1
TideDancer commented 2 years ago

Yes the from_pretrained() will load the pretrained model, as the Wav2Vec2ForCTCnCLS is inheritated from the huggingface pretrainedModel class, so the from_pretrained() function works the same way.

As for the gradient_checkpointing part, I am not sure either. I never used this feature before. If just commenting that, can you smoothly load the model and run the code?

Gpwner commented 2 years ago

Yes the from_pretrained() will load the pretrained model, as the Wav2Vec2ForCTCnCLS is inheritated from the huggingface pretrainedModel class, so the from_pretrained() function works the same way.

As for the gradient_checkpointing part, I am not sure either. I never used this feature before. If just commenting that, can you smoothly load the model and run the code?

Yes.Thanks for your help,I will close this issue.