Open andyljones opened 4 years ago
If 2 x max_value x max_repeats
is an acceptable memory cost, you could follow your existing strategy but build and process the max_value x max_repeats x max_repeats x 2
array incrementally (in e.g. chunks of 1000 along the max_value
axis). This is reasonable if that incremental loop is done from Python but wouldn't work if you want to jit through it (because each chunk would result in an addition to the result list bounded only by chunk_size x max_repeats x max_repeats
, which means we'd need the whole array worth of memory again).
This seems like a hard problem for JAX/XLA in part because of sparsity and related dynamism—the output you're looking for is essentially an adjacency list representation of a sparse matrix that's infeasible to represent densely—so I wonder if additional sparse primitives could be helpful here (I can't think of how to apply scatter/gather on their own, though).
The "if-looping-were-cheap" approach unfortunately also seems like a fundamentally serial/single-threaded approach, while JAX is only able to accelerate things on GPUs and TPUs by exploiting parallelism in something like a "single program multiple data" way. Maybe it would be helpful to try to think about what kind of algorithm we'd want to use here on a highly multithreaded CPU or even a distributed-memory cluster?
Thanks very much James! I ended up doing something similar to what you suggest, taking chunks. I did it by chunking based on values which have a set number of repeats though, which has the advantage of making the inner op much more SIMD-y:
def value_counts(vals, max_value):
counts = np.zeros((max_value+1,), dtype=np.int32)
return jax.ops.index_add(counts, vals, 1)
def pairs(A, B, max_value, max_repeats):
A_order = np.argsort(A)
A_sorted = A[A_order]
A_counts = value_counts(A_sorted, max_value)
B_order = np.argsort(B)
B_sorted = B[B_order]
B_counts = value_counts(B_sorted, max_value)
pairs = []
for A_count in range(1, max_repeats+1):
for B_count in range(1, max_repeats+1):
mask = (A_counts == A_count) & (B_counts == B_count)
s, = mask[A_sorted].nonzero()
t, = mask[B_sorted].nonzero()
ps = np.stack([
np.repeat(s.reshape(mask.sum(), A_count, 1), B_count, 2),
np.repeat(t.reshape(mask.sum(), 1, B_count), A_count, 1)], -1).reshape(-1, 2)
pairs.append(ps)
pairs = np.concatenate(pairs)
pairs = np.stack([A_order[pairs[..., 0]], B_order[pairs[..., 1]]], -1)
return pairs
The upside is that I've cut the memory usage down to the bare minimum; the downside is I'm making hundred times as many calls into JAX.
I expected those hundred-times-as-many-calls to be covered by the work being done on the GPU, but in practice it worked out to 1.5s-ish on 1-million element lists. That seemed kinda slow for an RTX 2080, so I ported the code to Pytorch and got an instant 20x speedup:
import jax
from jax import numpy as np
import numpy as onp
import torch
import time
class JAX:
def value_counts(self, vals, max_value):
counts = np.zeros((max_value+1,), dtype=np.int32)
return jax.ops.index_add(counts, vals, 1)
def pairs(self, A, B, max_value, max_repeats):
A, B = np.asarray(A), np.asarray(B)
A_order = np.argsort(A)
A_sorted = A[A_order]
A_counts = self.value_counts(A_sorted, max_value)
B_order = np.argsort(B)
B_sorted = B[B_order]
B_counts = self.value_counts(B_sorted, max_value)
pairs = []
for A_count in range(1, max_repeats+1):
for B_count in range(1, max_repeats+1):
mask = (A_counts == A_count) & (B_counts == B_count)
s, = mask[A_sorted].nonzero()
t, = mask[B_sorted].nonzero()
ps = np.stack([
np.repeat(s.reshape(mask.sum(), A_count, 1), B_count, 2),
np.repeat(t.reshape(mask.sum(), 1, B_count), A_count, 1)], -1).reshape(-1, 2)
pairs.append(ps)
pairs = np.concatenate(pairs)
pairs = np.stack([A_order[pairs[..., 0]], B_order[pairs[..., 1]]], -1)
return onp.asarray(pairs)
class Torch:
def value_counts(self, vals, max_value):
ones = vals.new_ones(len(vals,), dtype=torch.int32)
counts = vals.new_zeros((max_value+1,), dtype=torch.int32)
counts.index_add_(0, vals, ones)
return counts
def pairs(self, A, B, max_value, max_repeats):
A, B = torch.tensor(A).cuda(), torch.tensor(B).cuda()
A_order = torch.argsort(A)
A_sorted = A[A_order]
A_counts = self.value_counts(A_sorted, max_value)
B_order = torch.argsort(B)
B_sorted = B[B_order]
B_counts = self.value_counts(B_sorted, max_value)
pairs = []
for A_count in range(1, max_repeats+1):
for B_count in range(1, max_repeats+1):
mask = (A_counts == A_count) & (B_counts == B_count)
s = mask[A_sorted].nonzero()
t = mask[B_sorted].nonzero()
ps = torch.stack([
torch.repeat_interleave(s.reshape(mask.sum(), A_count, 1), B_count, 2),
torch.repeat_interleave(t.reshape(mask.sum(), 1, B_count), A_count, 1)], -1).reshape(-1, 2)
pairs.append(ps)
pairs = torch.cat(pairs)
pairs = torch.stack([A_order[pairs[..., 0]], B_order[pairs[..., 1]]], -1)
return pairs.cpu().numpy()
def random_problem():
max_value = int(1e6)
length = int(1e6)
onp.random.seed(1)
A = onp.random.choice(onp.arange(max_value), length)
B = onp.random.choice(onp.arange(max_value), length)
max_repeats = max(onp.bincount(A).max(), onp.bincount(B).max())
return A, B, max_value, max_repeats
def profile(lib=None, repeats=5):
if lib is None:
profile(Torch())
profile(JAX())
return
name = type(lib).__name__
prob = random_problem()
print(f'{name}: warming up')
lib.pairs(*prob)
print(f'{name}: benchmarking')
start = time.time()
for _ in range(repeats):
lib.pairs(*prob)
end = time.time()
print(f'{name}: {(end - start)/repeats:.2f}s per')
profile()
# Torch: warming up
# Torch: benchmarking
# Torch: 0.06s per
# JAX: warming up
# JAX: benchmarking
# JAX: 1.19s per
Now I think I've avoided the usual pitfall of benchmarking-without-moving-the-result-back-to-the-CPU, but this seems extreme. Setting CUDA_LAUNCH_BLOCKING=1
and running snakeviz on the warmed-up libs, I get this breakdown for Torch:
and this for JAX:
It looks like Torch dispatches directly into fairly chunky hand-rolled C code, while JAX does a lot of compilation on the fly. This is exactly the kind of work jax.jit
should hoist to the top, but trying to jit it raises an error about the mask indices not being static.
Now presumably the indices need to be static because JAX's JIT'r expects arrays of fixed dimension. There's an offhand comment in the docs about UnshapedArrays
, but I can't find a description of what UnshapedArrays
are, or how to set the level of abstraction, or how well levels of abstraction other than ShapedArrays
are supported.
I think this might play into your comments on sparse matrices? My intuition is sparse work involves a lot of arrays of of wildly varying dimension, which are more suitably thought of as Unshaped
than Shaped
?
That aside, two concrete questions: how can I finangle JAX's JIT'r into getting Torch-level performance on this problem? Or am I fundamentally using a screwdriver on a nail here?
Thanks so much for the kind words and the excellent writeup, @andyljones ! And indeed this is the perfect place to bring up these issues so we can think about them together.
I'm looking forward to digging in as well when I can! For now, I just wanted to say thanks for raising this :)
I've got a question about using JAX efficiently. I think this is the right place to ask, but say if StackOverflow or some forum or chat channel would be better suited.
I want to write a function that takes two arrays of values between 0 and N and outputs all pairs of indices with matching values, a la
Each value can turn up at most a small, known number of times (~10) in each list, while the lists are a ~million items long, and the maximum value is ~five million.
Now the naive way to do this would be to test every value of A against every value of B, and use a mask to pick out the ones I'm after:
Problem is of course that with million-item lists, there are a trillion comparisons to make.
The better way to go about things is to make use of that small-number-of-repeats property. If looping was cheap, I'd count-sort each list and then loop over the sorted lists in parallel and spit out the pairs as I went. That'd be the ideal solution - linear time, small time constant, minimal extra memory - but by my primitive understanding of JAX, looping over a million items just ain't on.
Instead, what I've got so far is
max_value x max_repeats
ragged array that maps each value to the indices with that value.max_value x max_repeats x max_repeats x 2
array of possible pairs.In all it's glory,
Now that's torturously long, but it works! And reasonably quickly if I use hawkinsp's shiny new cumsum implementation. The implicit
O(n log n)
sticks in my throat a little, but it's a minor issue besides the memory consumption: building outmax_repeats x max_repeats
(~100) copies of a million-long list blows through my GPU memory in a way that the ideal if-looping-were-cheap approach wouldn't.So: is there a better way? I'm not after working code, more pointers like 'think about using scatter and gather in this clever way'. I've read through the docs but seeing as I've only been playing with JAX since this morning I've likely misunderstood a lot of stuff.
If there isn't a better way, welp, time to write some
CustomCall
s.As some additional context: this springs from a subproblem I've hit on while implementing a fast-multipole method in JAX. Fast multipole methods split the universe is split into cells, and I need to generate the list of all pairs of masses and measurement points that cohabit the same cell.
That aside, I've gotta say folks: JAX is damned impressive. Today I've gotten a performance bump entirely disproportionate to the effort put in, and that's the highest praise I can think of for a numerics library. Great work all of you!