huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.34k stars 875 forks source link

Speed up imports and add a CI #2845

Closed muellerzr closed 3 days ago

muellerzr commented 3 weeks ago

What does this PR do?

This PR introduces a CI which utilizes the cProfile timings and my tuna-interpreter libraries to perform time-based import tests.

It was reported by @stas00 that we were taking far too long to do basic things like accelerate launch, and tuna can help visualize why by creating import graphs directing us to what is taking too long:

image

I wrote a small library called tuna-interpreter that aims to take the best parts of tuna and work it into something parse-able that lets us run CIs off of it.

After using the tool:

image

We can see a decrease of ~68%

How it works:

In its current form, we are going based off of a baseline torch import, since Accelerate relies on torch no matter what. BUT we should be no more than ~20% slower than the torch import overall. Anything more and we have some slip-up or timing problem.

An example test looks like the following:

    def test_base_import(self):
        output = run_import_time("import accelerate")
        with open(f"{self.tmpdir}/base_results.log", "w") as f:
            f.write(output)
        data = read_import_profile(f"{self.tmpdir}/base_results.log")
        total_time = calculate_total_time(data)
        pct_more = total_time / self.pytorch_time
        # Base import should never be more than 10% slower than raw torch import
        err_msg = f"Base import is more than 20% slower than raw torch import ({pct_more * 100:.2f}%), please check the attached `tuna` profile:\n"
        sorted_data = sort_nodes_by_total_time(data)
        paths_above_threshold = get_paths_above_threshold(sorted_data, 0.1, max_depth=7)
        err_msg += f"\n{convert_list_to_string(paths_above_threshold)}"
        self.assertLess(pct_more, 1.2, err_msg)

Where essentially we:

  1. Get the import_time to run a particular python import called via -c in subprocess
  2. We then read this profile generated
  3. From here, we take all the nodes, sort them by time, and get all paths above an arbitrary threshold. This should be tweaked to your own discretion, as threshold and max_depth changes from library to library. The key with max_depth is it should be enough to get your imports out of the slowdown trace, and show what external libraries you are really calling.
  4. Afterwards, we write a note stating that it was above a slowdown expected %, and state what modules were slowing it down.

An example failure is below, where we can clearly see what module chain had the slowdown:

E       AssertionError: 1.3515017627366224 not less than 1.2 : Base import is more than 20% slower than raw torch import (135.15%), please check the attached `tuna` profile:
E       
E       main 0.973s
E       main->accelerate 0.961s
E       main->accelerate->accelerate.accelerator 0.959s
E       main->accelerate->accelerate.accelerator->torch 0.758s
E       main->accelerate->accelerate.accelerator->torch->torch._C 0.355s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations 0.154s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations->torch._decomp 0.109s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing 0.126s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils 0.125s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.megatron_lm 0.107s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.megatron_lm->transformers.modeling_outputs 0.106s

tests/test_imports.py:64: AssertionError

Or:

E       AssertionError: 1.8292819779293377 not less than 1.2 : Base import is more than 20% slower than raw torch import (182.93%), please check the attached `tuna` profile:
E       
E       main 1.324s
E       main->accelerate 1.308s
E       main->accelerate->accelerate.accelerator 1.307s
E       main->accelerate->accelerate.accelerator->torch 0.706s
E       main->accelerate->accelerate.accelerator->torch->torch._C 0.327s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations 0.152s
E       main->accelerate->accelerate.accelerator->torch->torch._meta_registrations->torch._decomp 0.108s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing 0.527s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils 0.526s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.fsdp_utils 0.488s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.fsdp_utils->torch.distributed.fsdp.fully_sharded_data_parallel 0.488s
E       main->accelerate->accelerate.accelerator->accelerate.checkpointing->accelerate.utils->accelerate.utils.fsdp_utils->torch.distributed.fsdp.fully_sharded_data_parallel->torch.distributed.fsdp 0.488s

tests/test_imports.py:64: AssertionError

If there are specific issues with using tuna-interpreter, please let me know, it's a very quickly hacked-together-but-working library for what we are doing, and open to improving it further after we battle-test it

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@SunMarc @BenjaminBossan @sayakpaul @Titus-von-Koeller @ArthurZucker @ydshieh @LysandreJik

muellerzr commented 3 weeks ago

Let me know how I can improve on this tool further so we can then get it going throughout anyone at HF that wants to use it 🤗

muellerzr commented 3 weeks ago

@BenjaminBossan what's missing here is the initial tuna check, which is why it's required (see the workflow).

We can eventually look at gutting it out, sure. However I think having both the visual option for debugging further + the condensed output here is valuable and is why we should still use tuna itself.

Edit; I may be wrong here, sorry

muellerzr commented 3 weeks ago

@BenjaminBossan fully removed the requirement for tuna there, however I still think it's useful for further debugging so left that in as part of the testcase class description :)

HuggingFaceDocBuilderDev commented 3 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

ydshieh commented 3 days ago

🔥