Thank you very much for your kind work. When I was running the training example, I encountered the following error. It seems that the dimension of the conditional feature does not match the needed dimension during cross-attention with input.
File "/mnt/workspace/Vista/vwm/models/diffusion.py", line 198, in forward
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) # go to StandardDiffusionLoss
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/loss.py", line 60, in forward
return self._forward(network, denoiser, cond, input)
File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/loss.py", line 93, in _forward
model_output = denoiser(network, noised_input, sigmas, cond, cond_mask)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/denoiser.py", line 35, in forward
return (network(noised_input * c_in, c_noise, cond, cond_mask, self.num_frames) * c_out + noised_input * c_skip)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/wrappers.py", line 32, in forward
return self.diffusion_model(
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/video_model.py", line 475, in forward
h = module(
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/openaimodel.py", line 48, in forward
x = layer(x, context, time_context, num_frames)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/video_attention.py", line 282, in forward
x = block(x, context=spatial_context)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/attention.py", line 510, in forward
return checkpoint(self._forward, x, context)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 482, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 261, in forward
outputs = run_function(*args)
File "/mnt/workspace/Vista/vwm/modules/attention.py", line 521, in _forward
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/workspace/Vista/vwm/modules/attention.py", line 345, in forward
k = self.to_k(context)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (25x3456 and 1024x320)
Sorry for the inconvenience, I forgot to turn on action_control in nusc_train.yaml. It should be "True" to expand the original cross-attention dimension for extra action inputs. Please use the latest config.
Thank you very much for your kind work. When I was running the training example, I encountered the following error. It seems that the dimension of the conditional feature does not match the needed dimension during cross-attention with input.