Closed ReHoss closed 2 months ago
Do you have a MWE ?
This torch.empty
doesn't impose any type since dtype=None
here (https://pytorch.org/docs/stable/generated/torch.empty.html).
If you want to use float64
, please set :
torch.set_default_dtype(torch.float64)
The code won't work here if the data type of ys
is different from the data type of ts
or y0
. I came across this flaw.
I don't think the right solution is to force the user to set the tensor global type. Indeed, for instance, he could manipulate float32
data in some part of the code and float64
in another part.
In my opinion, a data type uniformity check should be done on the input arguments such as ts
and y0
(also t0
and t1
). Then, in the scope of this function, the data type of ys
should match the data type of the input.
This is intended y0
defines ys
's type. Pytorch default type is float32 so you have to specify and define all your tensors accordingly
This is intended
y0
definesys
's type. Pytorch default type is float32 so you have to specify and define all your tensors accordingly
I don't see where in the code it defines ys
's data type. To me, it defines only the device
.
Your suggestion is to add dtype=y0.dtype
?
Your suggestion is to add
dtype=y0.dtype
?
From above:
ys
datatype accordingly, possibly with dtype=y0.dtype
The associated unit test should capture some dtype
error if the input data types are not uniform (e.g. y0
as float32
and ts
as float64
). It should also pass without setting the global data type with torch.set_default_dtype
.
Latest commit (ie https://github.com/thibmonsel/torchdde/commit/a659f8c62e9274a86e6bdbad722e4f2e3666f265) should fix the issue. If so please close the issue. If not please provide a MWE.
For information :
>>> import torch
>>> a = torch.arange(10, dtype=torch.float64)
>>> b = 2 * torch.arange(10, dtype=torch.float32)
>>> c = a * b
>>> c.dtype
torch.float64
Closing this issue since there's no more activity.
This line imposes the data type
torch.float32
by default. Consequently, integration withtorch.float64
can't be performed because of floating point precision mix.https://github.com/thibmonsel/torchdde/blob/b28e4ef90e2c40bcfa5fc58b5b39ddca91f79164/torchdde/integrate.py#L386