Open charlescharles opened 2 months ago
@gbruno16 would you be able to look at this? Thanks in advance
Hi! Thank you for your valuable feedback, it’s incredibly helpful as we continue to refine this implementation. Just a heads-up: this project is still a work in progress, so there are a few areas we're actively debugging and improving.
Regarding the MSE loss as a metric for the denoiser output, it’s understandable to expect a decrease during sampling. However, the denoiser is trained to be the best possible prediction in terms of expected MSE, explaining why the first prediction is the best in this metric. During sampling steps no “new information” is provided and what we expect to see instead is an increase in the amount of detail generated from the underlying distribution, even if that’s very challenging to notice at lower resolutions. Note that we still haven't trained a high resolution model. The noisy sampling outputs may be caused by underfitting and errors accumulation during the sampling steps. Also, just a note: the lower pressure levels are less trained by default, but you can adjust this using the weights in WeightedMSELoss if needed. You could try to see if a "more trained" variable shows a better behavior (such as 2m_temperature). Anyway, in general we’re aware of some noise accumulation issues in the model and are exploring several alternatives to the way isotropic noise is generated and the training strategies to enhance sample quality. We’ll provide updates as we make progress!
I trained a 128x64 model (with sparse=False); if I record the ~39 denoiser outputs during the course of a single sampling loop and compute their mse loss (not weighted by lambda_sigma because this doesn't make sense during inference), the loss trajectory looks something like this:
If I visualize the denoiser output at the very first inference step, the denoiser output at the very last inference step, and also the sampling result, then indeed the denoiser output at inference step 0 looks much closer to the target:
I would expect the mse loss of the intermediate denoiser outputs to decrease during sampling.