jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.44k stars 2.8k forks source link

Efficiently pairing indices #2664

Open andyljones opened 4 years ago

andyljones commented 4 years ago

I've got a question about using JAX efficiently. I think this is the right place to ask, but say if StackOverflow or some forum or chat channel would be better suited.

I want to write a function that takes two arrays of values between 0 and N and outputs all pairs of indices with matching values, a la

A = [3, 1, 1]
B = [3, 3]
pairs = [
  [0, 0], 
  [0, 1]]

Each value can turn up at most a small, known number of times (~10) in each list, while the lists are a ~million items long, and the maximum value is ~five million.

Now the naive way to do this would be to test every value of A against every value of B, and use a mask to pick out the ones I'm after:

def naive(A, B):
    mask = (A[:, None] == B[None, :])
    indices = np.stack(np.meshgrid(
        np.arange(len(A)), 
        np.arange(len(B)), indexing='ij'), -1)
    return indices[mask]

Problem is of course that with million-item lists, there are a trillion comparisons to make.

The better way to go about things is to make use of that small-number-of-repeats property. If looping was cheap, I'd count-sort each list and then loop over the sorted lists in parallel and spit out the pairs as I went. That'd be the ideal solution - linear time, small time constant, minimal extra memory - but by my primitive understanding of JAX, looping over a million items just ain't on.

Instead, what I've got so far is

In all it's glory,

import jax
from jax import numpy as np

def repeats(values):
    argsort = np.argsort(values)
    ordered = values[argsort]
    last  = np.concatenate([ordered[1:] != ordered[:-1], np.array([True])])
    cum_replicas = np.concatenate([
        np.asarray([-1,]),
        np.arange(len(values))[last]])
    n_replicas = np.diff(cum_replicas)

    steps = np.ones(len(values)-1, dtype=n_replicas.dtype)
    steps = jax.ops.index_add(steps, last[:-1], -n_replicas[:-1])
    steps = np.concatenate([np.array([0]), steps])
    ordered_index = np.cumsum(steps)

    counts = np.empty_like(ordered_index)
    counts = jax.ops.index_update(counts, argsort, ordered_index)

    return counts

def ragged(values, max_value, max_repeats):
    reps = repeats(values)
    indices = np.full((max_value+1, max_repeats), -1)
    indices = jax.ops.index_update(indices, (values, reps), np.arange(len(values)))
    return indices

def pairs(A, B, max_value, max_repeats):
    A_ragged = ragged(A, max_value, max_repeats)
    B_ragged = ragged(B, max_value, max_repeats)

    ps = np.stack([
        np.repeat(A_ragged[:, None, :], max_repeats, -2),
        np.repeat(B_ragged[:, :, None], max_repeats, -1)], -1)
    return ps[(ps > -1).all(-1)]

A = np.array([3, 1, 1])
B = np.array([3, 3])

pairs(A, B, 3, 2)

Now that's torturously long, but it works! And reasonably quickly if I use hawkinsp's shiny new cumsum implementation. The implicit O(n log n) sticks in my throat a little, but it's a minor issue besides the memory consumption: building out max_repeats x max_repeats (~100) copies of a million-long list blows through my GPU memory in a way that the ideal if-looping-were-cheap approach wouldn't.

So: is there a better way? I'm not after working code, more pointers like 'think about using scatter and gather in this clever way'. I've read through the docs but seeing as I've only been playing with JAX since this morning I've likely misunderstood a lot of stuff.

If there isn't a better way, welp, time to write some CustomCalls.

As some additional context: this springs from a subproblem I've hit on while implementing a fast-multipole method in JAX. Fast multipole methods split the universe is split into cells, and I need to generate the list of all pairs of masses and measurement points that cohabit the same cell.

That aside, I've gotta say folks: JAX is damned impressive. Today I've gotten a performance bump entirely disproportionate to the effort put in, and that's the highest praise I can think of for a numerics library. Great work all of you!

jekbradbury commented 4 years ago

