huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

Mamba excessive memory usage #30024

Closed GooseIt closed 5 months ago

GooseIt commented 7 months ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

https://www.kaggle.com/code/gooseit/mambatransformers

Expected behavior

The amount of CUDA memory allocated by run should be < 25 Mb from model: 1.4M parameters 4 (float32) 4 (param + grad + Adam statistics) ~ 23 Mb from batch: 16 512 128 ~ 1.2 Mb

Not the 5.6 Gb as in the notebook

famishedrover commented 7 months ago

Before I get into the internals of Mamba implementation, this maybe a useful thread.

The following snippet keeps a check on memory :

for ix in range(2) : 
    inp = torch.randint(0, 128, (2, 512)).to('cuda')
    out = model(inp, use_cache=False)
    del out
    del inp
print(torch.cuda.max_memory_allocated())
for ix in range(2) : 
    inp = torch.randint(0, 128, (2, 512)).to('cuda')
    out = model(inp, use_cache=False)
    del out
    del inp
print(torch.cuda.max_memory_allocated())

and I get the following output :

513157120
513157120

After only loading the modules (imports) result of torch.cuda.max_memory_allocated() = 0. (expected) After loading the model to RAM result of torch.cuda.max_memory_allocated() = 0. (expected) After model = model.to('cuda') result of torch.cuda.max_memory_allocated() = 5529600. ~5.5Mb After a forward pass using a dummy batch as above result is 513157120 ~ 0.5Gb

Using the del commands, the usage remains same.

GooseIt commented 7 months ago

@famishedrover

The issue seems not to be one of those described in the thread. Rather than gradually increasing each epoch, memory spikes during forward pass, spike amplitudes seems to be consistent between epochs.

I've updated the Kaggle notebook - please check Version 6. It now includes torch cuda memory profile image, as well as memory usage prints in different setups backing my suspicion that the problem is in model itself.

GooseIt commented 7 months ago

@famishedrover

I'm carefully reminding you about this issue in case you've forgotten about it

famishedrover commented 7 months ago

I agree, there is something going on with the forward pass itself. The del hack atleast prevents linear growth across multiple forward passes ( but within one pass it still takes unreasonable size of memory ). I will take a look at this later in the week ( limited bandwidth ).

v4ndi commented 7 months ago

@famishedrover @ArthurZucker @younesbelkada @amyeroberts @koayon I've encountered the same issue as described above. I've tried to used transformers mamba implmentation instead of state-spaces/mamba. If neccessary I can provide example of code, Please fix this isssue, I will be really appreciative of this

ArthurZucker commented 7 months ago

You would need to provide a reproducer. If you try the original state space model and have the kernels, the hf model should not really change much. You should make sure you are testing equiavlent things: gradient or not, fast path or not, use cache or not

GooseIt commented 6 months ago

@v4ndi Please provide the reproducer, it will be very helpful

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker commented 5 months ago

related to #31116