Closed PgLoLo closed 3 years ago
I believe this is related to batch discussions here: https://github.com/dgasmith/opt_einsum/issues/95
You can supply arrays with different shapes but not ndim to a ContractExpression
and they will be evaluated with the same contraction path - i.e. you could vary 1024, 16 etc. Or you could generate the path for one set of shapes and use it to generate many different contract expressions (though it might no longer be the best path for different sets of shapes!).
opt_einsum
also works nicely with jax
(which has an efficient version of numpy.vectorize
):
eq = 'a,a->'
expr = oe.contract_expression(eq, (16,), (16,))
vexpr = jax.numpy.vectorize(expr, signature='(a),(a)->()')
x = np.random.randn(7, 32, 1024, 16)
y = np.random.randn(16)
vexpr(x, y).shape
# (7, 32, 1024)
Is it possible to construct
contract_expression
with optional batch dimensions? Consider example:contract('...i,i->...', a, b)
Expression above could be used with different shapes of variablea
. But If I construct contract_expression in the following way:contract_expression('...i,i->...', (1024, 16), (16,))
It would accept only 2-dimensional tensor as its first argument.