In training, I encountered some problems :
Traceback (most recent call last):
File "/home/gukaifeng/桌面/StyleFlow-master/train_flow.py", line 97, in
loss.backward()
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torch/autograd/init.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 253, in apply
return user_fn(self, *args)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/adjoint.py", line 126, in backward
aug_state = odeint(
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint
shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/misc.py", line 207, in _check_inputs
rtol = _tuple_tol('rtol', rtol, shapes)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/misc.py", line 115, in _tuple_tol
assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name)
AssertionError: If using tupled rtol it must have the same length as the tuple y0
In training, I encountered some problems : Traceback (most recent call last): File "/home/gukaifeng/桌面/StyleFlow-master/train_flow.py", line 97, in
loss.backward()
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 363, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torch/autograd/init.py", line 173, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torch/autograd/function.py", line 253, in apply
return user_fn(self, *args)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/adjoint.py", line 126, in backward
aug_state = odeint(
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint
shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/misc.py", line 207, in _check_inputs
rtol = _tuple_tol('rtol', rtol, shapes)
File "/home/gukaifeng/anaconda3/lib/python3.9/site-packages/torchdiffeq/_impl/misc.py", line 115, in _tuple_tol
assert len(tol) == len(shapes), "If using tupled {} it must have the same length as the tuple y0".format(name)
AssertionError: If using tupled rtol it must have the same length as the tuple y0