google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.09k stars 2.66k forks source link

Spoof multiple hosts #5155

Open alewis opened 3 years ago

alewis commented 3 years ago

Hello,

It is possible to spoof multiple devices on CPU by setting xla_force_host_platform_device_count to the desired value. Is there an analogous method to spoof multiple hosts with multiple devices each?

hawkinsp commented 3 years ago

Currently, no, because we have no multihost/multiprocess CPU support yet (real or fake). It's probably not that hard to add (e.g., we could plumb in Gloo) but it hasn't been very high on our priority list.

The closest equivalent I can suggest is that you could use a multi-GPU or multi-TPU setup with one process per GPU or group of GPUs. However we haven't published an API for multi-host GPU either, yet. All the pieces are there already, they just need Python APIs.

alewis commented 3 years ago

Yeah so basically we are just looking for a way to develop for a multi-accelerator-node setup without consuming accelerator resources during debugging. But, fair enough.

jacksonwb commented 3 years ago

For a little background we are developing on the Cloud TPU VM alpha and are working with a multi-TPU pod slice setup with one process / host per TPU and are investigating ways to capture the semantics of that setup (local_device_count != device_count) in a unit testing environment without pod access...

hawkinsp commented 3 years ago

Yes, it definitely makes sense!

I think that also says you don't care much about performance; what you care about is having a reference implementation you can test against locally. Seems reasonable and not too high a bar!

jacksonwb commented 3 years ago

Exactly. Our production setup and performance tests will all happen on pod slices. We just need unit tests to work and to be able to do some local validation.

lucasliunju commented 3 years ago

Hi @jacksonwb

I am trying to use multi-host tpu and I think your suggestion is very important. Could you please provide some examples about how to use it. Thank you!

jakevdp commented 3 years ago

The transformations doc page mimics an 8-device machine on a single CPU; you can see how in the setup code block in the source:

  # Set up runtime to mimic an 8-core machine for pmap example below:
  import os
  flags = os.environ.get('XLA_FLAGS', '')
  os.environ['XLA_FLAGS'] = flags + " --xla_force_host_platform_device_count=8"
lucasliunju commented 3 years ago

Hi @jakevdp

Thanks for your suggestion. Currently, I find I cannot connect jax and multi-host tpu (such as v3-128).

My code is:

from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://[my tpu ip address]:8470"

I find that can work on tpu v3-8 but not v3-32, v3-64, v3-128.

The error is:

RuntimeError: Deadline exceeded: Failed to connect to remote server at address: grpc://10.59.106.169:8470. Error fr
om gRPC: Deadline Exceeded. Details

I think maybe my setting is wrong.

Thank you!

jakevdp commented 3 years ago

Currently, I find I cannot connect jax and multi-host tpu (such as v3-128).

Thanks - this thread is about spoofing multiple hosts on CPU, not actually using a multi-host TPU. It looks like you already asked this question here: https://github.com/google/jax/discussions/6164; I'd stick to that thread for that particular issue: hopefully someone will know how to answer.

txctxctxc commented 2 years ago

Currently, no, because we have no multihost/multiprocess CPU support yet (real or fake). It's probably not that hard to add (e.g., we could plumb in Gloo) but it hasn't been very high on our priority list.

The closest equivalent I can suggest is that you could use a multi-GPU or multi-TPU setup with one process per GPU or group of GPUs. However we haven't published an API for multi-host GPU either, yet. All the pieces are there already, they just need Python APIs.

Hi @hawkinsp We are very interested in John's recent work Alphafold2. However, we have encountered a problem that JAX can not be used in a multi-machine GPU environment. We know that JAX has not provided Python API for multi-machine implementation on GPU platform at present. But you have already implemented this function, and we are anxious to solve this problem now. Could you please tell us how JAX realizes multi-machine on GPU platform?

fajieyuan commented 2 years ago

The closest equivalent

Hi there, we are implement AlphaFold2, which is based on jax but it does not support multi-host GPU,can you help us?

mrahtz commented 2 years ago

@hawkinsp Is the current status of this that we're still waiting for a JAX developer to implement the necessary changes? (Or is the bottleneck XLA itself, as suggested by https://github.com/google/jax/discussions/9688?)

The ability to spoof multiple hosts would be super helpful for some of the stuff we're using JAX for at DeepMind - running tests for multi-host code otherwise requires spinning up a bunch of TPUs, which is painful. So big +1 to this.

youurayy commented 4 months ago

The transformations doc page mimics an 8-device machine on a single CPU; you can see how in the setup code block in the source:

  # Set up runtime to mimic an 8-core machine for pmap example below:
  import os
  flags = os.environ.get('XLA_FLAGS', '')
  os.environ['XLA_FLAGS'] = flags + " --xla_force_host_platform_device_count=8"

this works but don't forget jax.config.update('jax_platform_name', 'cpu') if your system also has GPU.

dlwh commented 4 months ago

@youurayy I think the point is to run multiple processes, not just multiple cpu devices. It's helpful for testing.