NVIDIA / earth2studio

Open-source deep-learning framework for exploring, building and deploying AI weather/climate workflows.
https://nvidia.github.io/earth2studio/
Apache License 2.0
73 stars 23 forks source link

Add tensor noise amplitude parameters for perturbation methods #93

Closed dallasfoster closed 1 month ago

dallasfoster commented 1 month ago

Earth2Studio Pull Request

Description

This is a backward compatible PR that adds the ability to pass broadcastable tensors as noise amplitudes in relevant perturbation methods.

Checklist

NickGeneva commented 1 month ago

Currently there's not mechanism to ensure the noise vector is on the same device as the input tensor.

I would maybe suggest converting floats to a tensor in the inits.

Then inside each call function, the amplitude is moved onto the same device as x.device is it does not match at the moment.

Because the workflow determines the device, we shouldn't assume that the user will know what the device is at the time of construction.

dallasfoster commented 1 month ago

Currently there's not mechanism to ensure the noise vector is on the same device as the input tensor.

I would maybe suggest converting floats to a tensor in the inits.

Then inside each call function, the amplitude is moved onto the same device as x.device is it does not match at the moment.

Because the workflow determines the device, we shouldn't assume that the user will know what the device is at the time of construction.

Thanks for the catch, just pushed an update to include these suggestions.

NickGeneva commented 1 month ago

LGTM

dallasfoster commented 1 month ago

/blossom-ci