If 2 x max_value x max_repeats is an acceptable memory cost, you could follow your existing strategy but build and process the max_value x max_repeats x max_repeats x 2 array incrementally (in e.g. chunks of 1000 along the max_value axis). This is reasonable if that incremental loop is done from Python but wouldn't work if you want to jit through it (because each chunk would result in an addition to the result list bounded only by chunk_size x max_repeats x max_repeats, which means we'd need the whole array worth of memory again).

This seems like a hard problem for JAX/XLA in part because of sparsity and related dynamism—the output you're looking for is essentially an adjacency list representation of a sparse matrix that's infeasible to represent densely—so I wonder if additional sparse primitives could be helpful here (I can't think of how to apply scatter/gather on their own, though).

The "if-looping-were-cheap" approach unfortunately also seems like a fundamentally serial/single-threaded approach, while JAX is only able to accelerate things on GPUs and TPUs by exploiting parallelism in something like a "single program multiple data" way. Maybe it would be helpful to try to think about what kind of algorithm we'd want to use here on a highly multithreaded CPU or even a distributed-memory cluster?

andyljones commented 4 years ago

Thanks very much James! I ended up doing something similar to what you suggest, taking chunks. I did it by chunking based on values which have a set number of repeats though, which has the advantage of making the inner op much more SIMD-y:

def value_counts(vals, max_value):
    counts = np.zeros((max_value+1,), dtype=np.int32)
    return jax.ops.index_add(counts, vals, 1)

def pairs(A, B, max_value, max_repeats):
    A_order = np.argsort(A)
    A_sorted = A[A_order]
    A_counts = value_counts(A_sorted, max_value)

    B_order = np.argsort(B)
    B_sorted = B[B_order]
    B_counts = value_counts(B_sorted, max_value)

    pairs = []
    for A_count in range(1, max_repeats+1):
        for B_count in range(1, max_repeats+1):
            mask = (A_counts == A_count) & (B_counts == B_count)
            s, = mask[A_sorted].nonzero()
            t, = mask[B_sorted].nonzero()

            ps = np.stack([
                    np.repeat(s.reshape(mask.sum(), A_count, 1), B_count, 2),
                    np.repeat(t.reshape(mask.sum(), 1, B_count), A_count, 1)], -1).reshape(-1, 2)
            pairs.append(ps)
    pairs = np.concatenate(pairs)
    pairs = np.stack([A_order[pairs[..., 0]], B_order[pairs[..., 1]]], -1)
    return pairs

The upside is that I've cut the memory usage down to the bare minimum; the downside is I'm making hundred times as many calls into JAX.

I expected those hundred-times-as-many-calls to be covered by the work being done on the GPU, but in practice it worked out to 1.5s-ish on 1-million element lists. That seemed kinda slow for an RTX 2080, so I ported the code to Pytorch and got an instant 20x speedup:

import jax 
from jax import numpy as np
import numpy as onp
import torch
import time

class JAX:

    def value_counts(self, vals, max_value):
        counts = np.zeros((max_value+1,), dtype=np.int32)
        return jax.ops.index_add(counts, vals, 1)

    def pairs(self, A, B, max_value, max_repeats):
        A, B = np.asarray(A), np.asarray(B)

        A_order = np.argsort(A)
        A_sorted = A[A_order]
        A_counts = self.value_counts(A_sorted, max_value)

        B_order = np.argsort(B)
        B_sorted = B[B_order]
        B_counts = self.value_counts(B_sorted, max_value)

        pairs = []
        for A_count in range(1, max_repeats+1):
            for B_count in range(1, max_repeats+1):
                mask = (A_counts == A_count) & (B_counts == B_count)
                s, = mask[A_sorted].nonzero()
                t, = mask[B_sorted].nonzero()

                ps = np.stack([
                        np.repeat(s.reshape(mask.sum(), A_count, 1), B_count, 2),
                        np.repeat(t.reshape(mask.sum(), 1, B_count), A_count, 1)], -1).reshape(-1, 2)
                pairs.append(ps)
        pairs = np.concatenate(pairs)
        pairs = np.stack([A_order[pairs[..., 0]], B_order[pairs[..., 1]]], -1)
        return onp.asarray(pairs)

class Torch:

    def value_counts(self, vals, max_value):
        ones = vals.new_ones(len(vals,), dtype=torch.int32)
        counts = vals.new_zeros((max_value+1,), dtype=torch.int32)
        counts.index_add_(0, vals, ones)
        return counts

    def pairs(self, A, B, max_value, max_repeats):
        A, B = torch.tensor(A).cuda(), torch.tensor(B).cuda()

        A_order = torch.argsort(A)
        A_sorted = A[A_order]
        A_counts = self.value_counts(A_sorted, max_value)

        B_order = torch.argsort(B)
        B_sorted = B[B_order]
        B_counts = self.value_counts(B_sorted, max_value)

        pairs = []
        for A_count in range(1, max_repeats+1):
            for B_count in range(1, max_repeats+1):
                mask = (A_counts == A_count) & (B_counts == B_count)
                s = mask[A_sorted].nonzero()
                t = mask[B_sorted].nonzero()

                ps = torch.stack([
                        torch.repeat_interleave(s.reshape(mask.sum(), A_count, 1), B_count, 2),
                        torch.repeat_interleave(t.reshape(mask.sum(), 1, B_count), A_count, 1)], -1).reshape(-1, 2)
                pairs.append(ps)
        pairs = torch.cat(pairs)
        pairs = torch.stack([A_order[pairs[..., 0]], B_order[pairs[..., 1]]], -1)
        return pairs.cpu().numpy()

def random_problem():
    max_value = int(1e6)
    length = int(1e6)
    onp.random.seed(1)
    A = onp.random.choice(onp.arange(max_value), length)
    B = onp.random.choice(onp.arange(max_value), length)
    max_repeats = max(onp.bincount(A).max(), onp.bincount(B).max())
    return A, B, max_value, max_repeats

def profile(lib=None, repeats=5):
    if lib is None:
        profile(Torch())
        profile(JAX())
        return

    name = type(lib).__name__
    prob = random_problem()

    print(f'{name}: warming up')
    lib.pairs(*prob)

    print(f'{name}: benchmarking')
    start = time.time()
    for _ in range(repeats):
        lib.pairs(*prob)
    end = time.time()
    print(f'{name}: {(end - start)/repeats:.2f}s per')

profile()
# Torch: warming up
# Torch: benchmarking
# Torch: 0.06s per
# JAX: warming up
# JAX: benchmarking
# JAX: 1.19s per

Now I think I've avoided the usual pitfall of benchmarking-without-moving-the-result-back-to-the-CPU, but this seems extreme. Setting CUDA_LAUNCH_BLOCKING=1 and running snakeviz on the warmed-up libs, I get this breakdown for Torch:

image

and this for JAX:

image

It looks like Torch dispatches directly into fairly chunky hand-rolled C code, while JAX does a lot of compilation on the fly. This is exactly the kind of work jax.jit should hoist to the top, but trying to jit it raises an error about the mask indices not being static.

Now presumably the indices need to be static because JAX's JIT'r expects arrays of fixed dimension. There's an offhand comment in the docs about UnshapedArrays, but I can't find a description of what UnshapedArrays are, or how to set the level of abstraction, or how well levels of abstraction other than ShapedArrays are supported.

I think this might play into your comments on sparse matrices? My intuition is sparse work involves a lot of arrays of of wildly varying dimension, which are more suitably thought of as Unshaped than Shaped?

That aside, two concrete questions: how can I finangle JAX's JIT'r into getting Torch-level performance on this problem? Or am I fundamentally using a screwdriver on a nail here?

mattjj commented 4 years ago

Thanks so much for the kind words and the excellent writeup, @andyljones ! And indeed this is the perfect place to bring up these issues so we can think about them together.

I'm looking forward to digging in as well when I can! For now, I just wanted to say thanks for raising this :)