martenlienen / torchode

A parallel ODE solver for PyTorch
https://torchode.readthedocs.io
MIT License
224 stars 12 forks source link

Non-jitted solver in readme example raises error. #19

Closed chenweize1998 closed 1 year ago

chenweize1998 commented 1 year ago

Hi! I found that when I don't jit the solver and use solver.solve instead of jit_solver.solve, the code raises the following error.

Traceback (most recent call last):
  File "/apdcephfs/share_47076/weizechen/ode_albert/bmtrain/tmp.py", line 17, in <module>
    sol = solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval))
  File "/usr/local/lib/python3.10/site-packages/torchode/adjoints.py", line 104, in solve
    dt, controller_state, f0 = step_size_controller.init(
  File "/usr/local/lib/python3.10/site-packages/torchode/step_size_controllers.py", line 346, in init
    dt0, f0 = self._select_initial_step(
  File "/usr/local/lib/python3.10/site-packages/torchode/step_size_controllers.py", line 475, in _select_initial_step
    torch.addcmul(t0, direction, dt0.to(dtype=t0.dtype)), y1, stats, args
RuntimeError: !(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ && (has_undefined_outputs || config.enforce_safe_casting_to_output_ || config.cast_common_dtype_to_outputs_)) INTERNAL ASSERT FAILED at "../aten/src/ATen/TensorIterator.cpp":405, please report a bug to PyTorch.

However, the jitted solver works just fine. It seems that the problem is caused by direction being int64, while other variables being float. When I manually convert direction = direction.to(t0.dtype), the code runs normally. Is it expected?

I'm using pytorch 1.12.1, functorch 0.2.1, sympy 1.11.1, torchtyping 0.1.4, and the newest torchode.

martenlienen commented 1 year ago

Hi, thanks for reporting this error. For me, the code runs just fine but I have applied your fix anyway since it is arguably even more correct and have released the new version 0.1.7 on pypi.