pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
360 stars 107 forks source link

Adding conditionals for torch #939

Open Ch0ronomato opened 3 months ago

Ch0ronomato commented 3 months ago

Description

Add the branching ops

Ch0ronomato commented 3 months ago

Hey @ricardoV94 , could I get some clarity on scalar loop? I was under the impression that it might just work (I don't see any explicit tests for numba or jax) - what is the work needed for scalar loop? Here is an example test I wrote, that also maybe invalid


def test_ScalarOp():
    n_steps = int64("n_steps")
    x0 = float64("x0")
    const = float64("const")
    x = x0 + const

    op = ScalarLoop(init=[x0], constant=[const], update=[x])
    x = op(n_steps, x0, const)

    fn = function([n_steps, x0, const], x, mode=pytorch_mode)
    np.testing.assert_allclose(fn(5, 0, 1), 5)
    np.testing.assert_allclose(fn(5, 0, 2), 10)
    np.testing.assert_allclose(fn(4, 3, -1), -1)
op = ScalarLoop(), node = ScalarLoop(n_steps, x0, const)
kwargs = {'input_storage': [[None], [None], [None]], 'output_storage': [[None]], 'storage_map': {ScalarLoop.0: [None], const: [None], x0: [None], n_steps: [None]}}
nfunc_spec = None

    @pytorch_funcify.register(ScalarOp)
    def pytorch_funcify_ScalarOp(op, node, **kwargs):
        """Return pytorch function that implements the same computation as the Scalar Op.

        This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
        even though it's dispatched on the Scalar Op.
        """

        nfunc_spec = getattr(op, "nfunc_spec", None)
        if nfunc_spec is None:
>           raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
E           NotImplementedError: Dispatch not implemented for Scalar Op ScalarLoop

pytensor/link/pytorch/dispatch/scalar.py:19: NotImplementedError
ricardoV94 commented 3 months ago

You haven't seen JAX/Numba code because scalar loop isn't yet supported in those backends either.

I suggest checking the perform method to have an idea of how the Operator works

ricardoV94 commented 2 months ago

For Blockwise you should be able to use vmap repeatedly for each batch dimension. If they would have an equivalent to np.vectorize that would be all we need.