Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

Refactor `_SubprocessScriptLauncher` process launching strategy #17248

Open ymohamedahmed opened 1 year ago

ymohamedahmed commented 1 year ago

Outline & Motivation

Hi all, firstly thanks for your work on the library, it's great!

Existing code

The _SubprocessScriptLauncher results in a process hierarchy as follows, for N GPUs:

LocalRank=0
├── LocalRank=1
├── LocalRank=2
├── LocalRank=3
├── ...
└── LocalRank=N-1

where LocalRank=0 launches each of LocalRank=1 to LocakRank=N-1 via subprocess.Popen.

Issues

This can make a few things challenging, namely:

Proposal

There are two main approaches to rectify this:

  1. Add a thread in LocalRank=0 that checks the status of the remaining processes and tidies up all the processes on a failure. Not sure I like this, however, since the thread isn't strictly speaking guaranteed to execute due to the GIL.
  2. Refactor process structure to be something like
    Watcher
    ├── LocalRank=0
    ├── LocalRank=1
    ├── LocalRank=2
    ├── LocalRank=3
    ├── ...
    └── LocalRank=N-1

    where the role of the Watcher process is to a) launch the processes for each of the local ranks b) monitor the status of the processes c) terminate all the local rank processes, upon termination of a single one of them (and exit with the same exit code)

This would address all of the issues above and we could also explicitly set LOCAL_RANK=0 in the zeroth process.

Pitch

I feel that proposal 2. would add some robustness to multi-GPU training, especially in the case described above.

I would be happy to submit a PR that does this either in a new _Launcher implementation or as a refactor of _SubprocessScriptLauncher, if this sounds reasonable. Thanks!

Additional context

No response

cc @awaelchli @justusschock @tchaton @borda @carmocca

ymohamedahmed commented 1 year ago

Related issues:

awaelchli commented 1 year ago

@ymohamedahmed I think your concerns are valid, but I'm not sure we should really try to "rebuild" torchrun/torchelastic. I've been thinking for a while that we could eventually replace the implementation in the SubprocessLauncher with a simple call to torchrun. We already do this for the "lightning run model" command in Fabric. WDYT?

ymohamedahmed commented 1 year ago

Yeah that's fair, totally understand that you don't want to diverge too much here from Torch. AFAIK torchrun wouldn't address the above, but torchelastic might.

Namely the local_elastic_agent, but it's multiprocessing-based so would provide all the headaches of <blah> can't be pickled etc. and thereby wouldn't match the existing _SubprocessScriptLauncher.

Edit: I think torchrun would solve this since torch.multiprocessing explicitly propagates errors between processes, but naturally it would be fundamentally different to the existing subprocess-based approach.

ymohamedahmed commented 1 year ago

In case it helps @awaelchli, DeepSpeed implements the proposal above here so there is prior art!

awaelchli commented 1 year ago

Hey @ymohamedahmed Maybe I'm misinformed here but I thought that torchelastic, the previous standalone package, was deprecated and then moved into torch as torchrun. https://pytorch.org/docs/stable/elastic/run.html (elastic launch).

For your proposal under "Defunct/zombie processes", is this the same as #16410 / #16204? If you feel very confident in these, I would be happy to review a proof of concept as a separate subclass of _Launcher.

ymohamedahmed commented 1 year ago

Hey @awaelchli, yes you're right, they have been merged!

Re. https://github.com/Lightning-AI/lightning/issues/16410 / https://github.com/Lightning-AI/lightning/pull/16204 - it would achieve a similar outcome, but it wouldn't detect deadlocks and would be more similar to the DeepSpeed approach.

It would be conceptually a bit simpler (i.e. no need to touch files on each rank etc.); it would just use the native functionality in subprocess to check if a process has terminated or not.

I will get back to you with a PoC of what it would look like as a subclass of _Launcher. Thanks 👍

ymohamedahmed commented 1 year ago

@awaelchli

Something like the following is what I had in mind, largely just copied from the subprocess_script.py with some minor tweaks:


