Open levtelyatnikov opened 3 years ago
Have you tried this? https://github.com/stevenygd/PointFlow/issues/17#issuecomment-705793859
Changing line 64-65 in cnf.py to follow will help in my situation:
atol = self.atol #[self.atol] * 3
rtol = self.rtol #[self.rtol] * 3
It seems that the torchdiffeq lib has changed its interface.
Hello, I have tried all scripts and everywhere I have the same issue.
Traceback (most recent call last): File "train.py", line 272, in
main()
File "train.py", line 268, in main
main_worker(args.gpu, save_dir, ngpus_per_node, args)
File "train.py", line 176, in main_worker
out = model(inputs, optimizer, step, writer)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, *kwargs)
File "/content/PointFlow/models/networks.py", line 167, in forward
loss.backward()
File "/usr/local/lib/python3.7/dist-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/function.py", line 89, in apply
return self._forward_cls.backward(self, args) # type: ignore
File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/adjoint.py", line 129, in backward
rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options
File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint
shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/misc.py", line 209, in _check_inputs
rtol = _tuple_tol('rtol', rtol, shapes)
File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/misc.py", line 115, in _tuple_tol
assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name)
AssertionError: If using tupled rtol it must have the same length as the tuple y0
Done