pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.69k stars 2.01k forks source link

Sampling with Numba backend results in an IndexError #6293

Open fonnesbeck opened 1 year ago

fonnesbeck commented 1 year ago

Running a large model (too slow to be sampled with C backend) results in an IndexError when using the Numba backend via:

aesara.config.mode = "NUMBA"

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
~/GitHub/aesara/aesara/link/utils.py in streamline_default_f()
    201                 ):
--> 202                     thunk()
    203                     for old_s in old_storage:

~/GitHub/aesara/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    668         ):
--> 669             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    670 

IndexError: index 23 is out of bounds for axis 1 with size 23

During handling of the above exception, another exception occurred:

IndexError                                Traceback (most recent call last)
~/phillies/pie/research/projections/pitchers/pitcher_proj.py in <module>
----> 1 trace, model, input_data = test_proj_model()

~/phillies/pie/research/projections/pitchers/pitcher_proj.py in test_proj_model(n_samples, n_iterations)
     12 
     13     model = pitcher_projection_model(**input_data)
---> 14     trace = sample_model(
     15         model,
     16         n_samples=n_samples,

~/phillies/pie/research/projections/pitchers/pitcher_proj.py in sample_model(model, n_samples, n_tune)
     28     LOGGER.info("Fitting model")
     29     with model:
---> 30         trace = pm.sample(n_samples,
     31         tune=n_tune, chains=2)
     32     #     trace = sample_numpyro_nuts(

~/GitHub/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, keep_warning_stat, idata_kwargs, mp_ctx, **kwargs)
    457             [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    458         _log.info("Auto-assigning NUTS sampler...")
--> 459         initial_points, step = init_nuts(
    460             init=init,
    461             chains=chains,

~/GitHub/pymc/pymc/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   1664     ]
   1665 
-> 1666     initial_points = _init_jitter(
   1667         model,
   1668         initvals,

~/GitHub/pymc/pymc/sampling.py in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   1558             if i < jitter_max_retries:
   1559                 try:
-> 1560                     model.check_start_vals(point)
   1561                 except SamplingError:
   1562                     # Retry with a new seed

~/GitHub/pymc/pymc/model.py in check_start_vals(self, start)
   1732                 )
   1733 
-> 1734             initial_eval = self.point_logps(point=elem)
   1735 
   1736             if not all(np.isfinite(v) for v in initial_eval.values()):

~/GitHub/pymc/pymc/model.py in point_logps(self, point, round_vals)
   1766             for factor, factor_logp in zip(
   1767                 factors,
-> 1768                 self.compile_fn(factor_logps_fn)(point),
   1769             )
   1770         }

~/GitHub/pymc/pymc/aesaraf.py in __call__(self, state)
    694 
    695     def __call__(self, state):
--> 696         return self.f(**state)
    697 
    698 

~/GitHub/aesara/aesara/compile/function/types.py in __call__(self, *args, **kwargs)
    969         try:
    970             outputs = (
--> 971                 self.vm()
    972                 if output_subset is None
    973                 else self.vm(output_subset=output_subset)

~/GitHub/aesara/aesara/link/utils.py in streamline_default_f()
    204                         old_s[0] = None
    205             except Exception:
--> 206                 raise_with_op(fgraph, node, thunk)
    207 
    208         f = streamline_default_f

~/GitHub/aesara/aesara/link/utils.py in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    532         # Some exception need extra parameter in inputs. So forget the
    533         # extra long error message in that case.
--> 534     raise exc_value.with_traceback(exc_trace)
    535 
    536 

~/GitHub/aesara/aesara/link/utils.py in streamline_default_f()
    200                     thunks, order, post_thunk_old_storage
    201                 ):
--> 202                     thunk()
    203                     for old_s in old_storage:
    204                         old_s[0] = None

~/GitHub/aesara/aesara/link/basic.py in thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    667             thunk_outputs=thunk_outputs,
    668         ):
--> 669             outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    670 
    671             for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

