Closed WuJiunShiung closed 1 year ago
Hello @WuJiunShiung, I'm closing this issue in favor of #5114, where the problem was originally reported. We're currently trying to address it (see #5197). Probably the bugfix will be released in 1.18.0.
In the meantime, you can downgrade to Haystack 1.16.1, if you need to fine-tune your model.
Describe the bug I tried the code in Tutorial: Fine-tuning a model on your own data. Running the code for distillation, I encountered an error on this line: student.distil_prediction_layer_from(teacher, data_dir="data/squad20", train_filename="dev-v2.0.json", use_gpu=True) 13 student.save(directory="my_distilled_model")
Error message
RuntimeError Traceback (most recent call last) Cell In[15], line 11 8 student = FARMReader(model_name_or_path="huawei-noah/TinyBERT_General_6L_768D", use_gpu=True) 10 student.distil_intermediate_layers_from(teacher, data_dir=".", train_filename="augmented_dataset.json", use_gpu=True) ---> 11 student.distil_prediction_layer_from(teacher, data_dir="data/squad20", train_filename="dev-v2.0.json", use_gpu=True) 13 student.save(directory="my_distilled_model")
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/nodes/reader/farm.py:564, in FARMReader.distil_prediction_layer_from(self, teacher_model, data_dir, train_filename, dev_filename, test_filename, use_gpu, devices, batch_size, teacher_batch_size, n_epochs, learning_rate, max_seq_len, warmup_proportion, dev_split, evaluate_every, save_dir, num_processes, use_amp, checkpoint_root_dir, checkpoint_every, checkpoints_to_keep, caching, cache_path, distillation_loss_weight, distillation_loss, temperature, processor, grad_acc_steps, early_stopping) 499 """ 500 Fine-tune a model on a QA dataset using logit-based distillation. You need to provide a teacher model that is already finetuned on the dataset 501 and a student model that will be trained using the teacher's logits. The idea of this is to increase the accuracy of a lightweight student model. (...) 558 :return: None 559 """ 560 send_event( 561 event_name="Training", 562 event_properties={"class": self.class.name, "function_name": "distil_prediction_layer_from"}, 563 ) --> 564 return self._training_procedure( 565 data_dir=data_dir, 566 train_filename=train_filename, 567 dev_filename=dev_filename, 568 test_filename=test_filename, 569 use_gpu=use_gpu, 570 devices=devices, 571 batch_size=batch_size, 572 n_epochs=n_epochs, 573 learning_rate=learning_rate, 574 max_seq_len=max_seq_len, 575 warmup_proportion=warmup_proportion, 576 dev_split=dev_split, 577 evaluate_every=evaluate_every, 578 save_dir=save_dir, 579 num_processes=num_processes, 580 use_amp=use_amp, 581 checkpoint_root_dir=checkpoint_root_dir, 582 checkpoint_every=checkpoint_every, 583 checkpoints_to_keep=checkpoints_to_keep, 584 teacher_model=teacher_model, 585 teacher_batch_size=teacher_batch_size, 586 caching=caching, 587 cache_path=cache_path, 588 distillation_loss_weight=distillation_loss_weight, 589 distillation_loss=distillation_loss, 590 temperature=temperature, 591 processor=processor, 592 grad_acc_steps=grad_acc_steps, 593 early_stopping=early_stopping, 594 distributed=False, 595 )
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/nodes/reader/farm.py:260, in FARMReader._training_procedure(self, data_dir, train_filename, dev_filename, test_filename, use_gpu, devices, batch_size, n_epochs, learning_rate, max_seq_len, warmup_proportion, dev_split, evaluate_every, save_dir, num_processes, use_amp, checkpoint_root_dir, checkpoint_every, checkpoints_to_keep, teacher_model, teacher_batch_size, caching, cache_path, distillation_loss_weight, distillation_loss, temperature, tinybert, processor, grad_acc_steps, early_stopping, distributed, doc_stride, max_query_length) 255 # 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them 256 # and calculates a few descriptive statistics of our datasets 257 if ( 258 teacher_model and not tinybert 259 ): # checks if teacher model is passed as parameter, in that case assume model distillation is used --> 260 data_silo = DistillationDataSilo( 261 teacher_model, 262 teacher_batch_size or batch_size, 263 device=devices[0], 264 processor=processor, 265 batch_size=batch_size, 266 distributed=distributed, 267 max_processes=num_processes, 268 caching=caching, 269 cache_path=cache_path, 270 ) 271 else: # caching would need too much memory for tinybert distillation so in that case we use the default data silo 272 data_silo = DataSilo( 273 processor=processor, 274 batch_size=batch_size, (...) 278 cache_path=cache_path, 279 )
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/modeling/data_handler/data_silo.py:750, in DistillationDataSilo.init(self, teacher_model, teacher_batch_size, device, processor, batch_size, eval_batch_size, distributed, automatic_loading, max_processes, caching, cache_path) 748 self.device = device 749 max_processes = 1 # fix as long as multithreading is not working with teacher attribute --> 750 super().init( 751 max_processes=max_processes, 752 processor=processor, 753 batch_size=batch_size, 754 eval_batch_size=eval_batch_size, 755 distributed=distributed, 756 automatic_loading=automatic_loading, 757 caching=caching, 758 cache_path=cache_path, 759 )
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/modeling/data_handler/data_silo.py:104, in DataSilo.init(self, processor, batch_size, eval_batch_size, distributed, automatic_loading, max_multiprocessing_chunksize, max_processes, multiprocessing_strategy, caching, cache_path) 99 loaded_from_cache = True 101 if not loaded_from_cache and automatic_loading: 102 # In most cases we want to load all data automatically, but in some cases we rather want to do this 103 # later or load from dicts instead of file --> 104 self._load_data()
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/modeling/data_handler/data_silo.py:165, in DataSilo._load_data(self, train_dicts, dev_dicts, test_dicts) 163 train_file = self.processor.data_dir / self.processor.train_filename 164 logger.info("Loading train set from: %s ", train_file) --> 165 self.data["train"], self.tensor_names = self._get_dataset(train_file) 166 else: 167 logger.info("No train set is being loaded")
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/modeling/data_handler/data_silo.py:818, in DistillationDataSilo._get_dataset(self, filename, dicts) 816 corresponding_chunks.append(i) 817 if len(batch) == self.teacher_batch_size: --> 818 self._pass_batches( 819 batch, corresponding_chunks, teacher_outputs, tensor_names 820 ) # doing forward pass on teacher model 821 batch = [] 822 corresponding_chunks = []
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/modeling/data_handler/data_silo.py:785, in DistillationDataSilo._pass_batches(self, batch, corresponding_chunks, teacher_outputs, tensor_names) 783 with torch.inference_mode(): 784 batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...) --> 785 batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature 786 batch_dict = { 787 key: tensor.to(self.device) for key, tensor in zip(tensor_names, batch_transposed_list) 788 } # create input dict 789 y = self._run_teacher(batch=batch_dict) # run teacher model
File ~/miniforge3/envs/Transformers/lib/python3.10/site-packages/haystack/modeling/data_handler/data_silo.py:785, in (.0) 783 with torch.inference_mode(): 784 batch_transposed = zip(*batch) # transpose dimensions (from batch, features, ... to features, batch, ...) --> 785 batch_transposed_list = [torch.stack(b) for b in batch_transposed] # create tensors for each feature 786 batch_dict = { 787 key: tensor.to(self.device) for key, tensor in zip(tensor_names, batch_transposed_list) 788 } # create input dict 789 y = self._run_teacher(batch=batch_dict) # run teacher model
RuntimeError: stack expects each tensor to be equal size, but got [5, 2] at entry 0 and [6, 2] at entry 2
I was simply running the code in this tutorial. How could I fix this problem? Thank you!