aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.17k stars 156 forks source link

Create a Numba implementation for `MvNormalRV` #842

Open brandonwillard opened 2 years ago

brandonwillard commented 2 years ago

A Numba implementation for MvNormalRV could follow the same basic approach as https://github.com/aesara-devs/aesara/pull/841, i.e. nearly a direct translation of MvNormalRV.rng_fn.

As usual, the main challenge will involve converting the tuple-based operations into something that Numba will accept. In the case of https://github.com/aesara-devs/aesara/pull/841, that was accomplished by replacing the tuple slicing arguments with constant literals so that Numba could perform those operations at compile-time (i.e. Numba can't do them at run-time).

Aside from replacing the tuple slices, broadcast_params will need to be converted—or its essential logic reproduced—in a Numba compatible form. This has all the same tuple-based complications and will likely require some clever use of numba.np.unsafe.ndarray.to_fixed_tuple and/or aesara.link.numba.dispatch.create_tuple_creator. In both cases, constant literal values derived from the Aesara type information will likely need to be used.

danhphan commented 2 years ago

Hi @brandonwillard I am happy to work on this task :D

twiecki commented 2 years ago

That's great @danhphan, thanks!

brandonwillard commented 2 years ago

Hi @brandonwillard I am happy to work on this task :D

Sounds good! My recommendation/approach is to gradually write the implementation in Numba only, then, once you've arrived at a working Numba implementation, start working from that in Aesara.

The real trick to making these conversions work is to use the information that's technically "constant" at the time the Numba implementation needs to be constructed by Aesara. That's how we get past most of the missing features in Numba (e.g. unsupported NumPy arguments like size in numpy.random functions) and unsupported Python functionality (e.g. construction and manipulation of tuples). Basically, we're doing something akin to macro programming that makes good use of "compile-time" constants/literals.

danhphan commented 2 years ago

Hi @brandonwillard, thank you for the suggestion. I am playing with numba, and will implement the numba version for MvNormalRV.rng_fn first :D

fbarfi commented 2 years ago

Thank you all for working on this. It's been a real problem for me. The other distribution that I have been using is PolyaGamma and I am getting the same error. Probably the solution might be different from MvNormal as the PolyaGamma is not a 'usual'm distribution. In any case I look forward to your solutions. here is the final section of the error message:

 140 @singledispatch
    141 def jax_funcify(op, node=None, storage_map=None, **kwargs):
    142     """Create a JAX compatible function from an Aesara `Op`."""
--> 143     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: _PolyaGammaLogDistFunc{get_pdf=True}
brandonwillard commented 2 years ago

Thank you all for working on this. It's been a real problem for me. The other distribution that I have been using is PolyaGamma and I am getting the same error. Probably the solution might be different from MvNormal as the PolyaGamma is not a 'usual'm distribution. In any case I look forward to your solutions. here is the final section of the error message:

 140 @singledispatch
    141 def jax_funcify(op, node=None, storage_map=None, **kwargs):
    142     """Create a JAX compatible function from an Aesara `Op`."""
--> 143     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: _PolyaGammaLogDistFunc{get_pdf=True}

The implementation discussed here is for Numba, not JAX, so solving this issue won't help. A separate issue will need to be opened for MvNormal support in JAX. Likewise, JAX support for _PolyaGammaLogDistFunc will need to be added in the codebase that defines that Op (i.e. PyMC).

danhphan commented 2 years ago

Hi @brandonwillard , I have spent this week to get my head around python closures and decorators (singledispatch, parameterized, ...). Now I have the gist of your suggestion on macros things :D

I am testing and debugging the codes of numba_funcify_CategoricalRV on #841

My understanding is that numba_funcify_CategoricalRV is a specialisation of a generic numba_funcify function, when op equals aer.CategoricalRV. And I need to write a similar function numba_funcify_MvNormalRV for the case op equals aer.MvNormalRV.

Also, can we consider numba_funcify_CategoricalRV as a decorator? (as it returns a sampler function, and if the op input here is consider a function?).

Then the sampler will be a decorated function, and it is used to decorate the @classmethod CategoricalRV.rng_fn

Besides, not sure why we need the dtype argument in sampler function? dtype seems not been used inside this function (as the freevar out_dtype is used instead), and dtype also not an argument in CategoricalRV.rng_fn function.

Thank you

brandonwillard commented 2 years ago

My understanding is that numba_funcify_CategoricalRV is a specialisation of a generic numba_funcify function, when op equals aer.CategoricalRV. And I need to write a similar function numba_funcify_MvNormalRV for the case op equals aer.MvNormalRV.

Yes, exactly.

Also, can we consider numba_funcify_CategoricalRV as a decorator? (as it returns a sampler function, and if the op input here is consider a function?).

It's a function constructor—in the most basic sense—but, while it is a function that returns another function, the way it's used doesn't really fit the decorator designation (i.e. it doesn't take functions as input nor "decorate" existing functions).

Besides, not sure why we need the dtype argument in sampler function? dtype seems not been used inside this function (as the freevar out_dtype is used instead), and dtype also not an argument in CategoricalRV.rng_fn function.

It may not be explicitly used, but the signature (i.e. inputs and outputs) of the returned function must match the inputs and outputs of the Apply nodes constructed by the relevant Op's Op.make_node. In this case, that Op is RandomVariable and RandomVariable.make_node returns Applys with inputs corresponding to rng, size, dtype, and the distribution arguments.

danhphan commented 2 years ago

Hi @brandonwillard, many thanks for the clarifications.

Sorry I am quite slow, and still in the learning process. Last week, I dig deep into the aesara Graph structures, and try to understand the relation between Apply, Op, and Variable.

