pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
331 stars 99 forks source link

ENH: Native support for dims in tensors and tensor operations #954

Open wd60622 opened 1 month ago

wd60622 commented 1 month ago

Before

import pytensor.tensor as pt

# Need to 
a = pt.vector("a", shape=(2, ))
b = pt.vector("b", shape=(3, ))

# a + b fails due to broadcasting
# Transpose required
result = a + b[:, None]

After

a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")

result = a + b
# result.type TensorType(float64, dims=("channel", "geo")) # xarray-like ordering of dims
# + operation handles the transpose based on dims but would work for other element wise operations

Context for the issue:

Use of the Prior class in PyMC-Marketing and the potential usefulness of it else where and in PyMC directly

dist = Prior(
    "Normal", 
    # Variables are automatically transposed before passing to PyMC distributions
    mu=Prior("Normal", dims="geo"), 
    sigma=Prior("HalfNormal", dims="geo"), 
    dims=("geo", "channel"), 
)

References: PyMC-Marketing auto-broadcasting handling: https://github.com/pymc-labs/pymc-marketing/blob/main/pymc_marketing/prior.py#L131-L168 PyMC Discussion: https://github.com/pymc-devs/pymc/discussions/7416

wd60622 commented 1 month ago

not sure how eval would work in the case where shapes are not provided.

a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")
result = a + b
# raise since shapes are unknown?
result.eval({"a": np.array([1, 2, 3]),  "b": np.array([1, 2])})
# shape might be required for dims?
a = pt.vector("a", dims="channel", shape=(3, ))
b = pt.vector("b", dims="geo", shape=(2, ))
ricardoV94 commented 1 month ago

not sure how eval would work in the case where shapes are not provided.

a = pt.vector("a", dims="channel")
b = pt.vector("b", dims="geo")
result = a + b
# raise since shapes are unknown?
result.eval({"a": np.array([1, 2, 3]),  "b": np.array([1, 2])})
# shape might be required for dims?
a = pt.vector("a", dims="channel", shape=(3, ))
b = pt.vector("b", dims="geo", shape=(2, ))

It should be optional when it's enough to know at runtime. When doing addition the shape isn't needed

ricardoV94 commented 1 month ago

There's a draft PR that started on this idea: https://github.com/pymc-devs/pytensor/pull/407

ricardoV94 commented 1 month ago

Also we probably want to use different types for dimmed and regular variables since they have completely different semantics.

ricardoV94 commented 1 month ago

There's also the question of what should be the output, xarrays? Because if dims order is arbitrary users don't know what they're getting, but building xarray for the output (if not for intermediate operations, as our backends don't support that obviously) could be costly. Unless numpy arrays can be wrapped in xarray datarrays without copy costs.

Maybe a simpler object in between xarray and np arrays?

wd60622 commented 1 month ago

There's also the question of what should be the output, xarrays? Because if dims order is arbitrary users don't know what they're getting, but building xarray for the output (if not for intermediate operations, as our backends don't support that obviously) could be costly. Unless numpy arrays can be wrapped in xarray datarrays without copy costs.

Maybe a simpler object in between xarray and np arrays?

I'd think that it wouldn't be xarray. That seems like a pretty large dependency to add.

I would think the dims would be according to order of operations

# 407 syntax?
a = px.as_xtensor_variable("a", dims=("channel", ))
b = px.as_xtensor_variable("b", dims=("geo", ))
result1 = a + b # (channel, geo)
result2 = b + a # (geo, channel)

Couldn't this be constructed in a way where result1.owner is just Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) by some logic off the dims in the operation?

ricardoV94 commented 1 month ago

I would think the dims would be according to order of operations

In my draft PR I sorted dims alphabetically. I don't yet know what will work out better tbh. We definitely don't want to compute a + b and b + a in a final graph, since they are identical modulo transposition. But our regular backend should be able to merge those operations so we may not need to worry. We definitely need to have at least predictable function outputs, even if everything in the middle can be done in whatever order we want.

Btw, order can have an impact on performance as indicated by our incredibly slow sum along axis 0 xD: https://github.com/pymc-devs/pytensor/issues/935

But we definitely need to worry about that yet