huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.23k stars 26.34k forks source link

Memory increment and release when loading model via PretrainedModel.from_pretrained #18782

Closed tobyych closed 2 years ago

tobyych commented 2 years ago

Issue

I was trying to understand the memory usage when loading a Hugging Face model. I found that when loading the model via AutoModelForMaskedLM.from_pretrained("bert-base-uncased"), the resulting increment in memory was (1) larger than the cached BERT model on disk (859MB v.s. 421MB) and (2) when deleting the variable, not all of the allocated memory got released. On the other hand, if I just do torch.load("[path to cached model]"), the memory allocation and release matched and the number was very close to that on disk. May I know why was there such a difference in behaviour?

Code to reproduce the issue

import torch
from transformers import AutoModelForMaskedLM

from memory_profiler import profile

@profile
def hf_load():
    # bert-base-uncased: 421MB on disk
    model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
    del model

@profile
def direct_load():
    model = torch.load('/home/toby/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f', map_location='cpu')
    del model

Profile

hf_load:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    13    241.9 MiB    241.9 MiB           1   @profile
    14                                         def hf_load():
    15                                             # bert-base-uncased: 421MB on disk
    16   1100.4 MiB    858.5 MiB           1       model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
    17    683.0 MiB   -417.4 MiB           1       del model

direct_load:

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    19    240.6 MiB    240.6 MiB           1   @profile
    20                                         def direct_load():
    21    661.4 MiB    420.8 MiB           1       model = torch.load('/home/toby/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f', map_location='cpu')
    22    241.7 MiB   -419.7 MiB           1       del model

To supplement, I also observed that when running hf_load above multiple times, the memory usage was rather unobvious.

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    19    239.7 MiB    239.7 MiB           1   @profile
    20                                         def multiple_hf_load():
    21    681.0 MiB    441.3 MiB           1       hf_load()
    22   1008.8 MiB    327.9 MiB           1       hf_load()
    23    919.4 MiB    -89.4 MiB           1       hf_load()
    24    993.6 MiB     74.2 MiB           1       hf_load()
    25    995.9 MiB      2.2 MiB           1       hf_load()
    26    992.9 MiB     -3.0 MiB           1       hf_load()
Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    34    240.8 MiB    240.8 MiB           1   @profile
    35                                         def multiple_direct_load():
    36    242.0 MiB      1.1 MiB           1       direct_load()
    37    241.9 MiB     -0.1 MiB           1       direct_load()
    38    241.9 MiB      0.0 MiB           1       direct_load()
    39    241.9 MiB      0.0 MiB           1       direct_load()
    40    241.9 MiB      0.0 MiB           1       direct_load()
    41    241.9 MiB      0.0 MiB           1       direct_load()

It increased in the first two times, but did not keep increasing from the third time onwards. I wonder how could this be explained.

P.S. Also attached the case for direct_load above. No increment was observed.

Supplementary information

OS: 5.10.60.1-microsoft-standard-WSL2, 4.15.0-1113-azure #126~16.04.1-Ubuntu Python: 3.8.12 PyTorch: 1.11.0 Transformers: 4.21.2

@LysandreJik

LysandreJik commented 2 years ago

@ydshieh has been looking into memory leaks as well recently and might have some insights for you!

ydshieh commented 2 years ago

Hi @tobyych

Could you also try to add gc.collect() after del model in both loading methods, and see what you get (memory usage) after gc.collect() is done. You have to import gc.

tobyych commented 2 years ago

Hi @ydshieh,

Tried to add gc.collect() after del model.

For single hf_load,

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    14    245.5 MiB    245.5 MiB           1   @profile
    15                                         def hf_load():
    16                                             # bert-base-uncased: 421MB on disk
    17   1103.9 MiB    858.4 MiB           1       model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
    18    686.2 MiB   -417.7 MiB           1       del model
    19    266.5 MiB   -419.7 MiB           1       gc.collect()

For single direct_load,

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    32    240.8 MiB    240.8 MiB           1   @profile
    33                                         def direct_load():
    34    661.5 MiB    420.7 MiB           1       model = torch.load('/home/toby/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f', map_location='cpu')
    35    241.9 MiB   -419.6 MiB           1       del model
    36    241.9 MiB      0.0 MiB           1       gc.collect()

For multiple hf_load,

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    22    241.4 MiB    241.4 MiB           1   @profile
    23                                         def multiple_hf_load():
    24    263.1 MiB     21.6 MiB           1       hf_load()
    25    921.1 MiB    658.1 MiB           1       hf_load()
    26    263.3 MiB   -657.9 MiB           1       hf_load()
    27    263.3 MiB      0.0 MiB           1       hf_load()
    28    263.3 MiB      0.0 MiB           1       hf_load()
    29    263.3 MiB      0.0 MiB           1       hf_load()

For multiple direct_load,

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    38    240.8 MiB    240.8 MiB           1   @profile
    39                                         def multiple_direct_load():
    40    241.9 MiB      1.1 MiB           1       direct_load()
    41    242.0 MiB      0.1 MiB           1       direct_load()
    42    242.0 MiB      0.0 MiB           1       direct_load()
    43    242.0 MiB      0.0 MiB           1       direct_load()
    44    242.0 MiB      0.0 MiB           1       direct_load()
    45    248.0 MiB      6.0 MiB           1       direct_load()

