pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 461 forks source link

_setup_workers fails to match worker.host_port #2670

Closed tyoc213 closed 3 years ago

tyoc213 commented 3 years ago

Im passing XRT_WORKERS as localservice:0;grpc://localhost:40934

image

Which is weird because obviously strings doesnt have host_port nor worker_name, but if I make it pass with this

diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py
index e67d92d8..6a792502 100644
--- a/torch_xla/distributed/xla_multiprocessing.py
+++ b/torch_xla/distributed/xla_multiprocessing.py
@@ -154,13 +154,14 @@ def _setup_workers(num_devices):
         wcfg), 'World size ({}) must match the configured workers ({})'.format(
             world_size, len(wcfg))
     for h, worker in enumerate(wcfg):
-      m = re.match(r'(.*):(\d+)$', worker.host_port)
+      name, port = worker.split(":")
+      m = re.match(r'(.*):(\d+)$', worker)
       if not m:
         raise RuntimeError('Bad worker HOST:PORT format: {}'.format(
             worker.host_port))
       for i in range(0, num_devices):
         gindex = h * num_devices + i
-        workers.append('{}:{};grpc://{}:{}'.format(worker.worker_name, gindex,
+        workers.append('{}:{};grpc://{}:{}'.format(name, gindex,
                                                    m.group(1),
                                                    int(m.group(2)) + i))

Then it runs until

    xmp.spawn(map_fn, args=(flags,), nprocs=nprocs, start_method='fork')
  File "/home/tyoc213/Documents/github/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 385, in spawn
    pf_cfg = _pre_fork_setup(nprocs)
  File "/home/tyoc213/Documents/github/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 196, in _pre_fork_setup
    dev_count, dev_kind = _get_devices_per_worker()
  File "/home/tyoc213/Documents/github/pytorch/xla/torch_xla/distributed/xla_multiprocessing.py", line 88, in _get_devices_per_worker
    raise RuntimeError('Missing TPU or GPU configuration')
RuntimeError: Missing TPU or GPU configuration

That is, it can pass 1 time over def _get_devices_per_worker():, the first time it gets correctly num_gpus = os.environ.get(xenv.GPU_NUM_DEVICES, None) to 1, but the second time, dont know why it doesnt pass there, so it ends throwing raise RuntimeError('Missing TPU or GPU configuration').

I see, the second and third times is because I call again fit :)... but it works if I just hardcode the value of 1, dont know why between calls "deletes" the env var. on the same spawn func I have something like

xla_learner.fit(1)
xla_learner.fit(1)
xla_learner.fit(1)    

So I guess that is why it request that env var 3 times...

JackCaoG commented 3 years ago

If I remember correctly, if you want to use the XRT_WORKER config, you also need to manually setup the DEVICE_MAP like

export XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0|GPU:0;/job:localservice/replica:0/task:
0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:0/device:XLA_GPU:1|GPU:2;/job:localservice/replica:0/task:0/de
vice:XLA_GPU:2|GPU:3;/job:localservice/replica:0/task:0/device:XLA_GPU:3"

or you could simply use GPU_NUM_DEVICES and everything else(worker info) should be setup automatically, you can check https://github.com/pytorch/xla/blob/3eaee46ef679cc6a0f1f694bd0a007dbfd09c51b/.circleci/test.sh#L8

You can also check https://github.com/pytorch/xla/blob/3eaee46ef679cc6a0f1f694bd0a007dbfd09c51b/third_party/xla_client/computation_client.cc#L271 to see how we handle different ways to configure XLA.

XRT_TPU_CONFIG --> ParseEnvBasedTpuClusterConfig
TPU_NUM_DEVICES/GPU_NUM_DEVICES --> ParseEnvDeviceCounts
XRT_WORKERS + XRT_DEVICE_MAP --> ParseEnvDevices
stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.