class _ManagedSubprocessScriptLauncher(_Launcher):

     def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int, process_sampling_period: float) -> None:
        super().__init__()
        self.cluster_environment = cluster_environment
        self.num_processes = num_processes
        self.num_nodes = num_nodes
        self.procs: List[subprocess.Popen] = []  
        self.process_sampling_period = process_sampling_period

    def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
        if not self.cluster_environment.creates_processes_externally:
            self._call_children_scripts()

        self._monitor_workers(self.process_sampling_period)

      # Q: what are we supposed to do with `function` here?

    def _call_children_scripts(self) -> None: 
        self._check_can_spawn_children()
        self.procs = []  # reset in case it's called twice

        # DDP Environment variables
        os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
        os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)

        # allow the user to pass the node rank
        os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
        os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
        os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"

        for local_rank in range(self.num_processes): # CHANGE 
            env_copy = os.environ.copy()
            env_copy["LOCAL_RANK"] = f"{local_rank}"

            # remove env var if global seed not set
            if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
                del env_copy["PL_GLOBAL_SEED"]

            hydra_in_use = False
            cwd: Optional[str] = None
            if _HYDRA_AVAILABLE:
                from hydra.core.hydra_config import HydraConfig

                hydra_in_use = HydraConfig.initialized()

            if hydra_in_use:
                command, cwd = _hydra_subprocess_cmd(local_rank)
            else:
                command = _basic_subprocess_cmd()

            new_process = subprocess.Popen(command, env=env_copy, cwd=cwd)
            self.procs.append(new_process)

    def _monitor_processes(self, sampling_period: float) -> None: 
        living_processes = set(self.procs) 
        while len(living_processes): 
            if process.poll() is not None: 
                # process had died 
                living_processes.remove(process)
                if process.returncode != 0: 
                    self.kill(signal.SIGTERM, process.returncode)
        time.sleep(sampling_period)

    def kill(self, signum: _SIGNUM, exitcode: int) -> None:
        for proc in self.procs:
            log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
            # this skips subprocesses already terminated
            proc.send_signal(signum)
        sys.exit(exitcode)

Notice that _monitor_processes won't return until either a worker fails or all processes terminate successfully.

I'm not sure I totally understand how we should behave wrt to the function arg passed into launch in this case.

If we launch the script, will it get called on LOCAL_RANK=0 anyway?

Let me know if that makes sense. If so, will raise a proper PR with tests + docs ⚡ :grin:

awaelchli commented 1 year ago

Hey

Our launcher only creates N-1 new processes. The main process who creates these processes then also participates in training. So in total there are only ever N processes. The main process can't block - it can't run _monitor_processes in a blocking way otherwise the other N-1 processes will all wait and the program stalls.

The design here is very intentional and I'm pretty sure we don't want to change this. Having exactly N processes all in sync is easy to reason about and makes transferring/broadcasting state straightforward. The user does not need to distinguish between a special main process and worker processes.

I think you'll have to find a different way of monitoring the processes. You could perhaps use a Thread that runs alongside all the worker processes.

carmocca commented 1 year ago

Regarding using torchrun, we could also replace/refactor our launchers to use the torch.distributed.elastic APIs as much as possible, which will give us all their features that we haven't implemented for free but still not do a black-box call to torchrun

ymohamedahmed commented 1 year ago

Hey both,

@awaelchli I thought launcher.launch was called once per process, in which case the above would work as long as I added something like

if os.environ.get("LOCAL_RANK") is None: 
  self._monitor_workers(sampling_period)
else:
  function(blah)

in the launch function. Then the blocking function call wouldn’t be a problem?

Please correct me if I'm wrong. FWIW torch.multiprocessing uses N+1 processes as well in order to resolve the aforementioned issues.

Using N processes with a thread per process is possible but more complicated. Mostly because we can't distinguish between a successfully terminated process and a failed process (since only the parent can read the exit code of a process in UNIX). As a result, we would need to include an additional mechanism to avoid terminating all the processes prematurely if a single process succeeds.

Something like,

command = command + f"&& /tmp/succeeded/{local_rank}"
subprocess.Popen(command, ...)

which would touch the above file IFF the worker has succeeded. In each thread, of which there would be N, the thread checks that for each other process either it's running or the above file exists. If that condition is not satisfied, it triggers an exit. You could also use such a mechanism to communicate particular exit codes if that's desirable for users.

IMHO this becomes more complicated than the N+1 process mechanism which would largely not be of any concern for the users. I disagree that it would complicate broadcasting state, especially since this is already used by deepspeed and torch.multiprocessing successfully! Further this allows us to use native psutil/OS functionality rather than having to use files. Just my two cents though! :grin:

@carmocca I think it's great that Lightning provides a DDP launching mechanism that doesn't rely on Python multiprocessing. I think most of the distributed.elastic API is geared towards torch.multiprocessing which some users may find problematic (hence DDPStrategy being the default!) I'll have to take another look at the elastic API though :smile:

tbenst commented 1 year ago

Just ran into this issue when trying to setup DistributedSampler myself. This fails on LocalRank=0 with RuntimeError: Default process group has not been initialized, please make sure to call init_process_group., but likely succeeds on LocalRank=1 where the setup has been done.

I like proposal 2, or better yet integrate torchrun which has quickly become the standard. Standardizing that you must launch a pytorch lightning script with torchrun for DDP seems like solution most in-step with PyTorch ecosystem.

Edit: for anyone that finds, here's cause of my issue https://github.com/Lightning-AI/lightning/discussions/7573#discussioncomment-883581