Open brandonwillard opened 2 years ago
Hi @brandonwillard I am happy to work on this task :D
That's great @danhphan, thanks!
Hi @brandonwillard I am happy to work on this task :D
Sounds good! My recommendation/approach is to gradually write the implementation in Numba only, then, once you've arrived at a working Numba implementation, start working from that in Aesara.
The real trick to making these conversions work is to use the information that's technically "constant" at the time the Numba implementation needs to be constructed by Aesara. That's how we get past most of the missing features in Numba (e.g. unsupported NumPy arguments like size
in numpy.random
functions) and unsupported Python functionality (e.g. construction and manipulation of tuple
s). Basically, we're doing something akin to macro programming that makes good use of "compile-time" constants/literals.
Hi @brandonwillard, thank you for the suggestion. I am playing with numba
, and will implement the numba
version for MvNormalRV.rng_fn
first :D
Thank you all for working on this. It's been a real problem for me. The other distribution that I have been using is PolyaGamma and I am getting the same error. Probably the solution might be different from MvNormal as the PolyaGamma is not a 'usual'm distribution. In any case I look forward to your solutions. here is the final section of the error message:
140 @singledispatch
141 def jax_funcify(op, node=None, storage_map=None, **kwargs):
142 """Create a JAX compatible function from an Aesara `Op`."""
--> 143 raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
NotImplementedError: No JAX conversion for the given `Op`: _PolyaGammaLogDistFunc{get_pdf=True}
Thank you all for working on this. It's been a real problem for me. The other distribution that I have been using is PolyaGamma and I am getting the same error. Probably the solution might be different from MvNormal as the PolyaGamma is not a 'usual'm distribution. In any case I look forward to your solutions. here is the final section of the error message:
140 @singledispatch 141 def jax_funcify(op, node=None, storage_map=None, **kwargs): 142 """Create a JAX compatible function from an Aesara `Op`.""" --> 143 raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}") NotImplementedError: No JAX conversion for the given `Op`: _PolyaGammaLogDistFunc{get_pdf=True}
The implementation discussed here is for Numba, not JAX, so solving this issue won't help. A separate issue will need to be opened for MvNormal
support in JAX. Likewise, JAX support for _PolyaGammaLogDistFunc
will need to be added in the codebase that defines that Op
(i.e. PyMC).
Hi @brandonwillard , I have spent this week to get my head around python closures
and decorators
(singledispatch
, parameterized, ...). Now I have the gist of your suggestion on macros
things :D
I am testing and debugging the codes of numba_funcify_CategoricalRV
on #841
My understanding is that numba_funcify_CategoricalRV
is a specialisation of a generic numba_funcify
function, when op
equals aer.CategoricalRV
. And I need to write a similar function numba_funcify_MvNormalRV
for the case op
equals aer.MvNormalRV
.
Also, can we consider numba_funcify_CategoricalRV
as a decorator? (as it returns a sampler
function, and if the op
input here is consider a function?).
Then the sampler
will be a decorated function, and it is used to decorate the @classmethod
CategoricalRV.rng_fn
Besides, not sure why we need the dtype
argument in sampler
function? dtype
seems not been used inside this function (as the freevar out_dtype
is used instead), and dtype
also not an argument in CategoricalRV.rng_fn
function.
Thank you
My understanding is that
numba_funcify_CategoricalRV
is a specialisation of a genericnumba_funcify
function, whenop
equalsaer.CategoricalRV
. And I need to write a similar functionnumba_funcify_MvNormalRV
for the caseop
equalsaer.MvNormalRV
.
Yes, exactly.
Also, can we consider
numba_funcify_CategoricalRV
as a decorator? (as it returns asampler
function, and if theop
input here is consider a function?).
It's a function constructor—in the most basic sense—but, while it is a function that returns another function, the way it's used doesn't really fit the decorator designation (i.e. it doesn't take functions as input nor "decorate" existing functions).
Besides, not sure why we need the
dtype
argument insampler
function?dtype
seems not been used inside this function (as the freevarout_dtype
is used instead), anddtype
also not an argument inCategoricalRV.rng_fn
function.
It may not be explicitly used, but the signature (i.e. inputs and outputs) of the returned function must match the inputs and outputs of the Apply
nodes constructed by the relevant Op
's Op.make_node
. In this case, that Op
is RandomVariable
and RandomVariable.make_node
returns Apply
s with inputs corresponding to rng
, size
, dtype
, and the distribution arguments.
Hi @brandonwillard, many thanks for the clarifications.
Sorry I am quite slow, and still in the learning process. Last week, I dig deep into the aesara Graph structures, and try to understand the relation between Apply, Op, and Variable.
Previously I was mis-thought that RandomVariable
in some type of Variable
, similar to TensorVariable
and SharedVariable
. It is clear to me now that RandomVariable
is actually a subclass of 'Op', which can make an Apply
node from inputs of Variables
. Op
can also generate outputs of Variables
with its perform()
function (I think that RandomVariable
should be named as RandomVariableOp
to make it clearer by the way).
Besides, from pymc code, I can see that the Distribution
class has an rv_op
instance of RandomVariable
, which then form many kinds if distributions in pymc. The rng_fn
function in RandomVariable
and its sub-classes is used for sampling purpose, and we want to make it faster by using backend like Jax or Numba.
I think I need some more time to learn how the Op
and RandomVariable
really works, and try to address several "good-first-issues" to learn more on Aesara codebase first. I will come back to this issue later. In the meantime, if anyone interest in this issue, please feel free to go ahead.
Thank you.
Previously I was mis-thought that
RandomVariable
in some type ofVariable
, similar toTensorVariable
andSharedVariable
. It is clear to me now thatRandomVariable
is actually a subclass of 'Op', which can make anApply
node from inputs ofVariables
.
Yes, exactly; the Variable
part of the name reflects the "variable" in "random variables" from statistics/probability theory, which itself can be confusing.
Besides, from pymc code, I can see that the
Distribution
class has anrv_op
instance ofRandomVariable
, which then form many kinds if distributions in pymc. Therng_fn
function inRandomVariable
and its sub-classes is used for sampling purpose, and we want to make it faster by using backend like Jax or Numba.
Just to be clear, there's no need to be concerned with the PyMC side of things; a Numba implementation of MvNormalRV
only needs to adhere to the expectations of Aesara and Numba for it to work with PyMC.
I think I need some more time to learn how the
Op
andRandomVariable
really works, and try to address several "good-first-issues" to learn more on Aesara codebase first. I will come back to this issue later. In the meantime, if anyone interest in this issue, please feel free to go ahead.
Perhaps it would help to clarify what it is that you're trying to understand, because most of the things you've mentioned here shouldn't directly affect a Numba implementation of MvNormalRV
.
All that's necessary for such an implementation is a Numba port of MvNormalRV.rng_fn
.
You can use numba_funcify_DirichletRV
as a template/example; the only thing that changes substantially is that the distribution parameters are no longer alphas
, but mean
and cov
. For example, node.inputs[3:]
now refer to the (symbolic) parameters of the MvNormalRV
class, which represent the mean and covariance. Likewise, the returned function is also expected to take the same mean and covariance parameters. Here's a more explicit template for an MvNormalRV
implementation—it's just numba_funcify_DirichletRV
with the distribution parameters updated (and a few things renamed):
@numba_funcify.register(aer.MvNormalRV)
def numba_funcify_MvNormalRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
mean_ndim = node.inputs[3].type.ndim
neg_ind_shape_len = -mean_ndim + 1
size_len = int(get_vector_length(node.inputs[1]))
if mean_ndim > 1:
@numba_basic.numba_njit
def mvnormal_rv(rng, size, dtype, mean, cov):
# TODO: This needs a Numba port
mean, cov = broadcast_params([mean, cov], cls.ndims_params)
if size_len > 0:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
if (
0 < mean.ndim - 1 <= len(size_tpl)
and size_tpl[neg_ind_shape_len:] != mean.shape[:-1]
):
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape"
)
mean_shape = size_tpl + mean.shape[-1:]
cov_shape = size_tpl + cov.shape[-1:]
else:
mean_shape = mean.shape[-1:]
cov_shape = cov.shape[-1:]
res = np.empty(mean_shape, dtype=out_dtype)
mean_bcast = np.broadcast_to(mean, mean_shape)
cov_bcast = np.broadcast_to(cov, cov_shape)
for index in np.ndindex(*samples_shape[:-1]):
res[index] = np.random.multivariate_normal(mean_bcast[index], cov_bcast[index])
return (rng, res)
else:
@numba_basic.numba_njit
def mvnormal_rv(rng, size, dtype, mean, cov):
size = numba_ndarray.to_fixed_tuple(size, size_len)
return (rng, np.random.multivariate_normal(mean, cov, size))
return mvnormal_rv
Regarding the RandomVariable
Op
, it's helpful to know that MvNormalRV.rng_fn
is called by RandomVariable.perform
here, since that tells you where the arguments to MvNormalRV.rng_fn
originate.
More generally, when we're converting an Aesara graph into a Numba-compilable Python function, the Numba-compilable function needs to have a signature that matches the inputs and outputs received by each Op
's Op.perform
, and you can see what those inputs and outputs are here and here, respectively, for a RandomVariable.perform
. If you want to follow those threads, we can help clarify them, but they shouldn't be necessary for this work.
Aside from that, I always recommend printing graphs with aesara.dprint
; that will show you the objects with which you're actually working. Also, it's best to start this kind of work with a unit test (e.g. either an existing or new one in tests.link.test_numba
) and run it with pytest --pdb
(and possibly --pdbcls=IPython.terminal.debugger:Pdb
if you want ipdb
features), so that you can walk through the code with the debugger and actually see what's happening.
Hi @brandonwillard, thanks a lot for putting your effort to write such a detailed explanation. This is really helpful, and it is clear to me now. I will explore the way to port broadcast_params
into Numba soon.
Also thank you for the debug tips, I normally use set_trace
(from IPython.core.debugger
) with pytest -s
. I will definitely try --pdbcls=IPython.terminal.debugger:Pdb
:)
A Numba implementation for
MvNormalRV
could follow the same basic approach as https://github.com/aesara-devs/aesara/pull/841, i.e. nearly a direct translation ofMvNormalRV.rng_fn
.As usual, the main challenge will involve converting the
tuple
-based operations into something that Numba will accept. In the case of https://github.com/aesara-devs/aesara/pull/841, that was accomplished by replacing thetuple
slicing arguments with constant literals so that Numba could perform those operations at compile-time (i.e. Numba can't do them at run-time).Aside from replacing the
tuple
slices,broadcast_params
will need to be converted—or its essential logic reproduced—in a Numba compatible form. This has all the sametuple
-based complications and will likely require some clever use ofnumba.np.unsafe.ndarray.to_fixed_tuple
and/oraesara.link.numba.dispatch.create_tuple_creator
. In both cases, constant literal values derived from the Aesara type information will likely need to be used.