Hi! I found that when I don't jit the solver and use solver.solve instead of jit_solver.solve, the code raises the following error.
Traceback (most recent call last):
File "/apdcephfs/share_47076/weizechen/ode_albert/bmtrain/tmp.py", line 17, in <module>
sol = solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval))
File "/usr/local/lib/python3.10/site-packages/torchode/adjoints.py", line 104, in solve
dt, controller_state, f0 = step_size_controller.init(
File "/usr/local/lib/python3.10/site-packages/torchode/step_size_controllers.py", line 346, in init
dt0, f0 = self._select_initial_step(
File "/usr/local/lib/python3.10/site-packages/torchode/step_size_controllers.py", line 475, in _select_initial_step
torch.addcmul(t0, direction, dt0.to(dtype=t0.dtype)), y1, stats, args
RuntimeError: !(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ && (has_undefined_outputs || config.enforce_safe_casting_to_output_ || config.cast_common_dtype_to_outputs_)) INTERNAL ASSERT FAILED at "../aten/src/ATen/TensorIterator.cpp":405, please report a bug to PyTorch.
However, the jitted solver works just fine. It seems that the problem is caused by direction being int64, while other variables being float. When I manually convert direction = direction.to(t0.dtype), the code runs normally. Is it expected?
I'm using pytorch 1.12.1, functorch 0.2.1, sympy 1.11.1, torchtyping 0.1.4, and the newest torchode.
Hi, thanks for reporting this error. For me, the code runs just fine but I have applied your fix anyway since it is arguably even more correct and have released the new version 0.1.7 on pypi.
Hi! I found that when I don't jit the solver and use
solver.solve
instead ofjit_solver.solve
, the code raises the following error.However, the jitted solver works just fine. It seems that the problem is caused by
direction
being int64, while other variables being float. When I manually convertdirection = direction.to(t0.dtype)
, the code runs normally. Is it expected?I'm using pytorch 1.12.1, functorch 0.2.1, sympy 1.11.1, torchtyping 0.1.4, and the newest torchode.