A kind reminder: do not use pytorch 1.12.0 for this project, which would throw an error:
_assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." AssertionError: If capturable=False, statesteps should not be CUDA tensors.
A kind reminder: do not use pytorch 1.12.0 for this project, which would throw an error:
_assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors." AssertionError: If capturable=False, statesteps should not be CUDA tensors.
I tried with torch 1.11.0, then it works.