mllam / neural-lam

Neural Weather Prediction for Limited Area Modeling
MIT License
64 stars 24 forks source link

25 normalize data on gpu #39

Open sadamov opened 1 month ago

sadamov commented 1 month ago

Summary

This PR introduces on_after_batch_transfer logic to normalize data on GPU instead of on CPU before transfer.

Rationale

Normalization is faster on GPU than CPU. In the current code the data was normalized in the pytorch dataset class in the __get_item__ method potentially slowing down training; especially on systems with fewer CPU cores.

Changes

Testing

Both training and evaluation was successful. The training loss of 3.230 on the meps_example is identical to before the changes. The create_parameter_weights script was executed to successfully generate the stats.

Not-In-Scope

The normalization stats and other static features will all become zarr archives in the future. Their path defined in the data_config.yaml file.

leifdenby commented 1 month ago

This is looking good @sadamov! Can I suggest we merge #38 first, you merge that into your branch and then I do a review? That way we can ensure everything keeps working :)

sadamov commented 2 weeks ago

I merged with the latest updates from main @leifdenby and commented on your suggestions @joeloskarsson. In the following I want to show that the output tensor was not affected by this change. As you suggested I stored the last batch for both gpu/cpu normalization with deterministic=True and a set seed:

Pasted image

Then I compared the tensors like this:

Screenshot from 2024-06-07 21-06-37

So except for the forcings tensor which are handled differently the two approaches create identical output.

joeloskarsson commented 2 weeks ago

How do you think we should progress with this @sadamov ? If the forcing is handled differently in #54, would it make more sense to try to merge this after that? (I guess baking this change into #54 would just make it even bigger). Or should we merge this first so that #54 can build on it and be adapted to use it?

I would not be happy to merge this without fixing so also the forcing tensors match the previous implementation. But we could just do a quick fix for now so the standardization is only applied to the flux dimensions of the forcing tensor?

sadamov commented 2 weeks ago

I propose to merge #54 first and leave this open until then. We now know that on_after_batch_transfer works as expected and can shortly after implement this PR.