Closed YuffieHuang closed 2 years ago
@vgaraujov Hi. Thank you for sharing the code with us. I think I have found where the problem is.
In utils/train.py, lines 45 and 46, the loss and accuracy are calculated by dividing the sum of total loss and accuracy by the length of the data loader. However, in lines 28 and 29, the total loss and accuracy are added three times since Acc and loss both have the shape of (3, 1).
I guess we need to further divide the final accuracy and loss by three? I have no good explanation for this. However, I count the total number that the total loss and accuracy are added, the number is exactly three times the length of the data loader.
Hi @YuffieHuang, sorry for the delayed replay. I just checked the code. You are right; final_acc and final_loss should be divided by 3 or timestep variable. I think I did do that because of the Deepmind's loss graph shows the sum of the individual losses.
Thank you for your comment. Let me know if you find other errors.
@vgaraujov Thanks for the clarification! Just a friendly reminder, final_acc and final_loss need to be modified in both utils/train.py and utils/validation.py.
Hi, I'm training a CPC model with a customized dataset (rather than BookCorpus) and multiple GPUs. I tried the default learning rate of 2e-4 in main.py at first. However, the training loss (around 15.5) and training accuracy (approximately 0.03) have tiny improvement after 800 epochs (each epoch has 267,264 steps). The validation loss is about 15.8, and the validation accuracy is about 0.025.
I tried to increase the learning rate for the Adam optimizer in main.py to 3e-4, 4e-4, 5e-4, and 1e-3. Everything else remains the same. The strange thing happens that the training loss drops to 0.9, and the training accuracy rises to 2.6 after only 80 epochs. I'm confused why the accuracy could go beyond 1? The validation loss goes up to 66, and the validation accuracy is only about 0.03. There is overfitting for sure.
Do you have any idea why the accuracy can be larger than 1? I'm also checking the validation function in utils.validation.py. Please let me know if you have any ideas. Thank you so much!