shon-otmazgin / fastcoref

MIT License
149 stars 26 forks source link

Fine-tuning model on data with no coreference #35

Open ianbstewart opened 1 year ago

ianbstewart commented 1 year ago

Another edge case I've found, this time with data that has no coreference. During training, when a batch of data is generated for the model, it looks like the gold_clusters feature is set to None when none of the data in the batch has any coreference, which causes errors downstream. Example below:

from fastcoref import TrainingArgs, CorefTrainer, FCoref pretrained_model = FCoref(device='cpu')

get toy data with no coreference

sample_text = [ 'This package is very fast!', 'We are so happy to see you!', ] train_data_file = 'train_file_with_clusters.jsonlines' preds = pretrained_model.predict( texts=sample_text, output_file=train_data_file ) output_dir = 'trained_model/' args = TrainingArgs( output_dir=output_dir, overwrite_output_dir=True, model_name_or_path=pretrained_model.model.roberta.name_or_path, device='cpu', epochs=10, logging_steps=10, eval_steps=1, )

trainer = CorefTrainer( args=args, train_file=train_data_file, dev_file=train_data_file, ) trainer.model = pretrained_model.model trainer.train()

Error:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.9/site-packages/fastcoref/trainer.py:195, in CorefTrainer.train(self) 193 batch['input_ids'] = torch.tensor(batch['input_ids'], device=self.device) 194 batch['attention_mask'] = torch.tensor(batch['attention_mask'], device=self.device) --> 195 batch['gold_clusters'] = torch.tensor(batch['gold_clusters'], device=self.device) 196 if 'leftovers' in batch: 197 batch['leftovers']['input_ids'] = torch.tensor(batch['leftovers']['input_ids'], device=self.device)

RuntimeError: Could not infer dtype of NoneType

shon-otmazgin commented 1 year ago

not sure what is the issue here but did:

  1. output_dir = 'trained_model' insted of output_dir = 'trained_model/'
  2. model_name_or_path='distilroberta-base', insted of model_name_or_path=pretrained_model.model.roberta.name_or_path,

and everything worked.

shon-otmazgin commented 1 year ago

@ianbstewart we can close this since no activity here?

ianbstewart commented 1 year ago

Thanks for the update! I get the same problem as before when I make the changes you specified. Might be a version difference? I am using fastcoref==2.1.5.

shon-otmazgin commented 1 year ago

you can try the latest on a brand new env? (fastcoref 2.1.6)

nizaress commented 7 months ago

I do have the same problem as the OP using the latest version (2.1.6) and with the changes specified: RuntimeError: Could not infer dtype of NoneType