Open ymohamedahmed opened 1 year ago
@ymohamedahmed I think your concerns are valid, but I'm not sure we should really try to "rebuild" torchrun/torchelastic. I've been thinking for a while that we could eventually replace the implementation in the SubprocessLauncher with a simple call to torchrun. We already do this for the "lightning run model" command in Fabric. WDYT?
Yeah that's fair, totally understand that you don't want to diverge too much here from Torch. AFAIK torchrun
wouldn't address the above, but torchelastic
might.
Namely the local_elastic_agent
, but it's multiprocessing-based so would provide all the headaches of <blah> can't be pickled
etc. and thereby wouldn't match the existing _SubprocessScriptLauncher
.
Edit: I think torchrun
would solve this since torch.multiprocessing
explicitly propagates errors between processes, but naturally it would be fundamentally different to the existing subprocess-based approach.
In case it helps @awaelchli, DeepSpeed implements the proposal above here so there is prior art!
Hey @ymohamedahmed Maybe I'm misinformed here but I thought that torchelastic, the previous standalone package, was deprecated and then moved into torch as torchrun. https://pytorch.org/docs/stable/elastic/run.html (elastic launch).
For your proposal under "Defunct/zombie processes", is this the same as #16410 / #16204? If you feel very confident in these, I would be happy to review a proof of concept as a separate subclass of _Launcher.
Hey @awaelchli, yes you're right, they have been merged!
Re. https://github.com/Lightning-AI/lightning/issues/16410 / https://github.com/Lightning-AI/lightning/pull/16204 - it would achieve a similar outcome, but it wouldn't detect deadlocks and would be more similar to the DeepSpeed approach.
It would be conceptually a bit simpler (i.e. no need to touch files
on each rank etc.); it would just use the native functionality in subprocess
to check if a process has terminated or not.
I will get back to you with a PoC of what it would look like as a subclass of _Launcher
. Thanks 👍
@awaelchli
Something like the following is what I had in mind, largely just copied from the subprocess_script.py
with some minor tweaks:
class _ManagedSubprocessScriptLauncher(_Launcher):
def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int, process_sampling_period: float) -> None:
super().__init__()
self.cluster_environment = cluster_environment
self.num_processes = num_processes
self.num_nodes = num_nodes
self.procs: List[subprocess.Popen] = []
self.process_sampling_period = process_sampling_period
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
if not self.cluster_environment.creates_processes_externally:
self._call_children_scripts()
self._monitor_workers(self.process_sampling_period)
# Q: what are we supposed to do with `function` here?
def _call_children_scripts(self) -> None:
self._check_can_spawn_children()
self.procs = [] # reset in case it's called twice
# DDP Environment variables
os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
# allow the user to pass the node rank
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
for local_rank in range(self.num_processes): # CHANGE
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{local_rank}"
# remove env var if global seed not set
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
del env_copy["PL_GLOBAL_SEED"]
hydra_in_use = False
cwd: Optional[str] = None
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
hydra_in_use = HydraConfig.initialized()
if hydra_in_use:
command, cwd = _hydra_subprocess_cmd(local_rank)
else:
command = _basic_subprocess_cmd()
new_process = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.procs.append(new_process)
def _monitor_processes(self, sampling_period: float) -> None:
living_processes = set(self.procs)
while len(living_processes):
if process.poll() is not None:
# process had died
living_processes.remove(process)
if process.returncode != 0:
self.kill(signal.SIGTERM, process.returncode)
time.sleep(sampling_period)
def kill(self, signum: _SIGNUM, exitcode: int) -> None:
for proc in self.procs:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
# this skips subprocesses already terminated
proc.send_signal(signum)
sys.exit(exitcode)
Notice that _monitor_processes
won't return until either a worker fails or all processes terminate successfully.
I'm not sure I totally understand how we should behave wrt to the function
arg passed into launch
in this case.
If we launch the script, will it get called on LOCAL_RANK=0
anyway?
Let me know if that makes sense. If so, will raise a proper PR with tests + docs ⚡ :grin:
Hey
Our launcher only creates N-1 new processes. The main process who creates these processes then also participates in training. So in total there are only ever N processes. The main process can't block - it can't run _monitor_processes
in a blocking way otherwise the other N-1 processes will all wait and the program stalls.
The design here is very intentional and I'm pretty sure we don't want to change this. Having exactly N processes all in sync is easy to reason about and makes transferring/broadcasting state straightforward. The user does not need to distinguish between a special main process and worker processes.
I think you'll have to find a different way of monitoring the processes. You could perhaps use a Thread that runs alongside all the worker processes.
Regarding using torchrun
, we could also replace/refactor our launchers to use the torch.distributed.elastic
APIs as much as possible, which will give us all their features that we haven't implemented for free but still not do a black-box call to torchrun
Hey both,
@awaelchli I thought launcher.launch
was called once per process, in which case the above would work as long as I added something like
if os.environ.get("LOCAL_RANK") is None:
self._monitor_workers(sampling_period)
else:
function(blah)
in the launch
function. Then the blocking function call wouldn’t be a problem?
Please correct me if I'm wrong. FWIW torch.multiprocessing
uses N+1 processes as well in order to resolve the aforementioned issues.
Using N processes with a thread per process is possible but more complicated. Mostly because we can't distinguish between a successfully terminated process and a failed process (since only the parent can read the exit code of a process in UNIX). As a result, we would need to include an additional mechanism to avoid terminating all the processes prematurely if a single process succeeds.
Something like,
command = command + f"&& /tmp/succeeded/{local_rank}"
subprocess.Popen(command, ...)
which would touch the above file IFF the worker has succeeded. In each thread, of which there would be N, the thread checks that for each other process either it's running or the above file exists. If that condition is not satisfied, it triggers an exit. You could also use such a mechanism to communicate particular exit codes if that's desirable for users.
IMHO this becomes more complicated than the N+1 process mechanism which would largely not be of any concern for the users. I disagree that it would complicate broadcasting state, especially since this is already used by deepspeed
and torch.multiprocessing
successfully! Further this allows us to use native psutil
/OS functionality rather than having to use files. Just my two cents though! :grin:
@carmocca I think it's great that Lightning provides a DDP launching mechanism that doesn't rely on Python multiprocessing. I think most of the distributed.elastic
API is geared towards torch.multiprocessing
which some users may find problematic (hence DDPStrategy
being the default!) I'll have to take another look at the elastic API though :smile:
Just ran into this issue when trying to setup DistributedSampler
myself. This fails on LocalRank=0
with RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.
, but likely succeeds on LocalRank=1
where the setup has been done.
I like proposal 2, or better yet integrate torchrun
which has quickly become the standard. Standardizing that you must launch a pytorch lightning script with torchrun
for DDP
seems like solution most in-step with PyTorch ecosystem.
Edit: for anyone that finds, here's cause of my issue https://github.com/Lightning-AI/lightning/discussions/7573#discussioncomment-883581
Outline & Motivation
Hi all, firstly thanks for your work on the library, it's great!
Existing code
The
_SubprocessScriptLauncher
results in a process hierarchy as follows, forN
GPUs:where
LocalRank=0
launches each ofLocalRank=1
toLocakRank=N-1
viasubprocess.Popen
.Issues
This can make a few things challenging, namely:
LocalRank=1
toN-1
fail, then the process enters a zombie state, asLocalRank=0
(its parent), never reads the exit code. This means that a container running multi-GPU training will not exit on the failure of a non-zero process until DDP times out. For users where termination leads to freeing of expensive compute resources this is problematic.LocalRank=1..N-1
.LocalRank=0
from the lack of an environment variable. This makes any code relying on this less robust, especially, if it also runs in CPU-only environments where the environment variable is correctly missing.LocalRank=0
fails, thenLocalRank=1,..,N-1
are assigned to be children ofpid=1
. This makes cleaning up of processes a bit more challenging.Proposal
There are two main approaches to rectify this:
LocalRank=0
that checks the status of the remaining processes and tidies up all the processes on a failure. Not sure I like this, however, since the thread isn't strictly speaking guaranteed to execute due to the GIL.where the role of the
Watcher
process is to a) launch the processes for each of the local ranks b) monitor the status of the processes c) terminate all the local rank processes, upon termination of a single one of them (and exit with the same exit code)This would address all of the issues above and we could also explicitly set
LOCAL_RANK=0
in the zeroth process.Pitch
I feel that proposal 2. would add some robustness to multi-GPU training, especially in the case described above.
I would be happy to submit a PR that does this either in a new
_Launcher
implementation or as a refactor of_SubprocessScriptLauncher
, if this sounds reasonable. Thanks!Additional context
No response
cc @awaelchli @justusschock @tchaton @borda @carmocca