NVIDIA / earth2studio

Open-source deep-learning framework for exploring, building and deploying AI weather/climate workflows.
https://nvidia.github.io/earth2studio/
Apache License 2.0
73 stars 23 forks source link

🐛[BUG]: SphericalGuassian amplitude tensor not set to correct device #99

Closed NickGeneva closed 1 month ago

NickGeneva commented 1 month ago

Version

main

On which installation method(s) does this occur?

Source

Describe the issue

Using a noise method like this:

noise_amp = torch.zeros(73, 1, 1)
noise_amp[4] = 0.01 # t2m
perturbation = SphericalGaussian(noise_amp)

Produces an error in a workflow thats running on the GPU, the amplitude tensor should be moved to the correct device depending on the device of x in the call function. https://github.com/NVIDIA/earth2studio/blob/main/earth2studio/perturbation/spherical.py#L84

Should check other perturbations as well.

dallasfoster commented 1 month ago

Can you post a trace? This should be covered by https://github.com/NVIDIA/earth2studio/blob/3a0bc75e519381565d61a51ef167362cdcda3b62/earth2studio/perturbation/spherical.py#L114.

NickGeneva commented 1 month ago

Thats what I thought, and indeed that is the case. The trace was:

Traceback (most recent call last):
  File "/code/earth2-nims/nims/sfno/e2e/sfno_ensemble_studio.py", line 254, in <module>
    nim_workflow(
  File "/code/earth2-nims/nims/sfno/e2e/sfno_ensemble_studio.py", line 192, in nim_workflow
    x, coords = perturbation(x, coords)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/earth2studio/perturbation/spherical.py", line 109, in __call__
    return x + self.noise_amplitude * noise, coords
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

BUT some how python was still referencing an old version of the package instead of the one I was trying to install from source. Think its some issue with pip vs pip3 or something... False alarm!