google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

none-filled arrays #1580

Open jvdillon opened 4 years ago

jvdillon commented 4 years ago
import jax.numpy as jnp
jnp.array([None]*3)
# ==> TypeError: Unexpected input type for array: <class 'NoneType'>

# Yet:
np.array([None]*3)
# ==> array([None, None, None], dtype=object)

Is this expected?

hawkinsp commented 4 years ago

Well, the error could perhaps use some work. But yes, it's working as intended: you can't store None values in jax arrays, and object arrays aren't supported.

Can you say a bit more about what you are trying to do?

jvdillon commented 4 years ago

Its rather complex, but Id like to create a jax array which has list semantics and this is how id have preferred to initialize the object.

mattjj commented 4 years ago

I think it's unlikely that you'll be able to do that with JAX (EDIT i.e. make a jax array have list semantics, like a potentially-heterogeneous collection without even consistent dtypes). If you want a list, why not just use a Python list? If it's something you want to stage out entirely to XLA, then you can't use list-like things because XLA doesn't have recursive data types (or more generally programs with a-priori unbounded memory usage).

Baschdl commented 4 years ago

The current error message when trying to convert a standard numpy object array to a DeviceArray is a bit cryptic. I think TypeError: JAX only supports number and bool dtypes, got dtype <class 'object'> would also fit for this case.

import jax.numpy as np
import numpy as np_standard

np.array([0, 1], dtype=object)
# TypeError: JAX only supports number and bool dtypes, got dtype <class 'object'>

np.array(np_standard.array([0, 1], dtype=object))
# RuntimeError: Invalid argument: Unknown NumPy type O size 8
Baschdl commented 4 years ago

Is it possible to implement an easier version like an array of heterogeneous-sized arrays with the same dtype or would we always end up extending the smaller sized arrays with placeholders? Something like np.array([[0], [0, 1], [0, 0, 1]).

Baschdl commented 4 years ago

A vmap over a object arrays throws an even more cryptic error:

...
~/.local/lib/python3.7/site-packages/jax/core.py in __init__(self, val, weak_type)
    896     # Note: canonicalized self.dtype doesn't necessarily match self.val
    897     self.val = val
--> 898     assert self.dtype != onp.dtype('O')
    899 
    900   def __eq__(self, other):

AssertionError: 
Baschdl commented 4 years ago

Is it possible to implement an easier version like an array of heterogeneous-sized arrays with the same dtype or would we always end up extending the smaller sized arrays with placeholders? Something like np.array([[0], [0, 1], [0, 0, 1]).

I came to a point where this is unfortunately a big bottleneck. I can implement all my computations with fixed size arrays but in the end I need to aggregate them in some way to use them further. I currently do something like

indices = np.arange(len(data))
return [indices[boolean_mask] for boolean_mask in boolean_masks]

to get the indices of the data points which fulfill some requirement I calculated with pure jax code. This aggregation takes ten times more time then the actual computation which profits greatly from jax. Did you intend some other way to aggregate something like this or is this the only way? This would profit greatly from a np.array(x, dtype=object) if my aggregation is the only way.


I can speed it up a bit with multiprocessing

def apply_bitmask(array, boolean_mask):
        return array[boolean_mask]

from itertools import repeat
from multiprocessing import Pool

with Pool(3) as p:
        neighbors = p.starmap(apply_bitmask, zip(repeat(indices), boolean_masks))

but I think there could be a better internal solution.

jekbradbury commented 4 years ago

One approach that might work for your use case would be to convert your arrays to standard NumPy arrays after JAX processing, then do the dynamic/heterogeneous parts of your algorithm in NumPy or something else. The XLA compiler that powers JAX is based around operations on rectangular, dense arrays (even outside of XLA, those are much more likely to benefit significantly from GPU/TPU acceleration than other kinds of data structures).