Hello!When I encountered the following issue while executing the training command example, I didn't understand what caused it.
train_strategy.run_training(train_dataset, collator, metrics, stage=cfg.stage, seed=cfg.seed)
File "/opt/cv/tianyutong/prismatic-vlms/prismatic/training/strategies/base_strategy.py", line 181, in run_training
output: CausalLMOutputWithPast = self.vlm(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
output = self._fsdp_wrapped_module(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/cv/tianyutong/prismatic-vlms/prismatic/models/vlms/prismatic.py", line 410, in forward
return self.llm_backbone(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/cv/tianyutong/prismatic-vlms/prismatic/models/backbones/llm/base_llm.py", line 200, in forward
output: CausalLMOutputWithPast = self.llm(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 837, in forward
shift_logits = shift_logits.view(-1, self.config.vocab_size)
RuntimeError: shape '[-1, 32001]' is invalid for input of size 389385216
Hello!When I encountered the following issue while executing the training command example, I didn't understand what caused it.