Previously I was mis-thought that RandomVariable in some type of Variable, similar to TensorVariable and SharedVariable. It is clear to me now that RandomVariable is actually a subclass of 'Op', which can make an Apply node from inputs of Variables. Op can also generate outputs of Variables with its perform() function (I think that RandomVariable should be named as RandomVariableOp to make it clearer by the way).

Besides, from pymc code, I can see that the Distribution class has an rv_op instance of RandomVariable, which then form many kinds if distributions in pymc. The rng_fn function in RandomVariable and its sub-classes is used for sampling purpose, and we want to make it faster by using backend like Jax or Numba.

I think I need some more time to learn how the Op and RandomVariable really works, and try to address several "good-first-issues" to learn more on Aesara codebase first. I will come back to this issue later. In the meantime, if anyone interest in this issue, please feel free to go ahead.

Thank you.

brandonwillard commented 2 years ago

Previously I was mis-thought that RandomVariable in some type of Variable, similar to TensorVariable and SharedVariable. It is clear to me now that RandomVariable is actually a subclass of 'Op', which can make an Apply node from inputs of Variables.

Yes, exactly; the Variable part of the name reflects the "variable" in "random variables" from statistics/probability theory, which itself can be confusing.

Besides, from pymc code, I can see that the Distribution class has an rv_op instance of RandomVariable, which then form many kinds if distributions in pymc. The rng_fn function in RandomVariable and its sub-classes is used for sampling purpose, and we want to make it faster by using backend like Jax or Numba.

Just to be clear, there's no need to be concerned with the PyMC side of things; a Numba implementation of MvNormalRV only needs to adhere to the expectations of Aesara and Numba for it to work with PyMC.

I think I need some more time to learn how the Op and RandomVariable really works, and try to address several "good-first-issues" to learn more on Aesara codebase first. I will come back to this issue later. In the meantime, if anyone interest in this issue, please feel free to go ahead.

Perhaps it would help to clarify what it is that you're trying to understand, because most of the things you've mentioned here shouldn't directly affect a Numba implementation of MvNormalRV.

All that's necessary for such an implementation is a Numba port of MvNormalRV.rng_fn.

You can use numba_funcify_DirichletRV as a template/example; the only thing that changes substantially is that the distribution parameters are no longer alphas, but mean and cov. For example, node.inputs[3:] now refer to the (symbolic) parameters of the MvNormalRV class, which represent the mean and covariance. Likewise, the returned function is also expected to take the same mean and covariance parameters. Here's a more explicit template for an MvNormalRV implementation—it's just numba_funcify_DirichletRV with the distribution parameters updated (and a few things renamed):

@numba_funcify.register(aer.MvNormalRV)
def numba_funcify_MvNormalRV(op, node, **kwargs):

    out_dtype = node.outputs[1].type.numpy_dtype
    mean_ndim = node.inputs[3].type.ndim

    neg_ind_shape_len = -mean_ndim + 1
    size_len = int(get_vector_length(node.inputs[1]))

    if mean_ndim > 1:

        @numba_basic.numba_njit
        def mvnormal_rv(rng, size, dtype, mean, cov):

            # TODO: This needs a Numba port
            mean, cov = broadcast_params([mean, cov], cls.ndims_params)

            if size_len > 0:
                size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
                if (
                    0 < mean.ndim - 1 <= len(size_tpl)
                    and size_tpl[neg_ind_shape_len:] != mean.shape[:-1]
                ):
                    raise ValueError(
                        "shape mismatch: objects cannot be broadcast to a single shape"
                    )
                mean_shape = size_tpl + mean.shape[-1:]
                cov_shape = size_tpl + cov.shape[-1:]
            else:
                mean_shape = mean.shape[-1:]
                cov_shape = cov.shape[-1:]

            res = np.empty(mean_shape, dtype=out_dtype)
            mean_bcast = np.broadcast_to(mean, mean_shape)
            cov_bcast = np.broadcast_to(cov, cov_shape)

            for index in np.ndindex(*samples_shape[:-1]):
                res[index] = np.random.multivariate_normal(mean_bcast[index], cov_bcast[index])

            return (rng, res)

    else:

        @numba_basic.numba_njit
        def mvnormal_rv(rng, size, dtype, mean, cov):
            size = numba_ndarray.to_fixed_tuple(size, size_len)
            return (rng, np.random.multivariate_normal(mean, cov, size))

    return mvnormal_rv

Regarding the RandomVariable Op, it's helpful to know that MvNormalRV.rng_fn is called by RandomVariable.perform here, since that tells you where the arguments to MvNormalRV.rng_fn originate.

More generally, when we're converting an Aesara graph into a Numba-compilable Python function, the Numba-compilable function needs to have a signature that matches the inputs and outputs received by each Op's Op.perform, and you can see what those inputs and outputs are here and here, respectively, for a RandomVariable.perform. If you want to follow those threads, we can help clarify them, but they shouldn't be necessary for this work.

Aside from that, I always recommend printing graphs with aesara.dprint; that will show you the objects with which you're actually working. Also, it's best to start this kind of work with a unit test (e.g. either an existing or new one in tests.link.test_numba) and run it with pytest --pdb (and possibly --pdbcls=IPython.terminal.debugger:Pdb if you want ipdb features), so that you can walk through the code with the debugger and actually see what's happening.

danhphan commented 2 years ago

Hi @brandonwillard, thanks a lot for putting your effort to write such a detailed explanation. This is really helpful, and it is clear to me now. I will explore the way to port broadcast_params into Numba soon.

Also thank you for the debug tips, I normally use set_trace (from IPython.core.debugger) with pytest -s. I will definitely try --pdbcls=IPython.terminal.debugger:Pdb :)