aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 153 forks source link

Implement JAX conversions for RandomFunction/RandomVariable Ops #146

Closed ferrine closed 3 years ago

ferrine commented 4 years ago

Random functions do not work with jax linker

opts = theano.gof.Query(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = theano.compile.mode.Mode(theano.sandbox.jax_linker.JAXLinker(), opts)
r = theano.tensor.shared_randomstreams.RandomStreams()
rfj = theano.function([], r.normal(), mode=jax_mode)

raises

NotImplementedError: No JAX conversion for the given `Op`: RandomFunction{normal}
brandonwillard commented 4 years ago

The RandomFunction Op is in the process of being replaced via #137. Once #137 is merged, we can implement JAX conversions for RandomVariable Ops.

michaelosthege commented 3 years ago

Pinging @ferrine because the RandomVariable was merged.

ferrine commented 3 years ago

Cool

kc611 commented 3 years ago

We can use jax.random as the jax counterpart for numpy.random.RandomState which the RandomVariable internally uses but it seems kind-of limited (Distributions like Binomial,HalfNormal are not present). I think most of the work here would be getting the shapes and arguments right.

If anyone isn't working on this I'd like to take this on. An example usage would be helpful too. Something like this https://github.com/pymc-devs/aesara/issues/146#issue-737867223 but with concrete inputs and outputs to see if the implementation works. (The RandomVariable is fairly new so I'll need to study it's use interface.)

twiecki commented 3 years ago

@kc611 Would be great to get some help here, it's yours!

kc611 commented 3 years ago

I tried a simple straightforward implementation (without shape handling, to see the arguments and their datatypes) as follows:

@jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op):
    name = op.name
    def random_variable(rng, size, dtype, *args):
        smpl_value = getattr(jax.random, name)(jax.random.PRNGKey(0), size)
        return (dtype, smpl_value)

    return random_variable

And upon running the following code:

import aesara
import aesara.tensor as aet
from aesara.tensor.random.basic import normal
import jax

jax_mode = aesara.compile.Mode(aesara.link.jax.jax_linker.JAXLinker())
a = aet.fscalar()
b = normal(size=3)
c = a + b
f = aesara.function([a], [c], mode=jax_mode)
print(f(2.5))

It throws the following exception:

TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type
Apply node that caused the error: Elemwise{Add}[(0, 1)](InplaceDimShuffle{x}.0, normal_rv.out)
Toposort index: 2
Inputs types: [TensorType(float32, (True,)), TensorType(float64, vector)]
Inputs shapes: [(), 'No shapes']
Inputs strides: [(), 'No strides']
Inputs values: [array(2.5, dtype=float32), RandomState(MT19937) at 0x7F0EF518EA40]
Outputs clients: [['output']]