IndexError: index 23 is out of bounds for axis 1 with size 23
Apply node that caused the error: Elemwise{Composite{((i0 + Switch(AND(GE(i1, i2), LE(i1, i3)), i4, i5)) - ((i6 * scalar_softplus((-i7))) + i7))}}[(0, 1)](TensorConstant{0.6931471805599453}, rho_interval___interval, TensorConstant{-1.0}, (d__logp/dc_age_fb_stuff_log___logprob){1.0}, TensorConstant{-0.6931471805599453}, TensorConstant{-inf}, TensorConstant{2.0}, rho_interval__)
Toposort index: 289
Inputs types: [TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float32, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: [(), (), (), (), (), (23,), (2, 478, 23), (), (), (), (), (23,), (2, 478, 23), (), (), (), (), (23,), (2, 478, 23), (), (), (478, 3), (3,), (), (), (23,), (2, 478, 23), (), (), (), (), (23,), (2, 478, 23), (), (), (), (), (23,), (2, 478, 23), (), (), (3,), (3,), (2,), (), (), (), (), (), (), ()]
Inputs strides: [(), (), (), (), (), (8,), (87952, 184, 8), (), (), (), (), (8,), (87952, 184, 8), (), (), (), (), (8,), (87952, 184, 8), (), (), (24, 8), (8,), (), (), (8,), (87952, 184, 8), (), (), (), (), (8,), (87952, 184, 8), (), (), (), (), (8,), (87952, 184, 8), (), (), (8,), (8,), (8,), (), (), (), (), (), (), ()]
Inputs values: [array(0.65216304), array(0.50734942), array(-0.18043436), array(1.72874937), array(-0.76919174), 'not shown', 'not shown', array(-0.49425357), array(0.04805632), array(3.10702229), array(-0.31391575), 'not shown', 'not shown', array(-0.77787744), array(-0.02076679), array(2.64275196), array(0.50944904), 'not shown', 'not shown', array(0.42070242), array(0.77066501), 'not shown', array([0.37877663, 0.35234927, 0.90389653]), array(2.28823344), array(-0.07598847), 'not shown', 'not shown', array(-0.16014409), array(-0.89447727), array(2.74983691), array(0.68491237), 'not shown', 'not shown', array(0.53308313), array(-0.46624381), array(1.41455063), array(0.45739934), 'not shown', 'not shown', array(0.60924397), array(-0.10989646), array([ 0.73921934, -0.98177666, -0.45407343]), array([ 0.30524091, -0.0405966 ,  0.72879808]), array([-0.04702664,  0.41139767]), array(0.61398718), array(-0.15356098), array(0.50654193), array(0.44697276), array(-0.07707455), array(0.36574962), array(-2.4272219)]
Outputs clients: [['output']]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

The model runs with the C backend, but very slowly. Note that I get the same error with nutpie, suggesting that it is an issue with Numba.

Versions and main components

ricardoV94 commented 1 year ago

Any change you can give us a MWE?

aseyboldt commented 1 year ago

I think I found it (got the model from Chris on slack):

with pm.Model() as model:
    x = pm.Normal("x", shape=(3, 5, 7))
    y = at.cumsum(x, axis=-1)
    pm.Deterministic("y", y[np.array([2]), np.array([4]), np.array([6])])

func = aesara.function([model.x], model.y, mode="NUMBA")
func(np.zeros((3, 5, 7)))

Seems to be an issue in the numba impl of cumsum

Update:

Directly using aesara

x = at.dtensor3("x")
y = at.cumsum(x, axis=-1)
func = aesara.function([x], y, mode="NUMBA")
func(np.ones((3, 5, 7))).shape
# (5, 7, 3)

Update 2: I think I figured it out. The problem is that the code assumes that x.transpose(*order).transpose(*order) recreates the original array order, which is not in general true when rank>2. Working on a fix...

ricardoV94 commented 1 year ago

@aseyboldt is this fixed on the PyTensor side?