sangyun884 / rfpp

The codebase of our paper "Improving the Training of Rectified Flows"
75 stars 4 forks source link

Incorrect huber loss calculation #1

Open marcoppasini opened 4 months ago

marcoppasini commented 4 months ago

Thanks for sharing the code for your amazing work! I have a question about your huber loss implementation. Shouldn't you take the sqrt of the data_dim to calculate huber_c?

Your code:

def loss_func_huber(x, y):
    data_dim = x.shape[1] * x.shape[2] * x.shape[3]
    huber_c = 0.00054 * data_dim
    ...

Corrected code:

def loss_func_huber(x, y):
    data_dim = x.shape[1] * x.shape[2] * x.shape[3]
    huber_c = 0.00054 * math.sqrt(data_dim)
    ...
sangyun884 commented 3 months ago

Ah, thank you for catching this! We will reflect this in our new version of the paper.