thibmonsel / torchdde

Neural network compatible DDEs
https://thibmonsel.github.io/torchdde/
Apache License 2.0
9 stars 0 forks source link

`ys` dtype in _integrate_ode #33

Closed ReHoss closed 2 months ago

ReHoss commented 2 months ago

This line imposes the data type torch.float32 by default. Consequently, integration with torch.float64 can't be performed because of floating point precision mix.

https://github.com/thibmonsel/torchdde/blob/b28e4ef90e2c40bcfa5fc58b5b39ddca91f79164/torchdde/integrate.py#L386

thibmonsel commented 2 months ago

Do you have a MWE ?

thibmonsel commented 2 months ago

This torch.empty doesn't impose any type since dtype=None here (https://pytorch.org/docs/stable/generated/torch.empty.html).

If you want to use float64, please set :


torch.set_default_dtype(torch.float64)
ReHoss commented 2 months ago

The code won't work here if the data type of ys is different from the data type of ts or y0. I came across this flaw.

I don't think the right solution is to force the user to set the tensor global type. Indeed, for instance, he could manipulate float32 data in some part of the code and float64 in another part.

In my opinion, a data type uniformity check should be done on the input arguments such as ts and y0 (also t0 and t1). Then, in the scope of this function, the data type of ys should match the data type of the input.

https://github.com/thibmonsel/torchdde/blob/b28e4ef90e2c40bcfa5fc58b5b39ddca91f79164/torchdde/integrate.py#L386

thibmonsel commented 2 months ago

This is intended y0 defines ys's type. Pytorch default type is float32 so you have to specify and define all your tensors accordingly

ReHoss commented 2 months ago

This is intended y0 defines ys's type. Pytorch default type is float32 so you have to specify and define all your tensors accordingly

I don't see where in the code it defines ys's data type. To me, it defines only the device.

thibmonsel commented 2 months ago

Your suggestion is to add dtype=y0.dtype ?

ReHoss commented 2 months ago

Your suggestion is to add dtype=y0.dtype ?

From above:

The associated unit test should capture some dtype error if the input data types are not uniform (e.g. y0 as float32 and ts as float64). It should also pass without setting the global data type with torch.set_default_dtype.

thibmonsel commented 2 months ago

Latest commit (ie https://github.com/thibmonsel/torchdde/commit/a659f8c62e9274a86e6bdbad722e4f2e3666f265) should fix the issue. If so please close the issue. If not please provide a MWE.

For information :


>>> import torch 
>>> a = torch.arange(10, dtype=torch.float64)
>>> b = 2 * torch.arange(10, dtype=torch.float32)
>>> c = a * b 
>>> c.dtype
torch.float64
thibmonsel commented 2 months ago

Closing this issue since there's no more activity.