TRI-ML / prismatic-vlms

A flexible and efficient codebase for training visually-conditioned language models (VLMs)
MIT License
425 stars 194 forks source link

training error #15

Closed tayton42 closed 5 months ago

tayton42 commented 5 months ago

Hello!When I encountered the following issue while executing the training command example, I didn't understand what caused it.

    train_strategy.run_training(train_dataset, collator, metrics, stage=cfg.stage, seed=cfg.seed)
  File "/opt/cv/tianyutong/prismatic-vlms/prismatic/training/strategies/base_strategy.py", line 181, in run_training
    output: CausalLMOutputWithPast = self.vlm(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 839, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/cv/tianyutong/prismatic-vlms/prismatic/models/vlms/prismatic.py", line 410, in forward
    return self.llm_backbone(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/cv/tianyutong/prismatic-vlms/prismatic/models/backbones/llm/base_llm.py", line 200, in forward
    output: CausalLMOutputWithPast = self.llm(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 837, in forward
    shift_logits = shift_logits.view(-1, self.config.vocab_size)
RuntimeError: shape '[-1, 32001]' is invalid for input of size 389385216
siddk commented 5 months ago

Can you confirm the version of HF Transformers you’re using?

tayton42 commented 5 months ago

Can you confirm the version of HF Transformers you’re using?

It's an issue with the version of transformers, thank you.