Open Ch0ronomato opened 3 months ago
Hey @ricardoV94 , could I get some clarity on scalar loop? I was under the impression that it might just work (I don't see any explicit tests for numba or jax) - what is the work needed for scalar loop? Here is an example test I wrote, that also maybe invalid
def test_ScalarOp():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const
op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)
op = ScalarLoop(), node = ScalarLoop(n_steps, x0, const)
kwargs = {'input_storage': [[None], [None], [None]], 'output_storage': [[None]], 'storage_map': {ScalarLoop.0: [None], const: [None], x0: [None], n_steps: [None]}}
nfunc_spec = None
@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""
nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
> raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
E NotImplementedError: Dispatch not implemented for Scalar Op ScalarLoop
pytensor/link/pytorch/dispatch/scalar.py:19: NotImplementedError
You haven't seen JAX/Numba code because scalar loop isn't yet supported in those backends either.
I suggest checking the perform method to have an idea of how the Operator works
For Blockwise you should be able to use vmap
repeatedly for each batch dimension. If they would have an equivalent to np.vectorize
that would be all we need.
Description
Add the branching ops