Closed dallasfoster closed 1 month ago
Currently there's not mechanism to ensure the noise vector is on the same device as the input tensor.
I would maybe suggest converting floats to a tensor in the inits.
Then inside each call function, the amplitude is moved onto the same device as x.device
is it does not match at the moment.
Because the workflow determines the device, we shouldn't assume that the user will know what the device is at the time of construction.
Currently there's not mechanism to ensure the noise vector is on the same device as the input tensor.
I would maybe suggest converting floats to a tensor in the inits.
Then inside each call function, the amplitude is moved onto the same device as
x.device
is it does not match at the moment.Because the workflow determines the device, we shouldn't assume that the user will know what the device is at the time of construction.
Thanks for the catch, just pushed an update to include these suggestions.
LGTM
/blossom-ci
Earth2Studio Pull Request
Description
This is a backward compatible PR that adds the ability to pass broadcastable tensors as noise amplitudes in relevant perturbation methods.
Checklist