cloneofsimo / minRF

Minimal implementation of scalable rectified flow transformers, based on SD3's approach
Apache License 2.0
426 stars 29 forks source link

Problem in loss function #5

Closed mephisto28 closed 3 months ago

mephisto28 commented 3 months ago

https://github.com/cloneofsimo/minRF/blob/72feb0c87d435e9f9d220f34f348ed66c0b6ccec/advanced/main.py#L60

I guess the loss should be batchwise_mse = ((z1 - x - vtheta * texp) ** 2).mean(dim=list(range(1, len(x.shape))))... is it?

The current code says the predict target in training should be the distance other than velocity. If not changing the training target the inference code could be modified by dividing the predicted value by timestep to gain the velocity.

cloneofsimo commented 3 months ago

Sorry what?

mephisto28 commented 3 months ago

Sorry what?

If mse = (model_pred - x) ** 2, it should be model_pred = z1 - vtheta * t instead of model_pred = z1 - vtheta ... is it?

cloneofsimo commented 3 months ago

nah