It seems that jax is not accepting rng (of type RandomState) as an input argument. Is there any way around this ? Should rng be made a class data member in RandomVariable rather than being a member of an Apply node input ? (it's only job seems to be for behaving as a seed generator).

kc611 commented 3 years ago

Tagging @brandonwillard , since he'll be the person with most insight on this.

brandonwillard commented 3 years ago

I'll need to see the full trace to make sense of it.

kc611 commented 3 years ago

Here goes :

Traceback (most recent call last):
  File "/home/kaustubh/Desktop/Codeground/ddsgfdg/newiss.py", line 11, in <module>
    print(f(2.5))
  File "/home/kaustubh/Desktop/Github/aesara/aesara/compile/function/types.py", line 974, in __call__
    self.fn()
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 181, in streamline_default_f
    thunk()
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 80, in thunk
    outputs = [
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 81, in <listcomp>
    jax_impl_jit(*[x[0] for x in thunk_inputs])
jax.traceback_util.FilteredStackTrace: TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 181, in streamline_default_f
    thunk()
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 80, in thunk
    outputs = [
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 81, in <listcomp>
    jax_impl_jit(*[x[0] for x in thunk_inputs])
  File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 209, in f_jitted
    _check_arg(arg)
  File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 2117, in _check_arg
    raise TypeError("Argument '{}' of type {} is not a valid JAX type"
TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/kaustubh/Desktop/Codeground/ddsgfdg/newiss.py", line 11, in <module>
    print(f(2.5))
  File "/home/kaustubh/Desktop/Github/aesara/aesara/compile/function/types.py", line 974, in __call__
    self.fn()
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 185, in streamline_default_f
    raise_with_op(fgraph, node, thunk)
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 508, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 181, in streamline_default_f
    thunk()
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 80, in thunk
    outputs = [
  File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 81, in <listcomp>
    jax_impl_jit(*[x[0] for x in thunk_inputs])
  File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 209, in f_jitted
    _check_arg(arg)
  File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 2117, in _check_arg
    raise TypeError("Argument '{}' of type {} is not a valid JAX type"
TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type
Apply node that caused the error: Elemwise{Add}[(0, 1)](InplaceDimShuffle{x}.0, normal_rv.out)
Toposort index: 2
Inputs types: [TensorType(float32, (True,)), TensorType(float64, vector)]
Inputs shapes: [(), 'No shapes']
Inputs strides: [(), 'No strides']
Inputs values: [array(2.5, dtype=float32), RandomState(MT19937) at 0x7F23E955FA40]
Outputs clients: [['output']]

HINT: Re-running with most Aesara optimization disabled could give you a back-trace of when this node was created. This can be done with by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.

Note : Not having the rng in the inputs of Apply node doesn't cause this error and performs as expected

brandonwillard commented 3 years ago

OK, I think this particular case may require special handling for a new data type (i.e. RandomStates).

This sort of change/addition might start somewhere in the input-conversion conditions within aesara.link.jax.jax_dispatch.compose_jax_funcs. That loop is basically where "primitive" Types are handled, but, currently, only TensorTypes are explicitly supported. All other types are passed through to JAX as-is.

In other words, when an Aesara Variable has no "owner", there's no graph to climb, and—in these cases—we could be dealing with something like a straight numpy.ndarray that was represented by a Constant. Constant arrays simply need to converted into a JAX array via jnp.array, but, in this case, we're dealing with a RandomState-typed object that doesn't have a dtype field/property, so it generates a function JAX-able function that returns the object as-is (i.e. here). I'm assuming that this is the source of the error.

We should make generic functions that dispatch on i.type in this branch. The conversion function registered to the Type of a RandomState could then return None and avoid this error; however, there is a corresponding JAX type/parameter for this Aesara/NumPy RandomState parameter, and we should attempt to make a connection between the two.

If we can convert a NumPy RandomState object to its corresponding JAX type, that would be perfect, because our conversion function would only need to do that and everything would work as intended. I doubt that's the case, so we might need to create our own conversion function. Also, those RandomState objects often come from shared Aesara variables, so we're likely to encounter the exact same RandomState multiple times, which means we cannot shallowly copy or recreate those corresponding JAX objects—they need to have the same internal state, or, better yet, be the exact same JAX RNG objects. In other words, we'll probably need to keep a map/cache of RandomState-to-JAX-RNG-objects.

kc611 commented 3 years ago

Thanks for the deep dive. I got the gist of it. The conversion can be done as follows :

Numpy's RandomState allows for 2^(32*623) - 1 number of possible states. This can be accessed by get_state() as an array with 623 elements of 32 bit unsigned integers. Meanwhile Jax uses PRNGs in jax.random to convert a certain seed into a key and ensures reproduciblity but only allows for 2^(32*2) - 1 number of possible states. So we can generate a PRNGKey using a certain operation done on the array returned by get_state() (Something that'll reduce down the 623 elements to only two. Or we could simply take the first two elements and use them as a seed.) This way if the Numpy's state remains same we can ensure that the Jax conversion of it will also be same. (Same seed will give same key for PRNGs).

kc611 commented 3 years ago

That loop is basically where "primitive" Types are handled, but, currently, only TensorTypes are explicitly supported. All other types are passed through to JAX as-is.

@brandonwillard I suspect that there is also a similar as-is passing being done somewhere else too. I implemented a primitive generic dispatch functionality at the code that you pointed out in compose_jax_functions() as follows:

    for i in out_node.inputs:
        if i in fgraph_inputs:
            # This input is a top-level input (i.e. an input to the
            # `FunctionGraph` in which this `out_node` resides)
            idx = fgraph_inputs.index(i)
            input_f = get_jax_data_func(i.type, i, idx)

        elif i.owner is None:
            # This input is something like a `aesara.graph.basic.Constant`
            input_f = get_jax_data_func(i.type, i)

where the function representations are as follows :

@singledispatch
def get_jax_data_func(i_type, i, idx=None):
    def jax_data_func(*inputs, i=i, idx=idx):
        i_dtype = getattr(i, "dtype", None)
        i_data = i.data if idx is None else inputs[idx]

        if i_dtype is None:
            return i_data
        else:
            return jnp.array(i_data, dtype=jnp.dtype(i_dtype))

    return jax_data_func

@get_jax_data_func.register(RandomStateType)
def get_jax_data_func_RandomState(i_type, i, idx=None):
    def random_state(*inputs, i=i, idx=idx):
        return jax.random.PRNGKey(i.get_state()[0]) 
    return random_state

This still throws the same exception. It seems that the exception is thrown before the these composed functions are being actually run. (Maybe while converting inputs to jax format ?)

brandonwillard commented 3 years ago

Looks like you have the right idea, but I'll probably need to walk through an explicit implementation in order to help further.

At this point, you should put in a (draft) PR. We can discuss these details there.