huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.56k stars 1.19k forks source link

Error with SFT of LLaVA-Next #1785

Closed GohioAC closed 1 month ago

GohioAC commented 3 months ago

I'm trying to instruction tune llava-next models following the llava_vsft.py examples shared for llava-1.5.

python vsft.py \
    --dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
    --model_name_or_path="llava-hf/llava-v1.6-mistral-7b-hf" \
    --report_to="tensorboard" \
    --learning_rate=2e-5 \
    --lr_scheduler_type="cosine" \
    --per_device_train_batch_size=8 \
    --gradient_accumulation_steps=1 \
    --output_dir="data/vsft-llava-1.5-7b-hf" \
    --logging_steps=1 \
    --num_train_epochs=1 \
    --gradient_checkpointing \
    --remove_unused_columns=False \
    --torch_dtype=float16 \
    --fp16=True \
    --max_seq_length=4096 \
    --attn_implementation="flash_attention_2"

The run keeps failing on a 8xH100 VM with the following error:

RuntimeError: Input tensor at index 1 has invalid shape [8, 2785, 32064], but expected [8, 2889, 32064]

The full code and error stack trace is available in this gist.

qgallouedec commented 2 months ago

Hi, sorry for the delay. Can you double-check the command. When I run it, I get

Traceback (most recent call last):
  File "/fsx/qgallouedec/trl-2/vsft.py", line 122, in <module>
    trainer.train()
  File "/fsx/qgallouedec/trl-2/trl/trainer/sft_trainer.py", line 440, in train
    output = super().train(*args, **kwargs)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.10/site-packages/transformers/trainer.py", line 2314, in _inner_training_loop
    _grad_norm = self.accelerator.clip_grad_norm_(
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 2269, in clip_grad_norm_
    self.unscale_gradients()
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.10/site-packages/accelerate/accelerator.py", line 2219, in unscale_gradients
    self.scaler.unscale_(opt)
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 307, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 229, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.
  0%|          | 0/32395 [00:05<?, ?it/s]    

Also share the versions of try, transformers and torch please

GohioAC commented 2 months ago

I have double checked the command, code and output. Versions are as follows:

trl: 0.9.4
transformers: 4.41.2
torch: 2.3.1
cuda: 12.4
python: 3.10.14
GohioAC commented 2 months ago

@qgallouedec any update on this? I created a new environment with the latest version of try and transformers but still facing the same issue.

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

TonyJiang17 commented 1 month ago

@GohioAC were you able to resolve it? I got a similar issue.

qgallouedec commented 1 month ago

Make sure to update to the dev version of transformers:

pip install git+https://github.com/huggingface/transformers.git
TonyJiang17 commented 1 month ago

thanks, I updated to the latest dev version of transforms but still ran into the issue. Could you shed more light why it occurs? I am trying to reproduce this notebook to finetune LLaVA-Next-Video for context (https://colab.research.google.com/drive/1dTdro-k7NFqRgGq5-TlGHM-6k2sYQhXp?usp=sharing#scrollTo=5da82ca2-e1db-4a7c-878f-aeed972ba9e6) @qgallouedec

qgallouedec commented 1 month ago

Unfortunately, video aren't supported yet

TonyJiang17 commented 1 month ago

@qgallouedec Hmmm, I believe this tutorial notebook (https://colab.research.google.com/drive/1dTdro-k7NFqRgGq5-TlGHM-6k2sYQhXp?usp=sharing#scrollTo=5da82ca2-e1db-4a7c-878f-aeed972ba9e6) on LLaVA-Next-Video finetuning was made by a member of the LLaVA Hugging Face team member, Raushan Turganbay... not sure if you know. Thanks

qgallouedec commented 1 month ago

Thank for the info. cc @zucchini-nlp. But the notebook doesn't use trl.

TonyJiang17 commented 1 month ago

Thanks it's not, i didn't know this is specific to TRL. @qgallouedec @zuchini-nlp any help on this error for notebook would be much appreciated! @zucchini-nlp any pointers on the library version used would be helpful. Thanks

zucchini-nlp commented 1 month ago

@TonyJiang17 hey! Yes, we had an issue with llava-next-video recently and a fix was added in the latest patch release. Can you make sure that you have the latest version and check if model inference works? I guess the error message above should be affecting generation and training

TonyJiang17 commented 1 month ago

hey @zucchini-nlp thanks for replying. I made sure i am using the latest patch release 4.44.2 and ran the notebook but still got into the following error when I tried to finetune the model... Inference works.

RuntimeError: Input tensor at index 1 has invalid shape [2, 1402, 32064], but expected [2, 1407, 32064]

I am running just a batch size of 2, number of frames 8. I made sure each input_ids is padded to a max length of 256. Is there some issue with the number of tokens used per frame? I assume the 1407 came from 255 + 12128, and there should be 144 tokens per frame? Any help would be much appreciated!

zucchini-nlp commented 1 month ago

@TonyJiang17 oke, let me check this

zucchini-nlp commented 1 month ago

@TonyJiang17 the example notebook works for me with the latest transformers, I tried on a tiny subset of the ShareGPT4Video data. I guess you're using your own dataset for tuning. Can you share with me how the dataset looks like after collating and the whole traceback so I can help you

TonyJiang17 commented 1 month ago

Hi @zucchini-nlp certainly, and thanks again for helping!

I am actually also using a tiny subset of the ShareGPT4Video data (i only loaded the mixit portion of it). Below is more information of the dataset after collating. I really just used your code. Please let me know if you needed more information. image. And I didn't change any code else where either.

I actually first ran into this tensor not only the same device bug. I am using a AWS Sagemaker notebook instance with access to 4 A10 GPUs.

The Bug traceback pasted below.

RuntimeError Traceback (most recent call last) Cell In[32], line 1 ----> 1 trainer.train()

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:1938, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs) 1936 hf_hub_utils.enable_progress_bars() 1937 else: -> 1938 return inner_training_loop( 1939 args=args, 1940 resume_from_checkpoint=resume_from_checkpoint, 1941 trial=trial, 1942 ignore_keys_for_eval=ignore_keys_for_eval, 1943 )

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:2279, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval) 2276 self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 2278 with self.accelerator.accumulate(model): -> 2279 tr_loss_step = self.training_step(model, inputs) 2281 if ( 2282 args.logging_nan_inf_filter 2283 and not is_torch_xla_available() 2284 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) 2285 ): 2286 # if loss is nan or inf simply add the average of previous logged losses 2287 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:3318, in Trainer.training_step(self, model, inputs) 3315 return loss_mb.reduce_mean().detach().to(self.args.device) 3317 with self.compute_loss_context_manager(): -> 3318 loss = self.compute_loss(model, inputs) 3320 del inputs 3321 if ( 3322 self.args.torch_empty_cache_steps is not None 3323 and self.state.global_step % self.args.torch_empty_cache_steps == 0 3324 ):

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:3363, in Trainer.compute_loss(self, model, inputs, return_outputs) 3361 else: 3362 labels = None -> 3363 outputs = model(**inputs) 3364 # Save past state if it exists 3365 # TODO: this needs to be fixed and made cleaner later. 3366 if self.args.past_index >= 0:

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/accelerate/utils/operations.py:820, in convert_outputs_to_fp32..forward(*args, kwargs) 819 def forward(*args, *kwargs): --> 820 return model_forward(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/accelerate/utils/operations.py:808, in ConvertOutputsToFp32.call(self, *args, kwargs) 807 def call(self, *args, *kwargs): --> 808 return convert_to_fp32(self.model_forward(args, kwargs))

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator..decorate_autocast(*args, kwargs) 13 @functools.wraps(func) 14 def decorate_autocast(*args, *kwargs): 15 with autocast_instance: ---> 16 return func(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/peft/peft_model.py:771, in PeftModel.forward(self, *args, kwargs) 769 with self._enable_peft_forward_hooks(*args, *kwargs): 770 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} --> 771 return self.get_base_model()(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, kwargs) 168 output = module._old_forward(*args, *kwargs) 169 else: --> 170 output = module._old_forward(args, kwargs) 171 return module._hf_hook.post_forward(module, output)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/models/llava_next_video/modeling_llava_next_video.py:937, in LlavaNextVideoForConditionalGeneration.forward(self, input_ids, pixel_values, pixel_values_videos, image_sizes, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, labels, use_cache, output_attentions, output_hidden_states, return_dict) 934 attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) 935 position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 --> 937 outputs = self.language_model( 938 attention_mask=attention_mask, 939 position_ids=position_ids, 940 past_key_values=past_key_values, 941 inputs_embeds=inputs_embeds, 942 use_cache=use_cache, 943 output_attentions=output_attentions, 944 output_hidden_states=output_hidden_states, 945 return_dict=return_dict, 946 ) 948 logits = outputs[0] 950 loss = None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1189, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position) 1186 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1188 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1189 outputs = self.model( 1190 input_ids=input_ids, 1191 attention_mask=attention_mask, 1192 position_ids=position_ids, 1193 past_key_values=past_key_values, 1194 inputs_embeds=inputs_embeds, 1195 use_cache=use_cache, 1196 output_attentions=output_attentions, 1197 output_hidden_states=output_hidden_states, 1198 return_dict=return_dict, 1199 cache_position=cache_position, 1200 ) 1202 hidden_states = outputs[0] 1203 if self.config.pretraining_tp > 1:

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:977, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position) 974 hidden_states = inputs_embeds 976 # create position embeddings to be shared across the decoder layers --> 977 position_embeddings = self.rotary_emb(hidden_states, position_ids) 979 # decoder layers 980 all_hidden_states = () if output_hidden_states else None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, kwargs) 168 output = module._old_forward(*args, *kwargs) 169 else: --> 170 output = module._old_forward(args, kwargs) 171 return module._hf_hook.post_forward(module, output)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, *kwargs): 114 with ctx_factory(): --> 115 return func(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:209, in LlamaRotaryEmbedding.forward(self, x, position_ids) 207 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" 208 with torch.autocast(device_type=device_type, enabled=False): --> 209 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 210 emb = torch.cat((freqs, freqs), dim=-1) 211 cos = emb.cos()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:3! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

