HuiZhang0812 / DiffusionAD

148 stars 16 forks source link

problems with the loss updating #28

Closed CarlosNacher closed 8 months ago

CarlosNacher commented 8 months ago

Hi @HuiZhang0812 ,

First of all, thanks a lot for your work and for sharing it here!

I have realized one thing what I think could be a coding error. In train.py file, lines 136-141 the train losses are added to the epoch total loss:

train_loss += loss.item()
tbar.set_description('Epoch:%d, Train loss: %.3f' % (epoch, train_loss))

train_smL1_loss += smL1_loss.item()
train_focal_loss+=5*focal_loss.item()
train_noise_loss+=noise_loss.item()

I suppose this was for after, within lines 144-149, compute the mean of those epoch losses dividing them with the number of steps of the epoch (i). But it seems to have been forgotten, and it is being appended the sum of the epoch loss.

I think the following:

        if epoch % 10 ==0  and epoch > 0:
            train_loss_list.append(round(train_loss,3))
            train_smL1_loss_list.append(round(train_smL1_loss,3))
            train_focal_loss_list.append(round(train_focal_loss,3))
            train_noise_loss_list.append(round(train_noise_loss,3))
            loss_x_list.append(int(epoch))

Has to be replaced with:

        if epoch % 10 ==0  and epoch > 0:
            train_loss_list.append(round(train_loss / (i+1),3))
            train_smL1_loss_list.append(round(train_smL1_losss / (i+1),3))
            train_focal_loss_list.append(round(train_focal_losss / (i+1),3))
            train_noise_loss_list.append(round(train_noise_losss / (i+1),3))
            loss_x_list.append(int(epoch))

(to take advantage of the i variable, obviously there are alternative ways to achieve it).

In addition, these loss lists are not saved or logged anywhere afterwards (I guess this is a decision you are aware of).

Again, thank you very much for your great work in developing this SOTA algorithm and even more for sharing it with the community.

HuiZhang0812 commented 8 months ago

Thank you for your suggestion. The code you mentioned in lines 144-149 is designed to observe the variation of the loss, so summation or averaging might both be feasible, it's just a matter of scaling.

CarlosNacher commented 8 months ago

Thank you for your suggestion. The code you mentioned in lines 144-149 is designed to observe the variation of the loss, so summation or averaging might both be feasible, it's just a matter of scaling.

Totally agree that it is a matter of scale. Usually the average is used to make it more representative of "the loss of the epoch". However, I agree that if all you want to do is look at the variation, it doesn't matter what the scale is.

Thank you very much for the clarification and for your huge contribution