jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

Documentation and examples for multithreaded/multicore parallelism using XLA:CPU #8345

Open amir-saadat opened 2 years ago

amir-saadat commented 2 years ago

The parallelism documentation is mostly focused on multiple TPU/GPU devices. It would be great if jax team add concrete explanations on how jax uses parallelism when the runtime system is composed of multiple CPU cores. Ideally, the documentation may contain both within-op parallelism (what happens when x = jnp.ones((5000, 5000)); jnp.dot(x, x) is performed) as well as parallel directive pmap.

Specifically, if one allocates a certain number of cores, what is the most meaningful way of using pmap, e.g. does one need to specify the same number of XLA devices before importing jax? e.g.

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=_NUM_ALLOCATED_DEVICES"

However, this will cause an issue if the size of the mapped dimension of the input array is larger than the number of allocated devices. In such a case, are you supposed to reshape the input to fit to the number of XLA devices? Which also begs the question what to do for parallelizing the inner dimension (a nested pmap?) I have seen a discussion in another thread but I can't get a concrete conclusion. I also tried to use other mapping options. In my use case, however, they take much longer than setting the xla_force_host_platform_device_count to a large enough number, so pmap doesn't complain about the dimension size exceeding the number of XLA devices.

As another example, is it expected to get good performance if only one physical core is allocated, but multiple XLA devices are defined before importing jax? -- in my tests, this way results in a considerably faster runtime in comparison with performing the operation sequentially (looping over the input mapped dimension).

mohamad-amin commented 2 years ago

Hey @amir-saadat

Have you tried using jit(vmap(f), backend='cpu') without pmap? In my use case, weirdly, jit(vmap(f)) results in using all the available cores at 100% utilization (according to htop at least) but using pmap results in using one core only for a long time and finally leads the process to being killed (I assume because of memory issues). So my questions are:

tavin commented 2 years ago

I wish pmap -- or something like it -- would just figure out how to split up your thousands/millions of array elements across the number of devices (cpu cores) you have, instead of having to manually reshape the input array with an extra dimension matching those devices/cores (and guarantee your array length is a multiple of that). It seems to me the docs could make this behavior more clear at least.

doctor-phil commented 11 months ago

I am using my CPU as a test, since I don't have access to a nvidia gpu. Can someone tell me if pmap will also require me to know exactly how many threads are on the GPU, and split the problem up like this?

FWIW, so far, jit(vmap) with backend='cpu' is much faster than pmap with --xla_force_host_platform_device_count=501. (Obviously I don't have 500 cpu cores, but my test case has 500 array elements to parallelize.)

If pmap works (albeit slowly) on my cpu while spoofing the number of cores, can I count on it working when I pay for access to a GPU? The full array will have hundreds of thousands of elements, do I need to break them up into batches for the GPU to work properly?

cgarciae commented 11 months ago

@doctor-phil you would need more than 1 GPU for pmap to be useful.

doctor-phil commented 11 months ago

@cgarciae I must be misunderstanding the documentation. If I only have 1 gpu, I only need to use vmap?

cgarciae commented 11 months ago

@doctor-phil correct