TonyJiang17 commented 1 month ago

@zucchini-nlp After I removed the device_map = "auto" parameter when loading the pretrained model and lowered the batch size to 2, i no longer run into the above bug. I think it's just a work around as I think it's now just using a single GPU not the multi-GPU set up...

Regardless, after I remove the device_map = "auto" and ran the code again, I ran into the following tensor shape mismatch error similar to the original error in this issue thread.

The Bug traceback pasted below.

RuntimeError Traceback (most recent call last) Cell In[17], line 1 ----> 1 trainer.train()

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:1938, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs) 1936 hf_hub_utils.enable_progress_bars() 1937 else: -> 1938 return inner_training_loop( 1939 args=args, 1940 resume_from_checkpoint=resume_from_checkpoint, 1941 trial=trial, 1942 ignore_keys_for_eval=ignore_keys_for_eval, 1943 )

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:2279, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval) 2276 self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 2278 with self.accelerator.accumulate(model): -> 2279 tr_loss_step = self.training_step(model, inputs) 2281 if ( 2282 args.logging_nan_inf_filter 2283 and not is_torch_xla_available() 2284 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) 2285 ): 2286 # if loss is nan or inf simply add the average of previous logged losses 2287 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:3318, in Trainer.training_step(self, model, inputs) 3315 return loss_mb.reduce_mean().detach().to(self.args.device) 3317 with self.compute_loss_context_manager(): -> 3318 loss = self.compute_loss(model, inputs) 3320 del inputs 3321 if ( 3322 self.args.torch_empty_cache_steps is not None 3323 and self.state.global_step % self.args.torch_empty_cache_steps == 0 3324 ):

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/trainer.py:3363, in Trainer.compute_loss(self, model, inputs, return_outputs) 3361 else: 3362 labels = None -> 3363 outputs = model(**inputs) 3364 # Save past state if it exists 3365 # TODO: this needs to be fixed and made cleaner later. 3366 if self.args.past_index >= 0:

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:186, in DataParallel.forward(self, *inputs, **kwargs) 184 replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 185 outputs = self.parallel_apply(replicas, inputs, module_kwargs) --> 186 return self.gather(outputs, self.output_device)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py:203, in DataParallel.gather(self, outputs, output_device) 202 def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any: --> 203 return gather(outputs, output_device, dim=self.dim)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:105, in gather(outputs, target_device, dim) 102 # Recursive function calls like this create reference cycles. 103 # Setting the function to None clears the refcycle. 104 try: --> 105 res = gather_map(outputs) 106 finally: 107 gather_map = None # type: ignore[assignment]

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:96, in gather..gather_map(outputs) 94 if not all(len(out) == len(d) for d in outputs): 95 raise ValueError('All dicts must have the same number of keys') ---> 96 return type(out)((k, gather_map([d[k] for d in outputs])) 97 for k in out) 98 if _is_namedtuple(out): 99 return type(out)._make(map(gather_map, zip(*outputs)))

