Yale-LILY / ConvoSumm

Creative Commons Attribution Share Alike 4.0 International
37 stars 9 forks source link

Issues using BART model for inference #6

Open AADeLucia opened 2 years ago

AADeLucia commented 2 years ago

I am trying to use scripts/prep.sh and scripts/inference.py to load /reddit_vanilla_actual/checkpoint_best.pt BART for inference. I have been having many issues, mostly related to package versions and the extended 2048 source positions.

Environment:

pytorch                   1.7.1           py3.8_cuda10.2.89_cudnn7.6.5_0    pytorch

And I tried installing fairseq from source to access the examples module, but then I saw you had your own copy of fairseq in this repo so I installed your version according to the instructions here

cd fairseq
pip install --editable ./
python setup.py build develop

I binarized val.source and val.target from and am running inference as such:

python scripts/inference.py /home/aadelucia/ConvoSumm/checkpoints/reddit_vanilla_actual checkpoint_best.pt /home/aadelucia/ConvoSumm/alexandra_test/data_processed /home/aadelucia/ConvoSumm/alexandra_test/data/val.source /home/aadelucia/ConvoSumm/alexandra_test/inference_output.txt 4 1 80 120 1 2048 ./misc/encoder.json ./misc/vocab.bpe

And I get the following error:

Traceback (most recent call last):
  File "scripts/inference.py", line 42, in <module>
    hypotheses_batch = bart.sample(slines, beam=beam, lenpen=lenpen, min_len=min_len, no_repeat_ngram_size=3)
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 132, in sample
    batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/models/bart/hub_interface.py", line 108, in generate
    return super().generate(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 171, in generate
    for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/hub_utils.py", line 258, in _build_batches
    batch_iterator = self.task.get_batch_iterator(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/tasks/fairseq_task.py", line 244, in get_batch_iterator
    batch_sampler = dataset.batch_by_size(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/data/fairseq_dataset.py", line 145, in batch_by_size
    return data_utils.batch_by_size(
  File "/home/aadelucia/ConvoSumm/code/fairseq/fairseq/data/data_utils.py", line 337, in batch_by_size
    return batch_by_size_vec(
  File "fairseq/data/data_utils_fast.pyx", line 20, in fairseq.data.data_utils_fast.batch_by_size_vec
  File "fairseq/data/data_utils_fast.pyx", line 27, in fairseq.data.data_utils_fast.batch_by_size_vec
AssertionError: Sentences lengths should not exceed max_tokens=1024

Am I using the wrong version of a package? Is there something extra needed for this to work?

AADeLucia commented 2 years ago

Nevermind, seems to be working when I pass in max_tokens=max_source_positions in scripts/inference.py

bart = BARTModel.from_pretrained(
    model_dir,
    checkpoint_file=model_file,
    data_name_or_path=bin_folder,
    gpt2_encoder_json=encoder_file,
    gpt2_vocab_bpe=vocab_file,
    max_source_positions=max_source_positions,
    max_tokens=max_source_positions
)