huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.81k stars 26.96k forks source link

Speedup module imports #26308

Closed apoorvkh closed 1 year ago

apoorvkh commented 1 year ago

Feature request

Can we please consider importing the deepspeed module when needed, rather than in the import header of trainer.py?

Motivation

When deepspeed is installed, from transformers import Trainer takes a long time!

On my system that's 9 seconds!

>>> import timeit; timeit.timeit("from transformers import Trainer")
[2023-09-20 23:49:13,899] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
8.906949461437762

I believe this import is the culprit. As we can see, it takes 8.5 seconds of the load time.

https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/trainer.py#L217

>>> timeit.timeit("from accelerate.utils import DeepSpeedSchedulerWrapper")
[2023-09-20 23:45:53,185] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
8.525534554384649

This is quite cumbersome, because all scripts that import Trainer (e.g. even for typing) are impacted!

Your contribution

Happy to submit a PR. We could make this a class variable or just import it directly at both places it's used.

https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/trainer.py#L2437-L2439

https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/trainer.py#L2508-L2514

ArthurZucker commented 1 year ago

Hey! Thanks for opening this issue, are you using main. There was a PR recently to fix this, see #26090 and #26106

apoorvkh commented 1 year ago

I am indeed using main (specifically, transformers[deepspeed] at commit 382ba67)!

apoorvkh commented 1 year ago

The code I mentioned above is run directly in the header of trainer.py. And, if I understand correctly, I think accelerate is not covered by the Lazy imports in #26090.

https://github.com/huggingface/transformers/blob/e3a4bd2bee212a2d0fd9f03b27fe7bfc1debe42d/src/transformers/trainer.py#L203-L217

ArthurZucker commented 1 year ago

Cc @younesbelkada I think you mentioned that accelerate is the bottleneck that we can’t get rid of no?

younesbelkada commented 1 year ago

Hi @apoorvkh https://github.com/huggingface/accelerate/pull/1963 being merged in accelerate I think you can switch to accelerate main and see if it resolves your issue

apoorvkh commented 1 year ago

Hey, thanks! I think that commit (https://github.com/huggingface/accelerate/commit/5dec654aaea0c92d4ccb7ad389fc33adcbbf79fc) reduces the runtime for the import from 8-9 seconds to 3-4 seconds (on my machine). That is still not ideal but is certainly more tolerable.

younesbelkada commented 1 year ago

Thanks!
Hm I see ok, I am curious what module takes so much time for import, would you be able to run a quick benchmark with tuna and share the results here?

# benchmark
python -X importtime -c "import transformers" 2> transformers-import-profile.log

# visualize
tuna <path to log file>
apoorvkh commented 1 year ago

For sure. That's a nice tool!

Really quickly, I found that from transformers import Trainer was particularly taking 4 seconds to import -- whereas import transformers is actually faster (< 1 second).

We can see the result for from transformers import Trainer below:

image

Also, for from transformers import TrainingArguments:

image

And we can compare to import transformers:

image

Seems like accelerate is no longer the biggest culprit. A lot of time is also spent importing torch.

My point is that we sometimes just import these tools for typing purposes or in an interactive terminal for later use. From a developer perspective, it would be more convenient to have fast imports and move the time-consuming parts to the moment we actually want to init/use the modules (and are actually expecting to expend time). Thanks!

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.