With the explicit garbage collection, single hf_load could release the remaining memory that wasn't released previously. However, still not sure why it used double memory compared to direct_load.

As for multiple hf_load, the results looked much better after adding the explicit garbage collection as the memory used after several loads dropped to 263.3MB instead of the ~1000MB seen previously.

Any clue what weren't collected by the GC previously?

ydshieh commented 2 years ago

Hi, @tobyych

Glad to know gc.collect() works.

I don't know what weren't collected by the GC previously. In general, (I believe) it's not easy to know exactly what gc has done at at which timing. If you are able to investigate and share your finding, it would be great.

I will try to check the double memory part.

tobyych commented 2 years ago

@ydshieh, I tried to inspect the objects that were collected by the GC between del model and gc.collect() using the code below, where I aggregated the sizes of Python objects by their modules. I filtered those from the transformers package and might be a starting point for your side to see which part of code might be related.

def hf_load():
    # bert-base-uncased: 421MB on disk
    model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
    del model
    d = defaultdict(int)
    for o in gc.get_objects():
        try:
            d[o.__module__] += sys.getsizeof(o)
        except:
            d['others'] += sys.getsizeof(o)
    for k, v in d.items():
        if type(k) is str and k.startswith("transformers"):
            print(k, v)
    gc.collect()
transformers.modeling_utils 19120
transformers.utils.doc 1224
transformers.utils.versions 408
transformers.utils.logging 6664
transformers.utils.import_utils 21552
transformers.utils.generic 10363
transformers.utils.hub 7928
transformers.utils 136
transformers.dependency_versions_check 136
transformers.utils.dummy_speech_objects 2400
transformers.utils.dummy_tensorflow_text_objects 1200
transformers.utils.dummy_sentencepiece_and_speech_objects 1200
transformers.utils.dummy_timm_objects 4008
transformers.utils.dummy_scatter_objects 6136
transformers.utils.dummy_tf_objects 390408
transformers.utils.dummy_flax_objects 187200
transformers.dynamic_module_utils 1224
transformers.tokenization_utils_base 21309
transformers.tokenization_utils 6480
transformers.models.t5.tokenization_t5 3104
transformers.convert_slow_tokenizer 42888
transformers.tokenization_utils_fast 4328
transformers.models.t5.tokenization_t5_fast 1744
transformers.configuration_utils 6640
transformers.models.auto.configuration_auto 7184
transformers.models.auto.auto_factory 22544
transformers.models.auto.modeling_auto 27936
transformers.onnx.utils 1656
transformers.onnx.config 9288
transformers.models.bert.configuration_bert 2232
transformers.models.albert.configuration_albert 2400
transformers.models.bart.configuration_bart 3216
transformers.models.big_bird.configuration_big_bird 2400
transformers.models.roberta.configuration_roberta 2400
transformers.models.camembert.configuration_camembert 2264
transformers.models.convbert.configuration_convbert 2400
transformers.models.data2vec.configuration_data2vec_text 2400
transformers.models.deberta.configuration_deberta 2672
transformers.models.deberta_v2.configuration_deberta_v2 2672
transformers.models.distilbert.configuration_distilbert 2400
transformers.models.electra.configuration_electra 2400
transformers.models.xlm.configuration_xlm 2400
transformers.models.flaubert.configuration_flaubert 2400
transformers.models.fnet.configuration_fnet 1200
transformers.models.funnel.configuration_funnel 1744
transformers.models.ibert.configuration_ibert 2400
transformers.models.layoutlm.configuration_layoutlm 2672
transformers.models.longformer.configuration_longformer 1200
transformers.models.luke.configuration_luke 1200
transformers.models.mbart.configuration_mbart 3216
transformers.models.megatron_bert.configuration_megatron_bert 1200
transformers.models.mobilebert.configuration_mobilebert 2400
transformers.models.mpnet.configuration_mpnet 1200
transformers.models.mvp.configuration_mvp 1200
transformers.models.nezha.configuration_nezha 1200
transformers.models.nystromformer.configuration_nystromformer 1200
transformers.feature_extraction_utils 5256
transformers.models.perceiver.configuration_perceiver 2672
transformers.models.qdqbert.configuration_qdqbert 1200
transformers.models.reformer.configuration_reformer 1200
transformers.models.rembert.configuration_rembert 1200
transformers.models.roformer.configuration_roformer 2400
transformers.models.squeezebert.configuration_squeezebert 2400
transformers.models.tapas.configuration_tapas 1200
transformers.models.wav2vec2.configuration_wav2vec2 1336
transformers.models.xlm_roberta.configuration_xlm_roberta 2264
transformers.models.xlm_roberta_xl.configuration_xlm_roberta_xl 2400
transformers.models.yoso.configuration_yoso 1200
transformers.activations 14496
transformers.modeling_outputs 45024
transformers.deepspeed 3896
transformers.generation_beam_constraints 9944
transformers.generation_beam_search 6568
transformers.generation_logits_process 25728
transformers.generation_stopping_criteria 6752
transformers.pytorch_utils 2288
transformers.generation_utils 17464
transformers.models.bert.modeling_bert 41152
ydshieh commented 2 years ago

@tobyych

I opened a PR #18832 which could solve this issue. Notice that this is not a real memory issue however. GC usually makes its own decision for when to collect. But it's not bad if we can release some memory earlier.

tobyych commented 2 years ago

Thanks @ydshieh!