Open sadamov opened 1 month ago
This is looking good @sadamov! Can I suggest we merge #38 first, you merge that into your branch and then I do a review? That way we can ensure everything keeps working :)
I merged with the latest updates from main @leifdenby and commented on your suggestions @joeloskarsson. In the following I want to show that the output tensor was not affected by this change. As you suggested I stored the last batch for both gpu/cpu normalization with deterministic=True
and a set seed:
Then I compared the tensors like this:
So except for the forcings tensor which are handled differently the two approaches create identical output.
How do you think we should progress with this @sadamov ? If the forcing is handled differently in #54, would it make more sense to try to merge this after that? (I guess baking this change into #54 would just make it even bigger). Or should we merge this first so that #54 can build on it and be adapted to use it?
I would not be happy to merge this without fixing so also the forcing tensors match the previous implementation. But we could just do a quick fix for now so the standardization is only applied to the flux dimensions of the forcing tensor?
I propose to merge #54 first and leave this open until then. We now know that on_after_batch_transfer
works as expected and can shortly after implement this PR.
Summary
This PR introduces
on_after_batch_transfer
logic to normalize data on GPU instead of on CPU before transfer.Rationale
Normalization is faster on GPU than CPU. In the current code the data was normalized in the pytorch dataset class in the __get_item__ method potentially slowing down training; especially on systems with fewer CPU cores.
Changes
on_after_batch_transfer
in thear_model.py
scriptcreate_parameter_weights.py
script to work with the new changes (not reloading standardized dataset)Testing
Both training and evaluation was successful. The training loss of 3.230 on the meps_example is identical to before the changes. The create_parameter_weights script was executed to successfully generate the stats.
Not-In-Scope
The normalization stats and other static features will all become zarr archives in the future. Their path defined in the data_config.yaml file.