asyml / texar-pytorch

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/
https://asyml.io
Apache License 2.0
745 stars 117 forks source link

GPU memory usage when doing beam search #314

Open tanyuqian opened 4 years ago

tanyuqian commented 4 years ago

A BART model (https://arxiv.org/pdf/1910.13461.pdf) is implemented here: https://github.com/tanyuqian/texar-pytorch/tree/master/examples/bart

It has passed the test of text classification (MNLI) and summarization (CNN/DM) with greedy decoding, but it fails to run CNN/DM with beam search on a single GTX 1080Ti because of GPU memory, even when batch_size=1, beam_width=2, max_decoding_length=140.

A script to show this issue is here: https://github.com/tanyuqian/texar-pytorch/blob/master/examples/bart/bart_cnn.py (run this code after downloading CNN/DM data following README)

Note that in this fork, two more hyperparameters are added in TransformerDecoder ('normalize_before' and 'final_layer_norm'): https://github.com/tanyuqian/texar-pytorch/blob/master/texar/torch/modules/decoders/transformer_decoders.py#L290