microsoft / SDR

Self-Supervised Document-to-Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference
45 stars 13 forks source link

error when I run the command or training on the video games dataset #5

Open abhi1nandy2 opened 2 years ago

abhi1nandy2 commented 2 years ago

when I run

python sdr_main.py --dataset_name video_games

I get the following error -

Traceback (most recent call last):
  File "sdr_main.py", line 80, in <module>
    main()
  File "sdr_main.py", line 28, in main
    main_train(model_class_pointer, hyperparams,parser)
  File "sdr_main.py", line 72, in main_train
    trainer.fit(model)
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
    results = self.accelerator_backend.train()
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 57, in train
    return self.train_or_test()
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test
    results = self.trainer.train()
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 532, in train
    self.run_sanity_check(self.get_model())
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in run_sanity_check
    self.reset_val_dataloader(ref_model)
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 287, in reset_val_dataloader
    self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 207, in _reset_eval_dataloader
    dataloaders = self.request_dataloader(getattr(model, loader_name))
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py", line 310, in request_dataloader
    dataloader = dataloader_fx()
  File "/scratch/j20200059/SDR/models/doc_similarity_pl_template.py", line 182, in val_dataloader
    return self.dataloader(mode="val")
  File "/scratch/j20200059/SDR/models/SDR/SDR.py", line 171, in dataloader
    batch_size=self.hparams.val_batch_size,
  File "/scratch/j20200059/SDR/models/SDR/SDR_utils.py", line 16, in __init__
    super(MPerClassSamplerDeter, self).__init__(labels, m, batch_size, int(length_before_new_iter))
  File "/home/j20200059/miniconda3/envs/SDR/lib/python3.7/site-packages/pytorch_metric_learning/samplers/m_per_class_sampler.py", line 32, in __init__
    ), "m * (number of unique labels) must be >= batch_size"
AssertionError: m * (number of unique labels) must be >= batch_size