fastnlp / fastNLP

fastNLP: A Modularized and Extensible NLP Framework. Currently still in incubation.
https://gitee.com/fastnlp/fastNLP
Apache License 2.0
3.05k stars 451 forks source link

Clearing the cache after every epoch #332

Closed ChawlaAvi closed 3 years ago

ChawlaAvi commented 3 years ago

Hi

I am using the Trainer class for my model and provide it with the model and embedding as an attribute in order to train the model (same as that shown in the comments of trainer class).

Here, I also specify the number of epochs I want to run my model for. Now, when I try to use BERT model in the embeddings, it keeps throwing a CUDA out of memory error unless I don't specify the batch size as 4 or lower. Specifying such a small batch size increases the training time exponentially for my model. One solution to this is that if I can clear the GPU's cache after each epoch, this error won't come up and I have somehow verified that but I can't use his technique in my original code due to some reasons.

I just want to know if there is any option available to clear the cache automatically after every epoch?

This is how the trainer has been declared.

trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer, batch_size=batch_size, sampler=BucketSampler(), num_workers=2, n_epochs=100, dev_data=data_bundle.get_dataset('dev'), metrics=SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type=encoding_type), dev_batch_size=batch_size*5, callbacks=callbacks, device=device, test_use_tqdm=False, use_tqdm=True, print_every=300, save_path="saved_models/")

I just want to know if there is any way that I can run the model for 100 epochs and also execute torch.cuda.empty_cache() after each epoch.

xuyige commented 3 years ago

Thank you for your issue!

  1. How long does your input sequence on average? Empirically speaking, BERT-base with seq_len=512 leads to batch size as 6 or lower, BERT-large with seq=512 leads to batch size as 1 (on one single 12Gb GPU).
  2. If you do not fine-tune the BERT model, you can set BertEmbedding.requires_grad=False.
  3. If you want to execute torch.cuda.empty_cache(), you can write a callback and pass the callback into callbacks
ChawlaAvi commented 3 years ago

Hi Thanks for your response 👍

  1. The average sentence length in my case is around 15 and the maximum length is around 100 words.
  2. I want to fine tune the model so unfortunately, I can't set that to False.
  3. Yeah this is something that I can do and was also looking for.

Thank you so much @xuyige.