huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.21k stars 26.6k forks source link

How to use fine-tuned BART for prediction? #3853

Closed riacheruvu closed 4 years ago

riacheruvu commented 4 years ago

❓ Questions & Help

Details

I fine-tuned the BART model on a custom summarization dataset using the transformers/examples/summarization/bart/finetune.py and transformers/examples/summarization/bart/run_train.sh files in the repository for training (which generated three checkpointepoch=*.ckpt files) and prediction (which generated a .txt file with the test loss scores).

I have two questions on using this model for prediction:

Thank you for your time!

riacheruvu commented 4 years ago

Thank you, @claudiatin, and thank you for sharing your code!

gmlander commented 4 years ago

@claudiatin thanks for providing your code. I was able to load a finetuned version of facebook/bart-large-cnn into a pipeline using a far hackier way originally as well as your method.

Problem I'm running into which it sounds like maybe you were as well, is that the predictions from the pipeline after finetuning come out as pure gibberish, so something is being lost in translation. Example below:

'redistributionestonestoneston Hag Hag resultant resultant ' 'resultantestoneston redistribution redistribution Hag Hag pressuring ' 'pressuring redistribution redistribution alternate alternate alternate ' 'pressuring pressuring Hag Hagestoneston Champions Champions Champions ' 'redistribution redistribution sil sil sil redistribution redistributionbelt ' 'redistribution redistributioniopiopiop redistribution redistribution carved ' 'carved carved Hag Hag sil sil pressuring pressuring carved carved ' 'compartment compartment compartment redistribution redistribution Voyager ' 'Voyager Voyager redistribution redistribution pressuring pressuring '

I used the finetune.py script on the cnn tiny dataset found from the tiny version of the bash script in the examples folder. I even attempted to do this finetuning with nearly 0 (1e-10) learning rate, so that I knew I wasn't significantly changing the model. This still lead to gibberish predictions.

I tried a version where I loaded the pretrained model into the pipeline, saved it using pipeline.model.save_pretrained("path/to/dir") and in a new session, reloaded it using the second portion of the code provided by @claudiatin plus bart_loaded = pipeline(task='summarization', model=model, device = 0, tokenizer=tokenizer)

This worked correctly on predictions, however I did notice a significant change in inference time on the same article I tested (~3 seconds vs ~20 seconds). The only difference I could see vs using the config.json and pytorch_model.bin that came out of save_pretrained() vs the finetune.py checkpoint is that the save_pretrained() config.json contains the added key:value "architectures": ["BartForConditionalGeneration"]. I made this change to the config generated from my finetuned model, but it did not correct the gibberish generation problem.

@sshleifer , any ideas?

claudiatin commented 4 years ago

@gmlander, yes I have the same gibberish issue. It's not clear to me how to solve it. It would be nice to know that

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

mriganktiwari commented 3 years ago

I found a solution. The model.generate() function is necessary to extract the predictions. I defined a separate function in the SummarizationTrainer() class to use self.model.generate(), and was able to use tokenizer.decoder() on the outputs.

I was encountering issues when using self.tokenizer, so I assume using 'bart-large-cnn' tokenizer for similar custom summarization datasets is okay.

@prabalbansal, I'm not sure if the same method will apply to T5, but it could work for predicting for a single instance, per one of your questions.

My code is below:

    def text_predictions(self, input_ids):
        generated_ids = self.model.generate(
            input_ids=input_ids,
            num_beams=1,
            max_length=80,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
        )
        preds = [
            self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            for g in generated_ids
        ]
        return preds
...
    # Optionally, predict on dev set and write to output_dir
    if args.do_predict:
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
        tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')
        ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
        inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')['input_ids']
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
        model = model.load_from_checkpoint(checkpoints[-1])
        model.eval()
        model.freeze()
        outputs = model.text_predictions(inputs)
        print(outputs)

Thank you for the help, @sshleifer !

Hi @riacheruvu , I am facing a similar issue while tokenizing a piece of text in the QAGS repo. Line number 133 in https://github.com/W4ngatang/qags/blob/master/qg_utils.py gives me the same error which is due to tokenizer.decode() encountering a NoneType object. Would request if you can help. Please see the error log below:

Traceback (most recent call last):
    File "qg_utils.py", line 169, in <module>
        sys.exit(main(sys.argv[1:]))
    File "qg_utils.py", line 166, in main
        extract_gen_from_fseq_log(args.data_file, args.out_dir)
    File "qg_utils.py", line 142, in extract_gen_from_fseq_log
        gen = tokenizer.decode(tok_ids)
    File "/home/test/miniconda3/envs/qags/lib/python3.6/site-packages/transformers/tokenization_utils_base.py", line 3113, in decode
        *kwargs,
    File "/home/test/miniconda3/envs/qags/lib/python3.6/site-packages/transformers/tokenization_utils.py", line 753, in _decode
        sub_texts.append(self.convert_tokens_to_string(current_sub_text))
    File "/home/test/miniconda3/envs/qags/lib/python3.6/site-packages/transformers/models/gpt2/tokenization_gpt2.py", line 264, in convert_tokens_to_string
        text = "".join(tokens)
TypeError: sequence item 0: expected str instance, NoneType found
riacheruvu commented 3 years ago

Hi @mriganktiwari, in my case, I needed to use model.generate() as input to tokenizer.decode() to solve this issue. I had an older version of HuggingFace at the time, so this might not be true today.

You could consider first using model.generate() with tok_ids, followed by tokenizer.decode(). I could be wrong, and I'm not sure what the input data_file consists of, but I would try this to see if it helps.