Closed ferrine closed 3 years ago
The RandomFunction
Op
is in the process of being replaced via #137. Once #137 is merged, we can implement JAX conversions for RandomVariable
Op
s.
Pinging @ferrine because the RandomVariable
was merged.
Cool
We can use jax.random
as the jax counterpart for numpy.random.RandomState
which the RandomVariable
internally uses but it seems kind-of limited (Distributions like Binomial
,HalfNormal
are not present). I think most of the work here would be getting the shapes and arguments right.
If anyone isn't working on this I'd like to take this on.
An example usage would be helpful too. Something like this https://github.com/pymc-devs/aesara/issues/146#issue-737867223 but with concrete inputs and outputs to see if the implementation works. (The RandomVariable
is fairly new so I'll need to study it's use interface.)
@kc611 Would be great to get some help here, it's yours!
I tried a simple straightforward implementation (without shape handling, to see the arguments and their datatypes) as follows:
@jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op):
name = op.name
def random_variable(rng, size, dtype, *args):
smpl_value = getattr(jax.random, name)(jax.random.PRNGKey(0), size)
return (dtype, smpl_value)
return random_variable
And upon running the following code:
import aesara
import aesara.tensor as aet
from aesara.tensor.random.basic import normal
import jax
jax_mode = aesara.compile.Mode(aesara.link.jax.jax_linker.JAXLinker())
a = aet.fscalar()
b = normal(size=3)
c = a + b
f = aesara.function([a], [c], mode=jax_mode)
print(f(2.5))
It throws the following exception:
TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type
Apply node that caused the error: Elemwise{Add}[(0, 1)](InplaceDimShuffle{x}.0, normal_rv.out)
Toposort index: 2
Inputs types: [TensorType(float32, (True,)), TensorType(float64, vector)]
Inputs shapes: [(), 'No shapes']
Inputs strides: [(), 'No strides']
Inputs values: [array(2.5, dtype=float32), RandomState(MT19937) at 0x7F0EF518EA40]
Outputs clients: [['output']]
It seems that jax
is not accepting rng
(of type RandomState
) as an input argument. Is there any way around this ? Should rng
be made a class data member in RandomVariable
rather than being a member of an Apply
node input ? (it's only job seems to be for behaving as a seed generator).
Tagging @brandonwillard , since he'll be the person with most insight on this.
I'll need to see the full trace to make sense of it.
Here goes :
Traceback (most recent call last):
File "/home/kaustubh/Desktop/Codeground/ddsgfdg/newiss.py", line 11, in <module>
print(f(2.5))
File "/home/kaustubh/Desktop/Github/aesara/aesara/compile/function/types.py", line 974, in __call__
self.fn()
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 181, in streamline_default_f
thunk()
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 80, in thunk
outputs = [
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 81, in <listcomp>
jax_impl_jit(*[x[0] for x in thunk_inputs])
jax.traceback_util.FilteredStackTrace: TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 181, in streamline_default_f
thunk()
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 80, in thunk
outputs = [
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 81, in <listcomp>
jax_impl_jit(*[x[0] for x in thunk_inputs])
File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 209, in f_jitted
_check_arg(arg)
File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 2117, in _check_arg
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/kaustubh/Desktop/Codeground/ddsgfdg/newiss.py", line 11, in <module>
print(f(2.5))
File "/home/kaustubh/Desktop/Github/aesara/aesara/compile/function/types.py", line 974, in __call__
self.fn()
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 185, in streamline_default_f
raise_with_op(fgraph, node, thunk)
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 508, in raise_with_op
raise exc_value.with_traceback(exc_trace)
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/utils.py", line 181, in streamline_default_f
thunk()
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 80, in thunk
outputs = [
File "/home/kaustubh/Desktop/Github/aesara/aesara/link/jax/jax_linker.py", line 81, in <listcomp>
jax_impl_jit(*[x[0] for x in thunk_inputs])
File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 209, in f_jitted
_check_arg(arg)
File "/home/kaustubh/.local/lib/python3.8/site-packages/jax/api.py", line 2117, in _check_arg
raise TypeError("Argument '{}' of type {} is not a valid JAX type"
TypeError: Argument 'RandomState(MT19937)' of type <class 'numpy.random.mtrand.RandomState'> is not a valid JAX type
Apply node that caused the error: Elemwise{Add}[(0, 1)](InplaceDimShuffle{x}.0, normal_rv.out)
Toposort index: 2
Inputs types: [TensorType(float32, (True,)), TensorType(float64, vector)]
Inputs shapes: [(), 'No shapes']
Inputs strides: [(), 'No strides']
Inputs values: [array(2.5, dtype=float32), RandomState(MT19937) at 0x7F23E955FA40]
Outputs clients: [['output']]
HINT: Re-running with most Aesara optimization disabled could give you a back-trace of when this node was created. This can be done with by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.
Note : Not having the rng
in the inputs
of Apply
node doesn't cause this error and performs as expected
OK, I think this particular case may require special handling for a new data type (i.e. RandomState
s).
This sort of change/addition might start somewhere in the input-conversion conditions within aesara.link.jax.jax_dispatch.compose_jax_funcs
. That loop is basically where "primitive" Type
s are handled, but, currently, only TensorType
s are explicitly supported. All other types are passed through to JAX as-is.
In other words, when an Aesara Variable
has no "owner", there's no graph to climb, and—in these cases—we could be dealing with something like a straight numpy.ndarray
that was represented by a Constant
. Constant
arrays simply need to converted into a JAX array via jnp.array
, but, in this case, we're dealing with a RandomState
-typed object that doesn't have a dtype
field/property, so it generates a function JAX-able function that returns the object as-is (i.e. here). I'm assuming that this is the source of the error.
We should make generic functions that dispatch on i.type
in this branch. The conversion function registered to the Type
of a RandomState
could then return None
and avoid this error; however, there is a corresponding JAX type/parameter for this Aesara/NumPy RandomState
parameter, and we should attempt to make a connection between the two.
If we can convert a NumPy RandomState
object to its corresponding JAX type, that would be perfect, because our conversion function would only need to do that and everything would work as intended. I doubt that's the case, so we might need to create our own conversion function. Also, those RandomState
objects often come from shared Aesara variables, so we're likely to encounter the exact same RandomState
multiple times, which means we cannot shallowly copy or recreate those corresponding JAX objects—they need to have the same internal state, or, better yet, be the exact same JAX RNG objects. In other words, we'll probably need to keep a map/cache of RandomState
-to-JAX-RNG-objects.
Thanks for the deep dive. I got the gist of it. The conversion can be done as follows :
Numpy's RandomState
allows for 2^(32*623) - 1
number of possible states. This can be accessed by get_state() as an array with 623 elements of 32 bit unsigned integers. Meanwhile Jax uses PRNGs in jax.random to convert a certain seed into a key
and ensures reproduciblity but only allows for 2^(32*2) - 1
number of possible states. So we can generate a PRNGKey
using a certain operation done on the array returned by get_state()
(Something that'll reduce down the 623 elements to only two. Or we could simply take the first two elements and use them as a seed.) This way if the Numpy's state
remains same we can ensure that the Jax conversion of it will also be same. (Same seed will give same key for PRNGs
).
That loop is basically where "primitive" Types are handled, but, currently, only TensorTypes are explicitly supported. All other types are passed through to JAX as-is.
@brandonwillard I suspect that there is also a similar as-is passing being done somewhere else too. I implemented a primitive generic dispatch functionality at the code that you pointed out in compose_jax_functions()
as follows:
for i in out_node.inputs:
if i in fgraph_inputs:
# This input is a top-level input (i.e. an input to the
# `FunctionGraph` in which this `out_node` resides)
idx = fgraph_inputs.index(i)
input_f = get_jax_data_func(i.type, i, idx)
elif i.owner is None:
# This input is something like a `aesara.graph.basic.Constant`
input_f = get_jax_data_func(i.type, i)
where the function representations are as follows :
@singledispatch
def get_jax_data_func(i_type, i, idx=None):
def jax_data_func(*inputs, i=i, idx=idx):
i_dtype = getattr(i, "dtype", None)
i_data = i.data if idx is None else inputs[idx]
if i_dtype is None:
return i_data
else:
return jnp.array(i_data, dtype=jnp.dtype(i_dtype))
return jax_data_func
@get_jax_data_func.register(RandomStateType)
def get_jax_data_func_RandomState(i_type, i, idx=None):
def random_state(*inputs, i=i, idx=idx):
return jax.random.PRNGKey(i.get_state()[0])
return random_state
This still throws the same exception. It seems that the exception is thrown before the these composed functions are being actually run. (Maybe while converting inputs to jax format ?)
Looks like you have the right idea, but I'll probably need to walk through an explicit implementation in order to help further.
At this point, you should put in a (draft) PR. We can discuss these details there.
Random functions do not work with jax linker
raises