keras-team / keras-io

Keras documentation, hosted live at
Apache License 2.0
2.69k stars 2.01k forks source link

Multi-GPU distributed training with PyTorch #1838

Open bouachalazhar opened 2 months ago

bouachalazhar commented 2 months ago

Issue Type

Documentation Bug



Keras Version

Keras 3.2.1

Custom Code


OS Platform and Distribution

Linux Ubuntu 22.04.3

Python version

No response

GPU model and memory

No response

Current Behavior?

Downloading data from
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
/usr/lib/python3.10/multiprocessing/ RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock. = os.fork()
x_train shape: (60000, 28, 28, 1)
ProcessRaisedException                    Traceback (most recent call last)
[<ipython-input-14-31fe06a97fdd>](https://localhost:8080/#) in <cell line: 1>()
      1 if __name__ == "__main__":
      2     # We use the "fork" method rather than "spawn" to support notebooks
----> 3     torch.multiprocessing.start_processes(
      4         per_device_launch_fn,
      5         args=(num_gpu,),

1 frames
[/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/](https://localhost:8080/#) in join(self, timeout)
    156         msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
    157         msg += original_trace
--> 158         raise ProcessRaisedException(msg, error_index,


-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/", line 68, in _wrap
    fn(i, *args)
  File "<ipython-input-13-9a6bcf1473c9>", line 47, in per_device_launch_fn
    model = get_model()
  File "<ipython-input-9-4d33f81f3022>", line 14, in get_model
    x = keras.layers.Conv2D(filters=12, kernel_size=3, padding="same", use_bias=False)(
  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/", line 489, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/", line 288, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Standalone code to reproduce the issue or tutorial link

Relevant log output

No response