Open jeffhataws opened 2 months ago
hmm interesting, seems like xmp.spawn
does not surface the assertion error from spawned process, do you know why?
Sorry I haven't had time to look into it more. The last time I looked, the threads are stuck at futex.
ok if you remove the xm.rendezvous("end")
this is going to work. The issue is that current implementation of the multiprocess will not handle the exception from the subprocess until all process finished. I am not an expert on this, but I asked the chatbot and it told me to update the code to like
def run_multiprocess(fn: Callable[..., R],
*args,
start_method: str = 'spawn',
**kwargs) -> Dict[int, R]:
"""Runs `fn` on all devices available to PjRt.
Spawns one process per physical device (e.g. TPU chip).
Args:
fn: Function to run on all devices
args: args to pass to `fn`
start_method: The Python `multiprocessing` process creation method.
Default: `spawn`
kwargs: kwargs to pass to `fn`
Returns:
Dict of the form {device_ordinal: return_value}, where
return_value is the result of calling `fn`.
Raises:
Exception: If any subprocess raises an exception.
"""
if torch_xla._XLAC._xla_runtime_is_initialized():
raise RuntimeError('Runtime is already initialized. Do not use the XLA '
'device before calling xmp.spawn.')
# Determine the number of processes
if plugins.using_dynamic_plugins():
num_processes = plugins.default().physical_chip_count()
elif runtime.device_type() == 'TPU':
num_processes = tpu.num_local_processes()
elif runtime.device_type() == 'CUDA':
num_processes = gpu.num_local_processes()
elif runtime.device_type() == 'NEURON':
num_processes = neuron.num_local_processes()
else:
num_processes = 1
# Create a queue for error reporting
error_queue = multiprocessing.Queue()
with concurrent.futures.ProcessPoolExecutor(
max_workers=num_processes,
mp_context=multiprocessing.get_context(start_method)) as executor:
mp_fn = functools.partial(
_run_thread_per_device,
local_world_size=num_processes,
fn=functools.partial(fn, *args, **kwargs),
initializer_fn=initialize_multiprocess,
error_queue=error_queue)
futures = [executor.submit(mp_fn, i) for i in range(num_processes)]
results = []
while futures:
# Check the error queue
try:
error = error_queue.get_nowait()
# If we get here, an error was reported
executor.shutdown(wait=False)
raise error
except queue.Empty:
# No error reported, continue processing
pass
# Wait for the next future to complete
done, futures = concurrent.futures.wait(
futures, timeout=0.1, return_when=concurrent.futures.FIRST_COMPLETED)
for future in done:
try:
result = future.result()
results.append(result)
except Exception as e:
executor.shutdown(wait=False)
raise # Re-raise the exception to stop the main process
replica_results = list(itertools.chain.from_iterable(result.items() for result in results))
return _merge_replica_results(replica_results)
def _run_thread_per_device(device_ordinal: int,
local_world_size: int,
fn: Callable[[], R],
initializer_fn: Callable[[], None],
error_queue: multiprocessing.Queue) -> Dict[int, R]:
try:
initializer_fn()
result = fn()
return {device_ordinal: result}
except Exception as e:
error_queue.put(e)
raise # Re-raise the exception to stop this process
which adds a try catch
in _run_thread_per_device
and append the expcetion to a queue while the main process keep checking if any new exception in the queue. From the first glance this should work. @will-cromar wdyt?
🐛 Bug
When I run two worker multiprocess training with xmp.spawn, where one rank waits at rendezvous while another hits assertion, I would see a hang at the assertion:
Interrupting using ctrl-C would still leave two zombie processes.
When running with torchrun, I see the run errors out properly:
To Reproduce
Create a file test_assert_hang.py with the following content:
Run the xmp.spawn hang case with CPU on 2 workers:
Run the torchrun no-hang case with CPU on 2 workers:
Expected behavior
The run should exit with assertion message, rather than hang.
Environment
Additional context