OpenDriveLab / Vista

A Generalizable World Model for Autonomous Driving
https://vista-demo.github.io
Apache License 2.0
362 stars 16 forks source link

RuntimeError: mat1 and mat2 shapes cannot be multiplied (25x3456 and 1024x320) #8

Closed freeEntropy closed 2 weeks ago

freeEntropy commented 2 weeks ago

Thank you very much for your kind work. When I was running the training example, I encountered the following error. It seems that the dimension of the conditional feature does not match the needed dimension during cross-attention with input.

File "/mnt/workspace/Vista/vwm/models/diffusion.py", line 198, in forward
    loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)  # go to StandardDiffusionLoss
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/loss.py", line 60, in forward
    return self._forward(network, denoiser, cond, input)
  File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/loss.py", line 93, in _forward
    model_output = denoiser(network, noised_input, sigmas, cond, cond_mask)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/denoiser.py", line 35, in forward
    return (network(noised_input * c_in, c_noise, cond, cond_mask, self.num_frames) * c_out + noised_input * c_skip)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/wrappers.py", line 32, in forward
    return self.diffusion_model(
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/video_model.py", line 475, in forward
    h = module(
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/diffusionmodules/openaimodel.py", line 48, in forward
    x = layer(x, context, time_context, num_frames)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/video_attention.py", line 282, in forward
    x = block(x, context=spatial_context)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/attention.py", line 510, in forward
    return checkpoint(self._forward, x, context)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 482, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 261, in forward
    outputs = run_function(*args)
  File "/mnt/workspace/Vista/vwm/modules/attention.py", line 521, in _forward
    x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workspace/Vista/vwm/modules/attention.py", line 345, in forward
    k = self.to_k(context)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/t2v/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (25x3456 and 1024x320)
Little-Podi commented 2 weeks ago

Sorry for the inconvenience, I forgot to turn on action_control in nusc_train.yaml. It should be "True" to expand the original cross-attention dimension for extra action inputs. Please use the latest config.