File :9, in init(self, loss, logits, past_key_values, hidden_states, attentions, image_hidden_states)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/transformers/utils/generic.py:390, in ModelOutput.__post_init__(self) 387 # if we provided an iterator as first field and the iterator is a (key, value) iterator 388 # set the associated fields 389 if first_field_iterator: --> 390 for idx, element in enumerate(iterator): 391 if ( 392 not isinstance(element, (list, tuple)) 393 or not len(element) == 2 394 or not isinstance(element[0], str) 395 ): 396 if idx == 0: 397 # If we do not have an iterator of key/values, set it as attribute

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:96, in (.0) 94 if not all(len(out) == len(d) for d in outputs): 95 raise ValueError('All dicts must have the same number of keys') ---> 96 return type(out)((k, gather_map([d[k] for d in outputs])) 97 for k in out) 98 if _is_namedtuple(out): 99 return type(out)._make(map(gather_map, zip(*outputs)))

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py:90, in gather..gather_map(outputs) 88 out = outputs[0] 89 if isinstance(out, torch.Tensor): ---> 90 return Gather.apply(target_device, dim, *outputs) 91 if out is None: 92 return None

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, *kwargs) 536 if not torch._C._are_functorch_transforms_active(): 537 # See NOTE: [functorch vjp and autograd interaction] 538 args = _functorch.utils.unwrap_dead_wrappers(args) --> 539 return super().apply(args, **kwargs) # type: ignore[misc] 541 if cls.setup_context == _SingleLevelFunction.setup_context: 542 raise RuntimeError( 543 "In order to use an autograd.Function with functorch transforms " 544 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 545 "staticmethod. For more details, please see " 546 "https://pytorch.org/docs/master/notes/extending.func.html" 547 )

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:75, in Gather.forward(ctx, target_device, dim, *inputs) 73 ctx.unsqueezed_scalar = False 74 ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) ---> 75 return comm.gather(inputs, ctx.dim, ctx.target_device)

File ~/SageMaker/llava-next-env/lib/python3.10/site-packages/torch/nn/parallel/comm.py:231, in gather(tensors, dim, destination, out) 227 warnings.warn( 228 'Using -1 to represent CPU tensor is deprecated. Please use a ' 229 'device object or string instead, e.g., "cpu".') 230 destination = _get_device_index(destination, allow_cpu=True, optional=True) --> 231 return torch._C._gather(tensors, dim, destination) 232 else: 233 if destination is not None:

RuntimeError: Input tensor at index 1 has invalid shape [2, 1402, 32064], but expected [2, 1407, 32064]