Yale-LILY / SummerTime

An open-source text summarization toolkit for non-experts. EMNLP'2021 Demo
https://arxiv.org/abs/2108.12738
Apache License 2.0
268 stars 30 forks source link

self.model needs to be moved to GPU in BartModel #90

Closed JingrongFeng closed 3 years ago

JingrongFeng commented 3 years ago

In SummerTime_midway_showcase_08_28.ipynb, when I defined an object of BartModel on GPU and ran the inference as follows:

bart_model = st_model.BartModel(device='cuda')
documents = [
    """xxx"""
]
bart_model.summarize(documents)

I got the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-5f5c5eef9ea1> in <module>()
      3 ]
      4 
----> 5 sample_model.summarize(documents)

8 frames
/content/SummerTime/summertime/model/single_doc/bart_model.py in summarize(self, corpus, queries)
     27             corpus, truncation=True, padding="longest", return_tensors="pt"
     28         ).to(self.device)
---> 29         encoded_summaries = self.model.generate(**batch)
     30         summaries = self.tokenizer.batch_decode(
     31             encoded_summaries, skip_special_tokens=True

/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.__class__():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, encoder_no_repeat_ngram_size, num_return_sequences, max_time, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, forced_bos_token_id, forced_eos_token_id, remove_invalid_values, **model_kwargs)
    925         if self.config.is_encoder_decoder:
    926             # add encoder_outputs to model_kwargs
--> 927             model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
    928 
    929             # set input_ids as decoder_input_ids

/usr/local/lib/python3.7/dist-packages/transformers/generation_utils.py in _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs)
    410                 argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
    411             }
--> 412             model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
    413         return model_kwargs
    414 

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/usr/local/lib/python3.7/dist-packages/transformers/models/bart/modeling_bart.py in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    748 
    749         if inputs_embeds is None:
--> 750             inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
    751 
    752         embed_pos = self.embed_positions(input_shape)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/sparse.py in forward(self, input)
    156         return F.embedding(
    157             input, self.weight, self.padding_idx, self.max_norm,
--> 158             self.norm_type, self.scale_grad_by_freq, self.sparse)
    159 
    160     def extra_repr(self) -> str:

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1914         # remove once script supports set_grad_enabled
   1915         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1916     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1917 
   1918 

RuntimeError: Input, output and indices must be on the current device

Change the following line https://github.com/Yale-LILY/SummerTime/blob/211ee0676b0c48f35d0ef797f8c692dd5f0b7aae/summertime/model/single_doc/bart_model.py#L20 to

self.model = BartForConditionalGeneration.from_pretrained(model_name).to(self.device)

would fix it.

niansong1996 commented 3 years ago

Thanks a lot for the comment! We've fixed this in #91, feel free to reopen it if you have any questions.

JingrongFeng commented 3 years ago

Thanks so much for the quick fix and this handy toolkit!