pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.47k stars 1.97k forks source link

Missing static information for jax backend for models with missing values #7387

Closed aseyboldt closed 6 days ago

aseyboldt commented 6 days ago

Description

The following fails in the jax backend when it tries to allocate a new array like an existing array with unknown shape:

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

with pm.Model(coords={"a": [0, 1, 2]}) as model:
#with pm.Model() as model:
with pm.Model(coords={"a": [0, 1, 2]}) as model:
    y = pm.Normal("y", 0, observed=[0, 0, np.nan], dims="a")

from pymc.model.transform.optimization import freeze_dims_and_data

model = freeze_dims_and_data(model, dims=model.coords.keys())
outputs = pytensor.clone_replace(model.unobserved_value_vars, model.rvs_to_values)
func = pytensor.function(model.value_vars, outputs, mode="JAX")
func(np.ones(1))
File [/tmp/tmp5pyizklf:5](http://localhost:7890/tmp/tmp5pyizklf#line=4), in jax_funcified_fgraph(y_unobserved, a)
      3 tensor_variable = deepcopyop(y_unobserved)
      4 # AllocEmpty{dtype='float64'}(a)
----> 5 tensor_variable_1 = allocempty(a)
      6 # AdvancedSetSubtensor(AllocEmpty{dtype='float64'}.0, y_unobserved, [False False  True])
      7 tensor_variable_2 = advancedincsubtensor(tensor_variable_1, y_unobserved, tensor_constant)

File [~/git/pytensor/pytensor/link/jax/dispatch/tensor_basic.py:39](http://localhost:7890/lab/tree/pymc-labs/pytensor/pytensor/link/jax/dispatch/tensor_basic.py#line=38), in jax_funcify_AllocEmpty.<locals>.allocempty(*shape)
     38 def allocempty(*shape):
---> 39     return jnp.empty(shape, dtype=op.dtype)

File [~/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2785](http://localhost:7890/home/adr/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py#line=2784), in empty(shape, dtype, device)
   2783 if (m := _check_forgot_shape_tuple("empty", shape, dtype)): raise TypeError(m)
   2784 dtypes.check_user_dtype_supported(dtype, "empty")
-> 2785 return zeros(shape, dtype, device=device)

File [~/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2765](http://localhost:7890/home/adr/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py#line=2764), in zeros(shape, dtype, device)
   2763 if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
   2764 dtypes.check_user_dtype_supported(dtype, "zeros")
-> 2765 shape = canonicalize_shape(shape)
   2766 return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))

File [~/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:100](http://localhost:7890/home/adr/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py#line=99), in canonicalize_shape(shape, context)
     98   return core.canonicalize_shape((shape,), context)  # type: ignore
     99 else:
--> 100   return core.canonicalize_shape(shape, context)

File [~/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/core.py:1647](http://localhost:7890/home/adr/micromamba/envs/dev-cuda/lib/python3.11/site-packages/jax/_src/core.py#line=1646), in canonicalize_shape(shape, context)
   1645 except TypeError:
   1646   pass
-> 1647 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1[/0](http://localhost:7890/0))>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function jax_funcified_fgraph at [/tmp/tmp5pyizklf:1](http://localhost:7890/tmp/tmp5pyizklf#line=0) for jit. This concrete value was not available in Python because it depends on the value of the argument a.
Apply node that caused the error: DeepCopyOp(y_unobserved)
Toposort index: 0
Inputs types: [TensorType(float64, shape=(1,))]
Inputs shapes: [(1,), ()]
Inputs strides: [(8,), ()]
Inputs values: [array([1.]), array(3)]
Outputs clients: [['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

The dprint of the function:

DeepCopyOp [id A] <Vector(float64, shape=(1,))> 0
 └─ y_unobserved [id B] <Vector(float64, shape=(1,))>
AdvancedSetSubtensor [id C] <Vector(float64, shape=(?,))> 'y' 3
 ├─ AdvancedSetSubtensor [id D] <Vector(float64, shape=(?,))> 2
 │  ├─ AllocEmpty{dtype='float64'} [id E] <Vector(float64, shape=(?,))> 1
 │  │  └─ a [id F] <Scalar(int64, shape=())>
 │  ├─ y_unobserved [id B] <Vector(float64, shape=(1,))>
 │  └─ [False False  True] [id G] <Vector(bool, shape=(3,))>
 ├─ [0. 0.] [id H] <Vector(float64, shape=(2,))>
 └─ [ True  True False] [id I] <Vector(bool, shape=(3,))>

I think the problem is that AllocEmpty doesn't have a known shape.

Everything works fine if no dimensions are specified:

with pm.Model() as model:
    y = pm.Normal("y", 0, observed=[0, 0, np.nan])

In this case the shape is known:

DeepCopyOp [id A] <Vector(float64, shape=(1,))> 0
 └─ y_unobserved [id B] <Vector(float64, shape=(1,))>
AdvancedSetSubtensor [id C] <Vector(float64, shape=(?,))> 'y' 3
 ├─ AdvancedSetSubtensor [id D] <Vector(float64, shape=(?,))> 2
 │  ├─ AllocEmpty{dtype='float64'} [id E] <Vector(float64, shape=(3,))> 1
 │  │  └─ 3 [id F] <Scalar(int64, shape=())>
 │  ├─ y_unobserved [id B] <Vector(float64, shape=(1,))>
 │  └─ [False False  True] [id G] <Vector(bool, shape=(3,))>
 ├─ [0. 0.] [id H] <Vector(float64, shape=(2,))>
 └─ [ True  True False] [id I] <Vector(bool, shape=(3,))>
ricardoV94 commented 6 days ago

a was not properly replaced by freeze_rv_and_dims, it shouldn't be in the graph anymore. empty does not in fact have a static shape in that first graph. Seems like a PyMC issue. Is this on the main branch?

aseyboldt commented 6 days ago

Also happens on main

ricardoV94 commented 6 days ago

Found the problem