NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.42k stars 1.32k forks source link

Shape mismatch issue in Fine-tuing of Idefics 2 Tutorial #434

Closed chang-changiti closed 2 weeks ago

chang-changiti commented 3 weeks ago

Referenced file: https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Idefics2/Fine_tune_Idefics2_for_JSON_extraction_use_cases_(PyTorch_Lightning).ipynb

Hi! In the above-mentioned tutorial, I face the following issue when I try to replicate the tutorial. I think there is some issue during evaluation step on validation dataset, but I can't seem to figure out the root cause.

RuntimeError: shape mismatch: value tensor of shape [128, 4096] cannot be broadcast to indexing result of shape [0, 4096]

Full stack trace:

File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 542 self.state.status = TrainerStatus.RUNNING 543 self.training = True --> 544 call._call_and_handle_interrupt( 545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 546 ) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 42 if trainer.strategy.launcher is not None: 43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) ---> 44 return trainer_fn(args, kwargs) 46 except _TunerExitException: 47 _call_teardown_hook(trainer) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 573 assert self.state.fn is not None 574 ckpt_path = self._checkpoint_connector._select_ckpt_path( 575 self.state.fn, 576 ckpt_path, 577 model_provided=True, 578 model_connected=self.lightning_module is not None, 579 ) --> 580 self._run(model, ckpt_path=ckpt_path) 582 assert self.state.stopped 583 self.training = False File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path) 982 self._signal_connector.register_signal_handlers() 984 # ---------------------------- 985 # RUN THE TRAINER 986 # ---------------------------- --> 987 results = self._run_stage() 989 # ---------------------------- 990 # POST-Training CLEAN UP 991 # ---------------------------- 992 log.debug(f"{self.class.name}: trainer tearing down") File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1033, in Trainer._run_stage(self) 1031 self._run_sanity_check() 1032 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1033 self.fit_loop.run() 1034 return None 1035 raise RuntimeError(f"Unexpected state {self.state}") File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:205, in _FitLoop.run(self) 203 try: 204 self.on_advance_start() --> 205 self.advance() 206 self.on_advance_end() 207 self._restarting = False File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:363, in _FitLoop.advance(self) 361 with self.trainer.profiler.profile("run_training_epoch"): 362 assert self._data_fetcher is not None --> 363 self.epoch_loop.run(self._data_fetcher) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:141, in _TrainingEpochLoop.run(self, data_fetcher) 139 try: 140 self.advance(data_fetcher) --> 141 self.on_advance_end(data_fetcher) 142 self._restarting = False 143 except StopIteration: File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:295, in _TrainingEpochLoop.on_advance_end(self, data_fetcher) 291 if not self._should_accumulate(): 292 # clear gradients to not leave any unused memory during validation 293 call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad") --> 295 self.val_loop.run() 296 self.trainer.training = True 297 self.trainer._logger_connector._first_loop_iter = first_loop_iter File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py:182, in _no_grad_context.._decorator(self, *args, kwargs) 180 context_manager = torch.no_grad 181 with context_manager(): --> 182 return loop_run(self, *args, *kwargs) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:135, in _EvaluationLoop.run(self) 133 self.batch_progress.is_last_batch = data_fetcher.done 134 # run step hooks --> 135 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter) 136 except StopIteration: 137 # this needs to wrap the `_stepcall too (not justnext) fordataloader_iter` support 138 break File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py:396, in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter) 390 hook_name = "test_step" if trainer.testing else "validation_step" 391 step_args = ( 392 self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name) 393 if not using_dataloader_iter 394 else (dataloader_iter,) 395 ) --> 396 output = call._call_strategy_hook(trainer, hook_name, step_args) 398 self.batch_progress.increment_processed() 400 if using_dataloader_iter: 401 # update the hook kwargs now that the step method might have consumed the iterator File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, args, kwargs) 306 return None 308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.class.name}.{hook_name}"): --> 309 output = fn(*args, kwargs) 311 # restore current_fx when nested context 312 pl_module._current_fx_name = prev_fx_name File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:412, in Strategy.validation_step(self, *args, *kwargs) 410 if self.model != self.lightning_module: 411 return self._forward_redirection(self.model, self.lightning_module, "validation_step", args, kwargs) --> 412 return self.lightning_module.validation_step(*args, kwargs) File , line 36, in Idefics2ModelPLModule.validation_step(self, batch, batch_idx, dataset_idx) 33 input_ids, attention_mask, pixel_values, pixel_attention_mask, answers = batch 35 # autoregressively generate token IDs ---> 36 generated_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, 37 pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, 38 max_new_tokens=768) 39 # turn them back into text, chopping of the prompt 40 # important: we don't skip special tokens here, because we want to see them in the output 41 predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True) File /databricks/python/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, *kwargs) 112 @functools.wraps(func) 113 def decorate_context(args, kwargs): 114 with ctx_factory(): --> 115 return func(*args, kwargs) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/generation/utils.py:1896, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, kwargs) 1888 input_ids, model_kwargs = self._expand_inputs_for_generation( 1889 input_ids=input_ids, 1890 expand_size=generation_config.num_return_sequences, 1891 is_encoder_decoder=self.config.is_encoder_decoder, 1892 model_kwargs, 1893 ) 1895 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False) -> 1896 result = self._sample( 1897 input_ids, 1898 logits_processor=prepared_logits_processor, 1899 logits_warper=prepared_logits_warper, 1900 stopping_criteria=prepared_stopping_criteria, 1901 generation_config=generation_config, 1902 synced_gpus=synced_gpus, 1903 streamer=streamer, 1904 model_kwargs, 1905 ) 1907 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): 1908 # 11. prepare logits warper 1909 prepared_logits_warper = ( 1910 self._get_logits_warper(generation_config, device=input_ids.device) 1911 if generation_config.do_sample 1912 else None 1913 ) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/generation/utils.py:2633, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, model_kwargs) 2630 model_inputs = self.prepare_inputs_for_generation(input_ids, model_kwargs) 2632 # forward pass to get next token -> 2633 outputs = self( 2634 model_inputs, 2635 return_dict=True, 2636 output_attentions=output_attentions, 2637 output_hidden_states=output_hidden_states, 2638 ) 2640 if synced_gpus and this_peer_finished: 2641 continue # don't waste resources running the code we don't need File /databricks/python/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] File /databricks/python/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(module, *args, kwargs) 163 output = module._old_forward(*args, *kwargs) 164 else: --> 165 output = module._old_forward(args, kwargs) 166 return module._hf_hook.post_forward(module, output) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py:1829, in Idefics2ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict) 1826 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1828 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1829 outputs = self.model( 1830 input_ids=input_ids, 1831 attention_mask=attention_mask, 1832 position_ids=position_ids, 1833 past_key_values=past_key_values, 1834 inputs_embeds=inputs_embeds, 1835 pixel_values=pixel_values, 1836 pixel_attention_mask=pixel_attention_mask, 1837 image_hidden_states=image_hidden_states, 1838 use_cache=use_cache, 1839 output_attentions=output_attentions, 1840 output_hidden_states=output_hidden_states, 1841 return_dict=return_dict, 1842 ) 1844 hidden_states = outputs[0] 1845 logits = self.lm_head(hidden_states) File /databricks/python/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, *kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] File /databricks/python/lib/python3.10/site-packages/accelerate/hooks.py:165, in add_hook_to_module..new_forward(module, args, kwargs) 163 output = module._old_forward(*args, *kwargs) 164 else: --> 165 output = module._old_forward(args, **kwargs) 166 return module._hf_hook.post_forward(module, output) File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py:1656, in Idefics2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict) 1651 image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) 1653 if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: 1654 # When we generate, we don't want to replace the potential image_token_id that we generated by images 1655 # that simply don't exist -> 1656 inputs_embeds = self.inputs_merger( 1657 input_ids=input_ids, 1658 inputs_embeds=inputs_embeds, 1659 image_hidden_states=image_hidden_states, 1660 ) 1662 outputs = self.text_model( 1663 inputs_embeds=inputs_embeds, 1664 attention_mask=attention_mask, (...) 1669 return_dict=return_dict, 1670 ) 1672 if return_legacy_cache: File /local_disk0/.ephemeral_nfs/envs/pythonEnv-836bb24d-b821-4dba-9f1e-156ae8dec084/lib/python3.10/site-packages/transformers/models/idefics2/modeling_idefics2.py:1542, in Idefics2Model.inputs_merger(self, input_ids, inputs_embeds, image_hidden_states) 1540 new_inputs_embeds = inputs_embeds.clone() 1541 reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size) -> 1542 new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states 1543 return new_inputs_embeds

chang-changiti commented 2 weeks ago

found the fix to this. we'll need to install the latest version of transformer package (as of time of commenting, v4.41.2)