stevenygd / PointFlow

PointFlow : 3D Point Cloud Generation with Continuous Normalizing Flows
https://www.guandaoyang.com/PointFlow/
MIT License
720 stars 101 forks source link

Train script doesn't work #23

Open levtelyatnikov opened 3 years ago

levtelyatnikov commented 3 years ago

Hello, I have tried all scripts and everywhere I have the same issue.

Traceback (most recent call last): File "train.py", line 272, in main() File "train.py", line 268, in main main_worker(args.gpu, save_dir, ngpus_per_node, args) File "train.py", line 176, in main_worker out = model(inputs, optimizer, step, writer) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "/content/PointFlow/models/networks.py", line 167, in forward loss.backward() File "/usr/local/lib/python3.7/dist-packages/torch/tensor.py", line 221, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph) File "/usr/local/lib/python3.7/dist-packages/torch/autograd/init.py", line 132, in backward allow_unreachable=True) # allow_unreachable flag File "/usr/local/lib/python3.7/dist-packages/torch/autograd/function.py", line 89, in apply return self._forward_cls.backward(self, args) # type: ignore File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/adjoint.py", line 129, in backward rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS) File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/misc.py", line 209, in _check_inputs rtol = _tuple_tol('rtol', rtol, shapes) File "/usr/local/lib/python3.7/dist-packages/torchdiffeq/_impl/misc.py", line 115, in _tuple_tol assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name) AssertionError: If using tupled rtol it must have the same length as the tuple y0 Done

stevenygd commented 3 years ago

Have you tried this? https://github.com/stevenygd/PointFlow/issues/17#issuecomment-705793859

SymenYang commented 3 years ago

Changing line 64-65 in cnf.py to follow will help in my situation:

atol = self.atol #[self.atol] * 3
rtol = self.rtol #[self.rtol] * 3

It seems that the torchdiffeq lib has changed its interface.