the above code would get nan loss and especially substate1 would be tensor([-1., -2., 0., nan]) after two iteration. however if calculate substate1 mannually, the result can be evaluated correctly to tensor([-1., -2., 0., 0.]).
The last element should be evaluated to 0 not nan.
the above code would get nan loss and especially substate1 would be tensor([-1., -2., 0., nan]) after two iteration. however if calculate substate1 mannually, the result can be evaluated correctly to tensor([-1., -2., 0., 0.]). The last element should be evaluated to 0 not nan.
cc @vincentqb @jbschlosser @albanD