pymc-devs / nutpie

Python wrapper for nuts-rs
MIT License
123 stars 10 forks source link

Nutpie fails to compile #40

Closed bpajusco closed 1 year ago

bpajusco commented 1 year ago

I'm trying to use nutpie to speed up an ICM GP model. I've run the model through pymc without errors but once I call the command nutpie.compile_pymc_model(icm) I get the following error message:

```python TypingError: Failed in nopython mode pipeline (step: nopython frontend) Failed in nopython mode pipeline (step: nopython frontend) No implementation of function Function(.elemwise at 0x18dfcdea0>) found for signature: >>> elemwise(readonly array(float64, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C)) For completeness here is full output TypingError Traceback (most recent call last) File :2 File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/nutpie/compile_pymc.py:121, in compile_pymc_model(model, **kwargs) 116 user_data = make_user_data(logp_fn_pt, shared_data) 118 logp_numba_raw, c_sig = _make_c_logp_func( 119 n_dim, logp_fn, user_data, shared_logp, shared_data 120 ) --> 121 logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw) 123 def expand_draw(x, seed, chain, draw, *, shared_data): 124 return expand_fn(x, **{name: shared_data[name] for name in shared_expand})[0] File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/decorators.py:282, in cfunc..wrapper(func) 280 if cache: 281 res.enable_caching() --> 282 res.compile() 283 return res File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.._acquire_compile_lock(*args, **kwargs) 32 @functools.wraps(func) 33 def _acquire_compile_lock(*args, **kwargs): 34 with self: ---> 35 return func(*args, **kwargs) File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/ccallback.py:67, in CFunc.compile(self) 64 cres = self._cache.load_overload(self._sig, 65 self._targetdescr.target_context) 66 if cres is None: ---> 67 cres = self._compile_uncached() 68 self._cache.save_overload(self._sig, cres) 69 else: File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/ccallback.py:81, in CFunc._compile_uncached(self) 78 sig = self._sig 80 # Compile native function as well as cfunc wrapper ---> 81 return self._compiler.compile(sig.args, sig.return_type) File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/dispatcher.py:129, in _FunctionCompiler.compile(self, args, return_type) 127 return retval 128 else: --> 129 raise retval File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/dispatcher.py:139, in _FunctionCompiler._compile_cached(self, args, return_type) 136 pass 138 try: --> 139 retval = self._compile_core(args, return_type) 140 except errors.TypingError as e: 141 self._failed_cache[key] = e File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/dispatcher.py:152, in _FunctionCompiler._compile_core(self, args, return_type) 149 flags = self._customize_flags(flags) 151 impl = self._get_implementation(args, {}) --> 152 cres = compiler.compile_extra(self.targetdescr.typing_context, 153 self.targetdescr.target_context, 154 impl, 155 args=args, return_type=return_type, 156 flags=flags, locals=self.locals, 157 pipeline_class=self.pipeline_class) 158 # Check typing error if object mode is used 159 if cres.typing_error is not None and not flags.enable_pyobject: File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler.py:716, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class) 692 """Compiler entry point 693 694 Parameter (...) 712 compiler pipeline 713 """ 714 pipeline = pipeline_class(typingctx, targetctx, library, 715 args, return_type, flags, locals) --> 716 return pipeline.compile_extra(func) File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler.py:452, in CompilerBase.compile_extra(self, func) 450 self.state.lifted = () 451 self.state.lifted_from = None --> 452 return self._compile_bytecode() File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler.py:520, in CompilerBase._compile_bytecode(self) 516 """ 517 Populate and run pipeline for bytecode input 518 """ 519 assert self.state.func_ir is None --> 520 return self._compile_core() File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler.py:499, in CompilerBase._compile_core(self) 497 self.state.status.fail_reason = e 498 if is_final_pipeline: --> 499 raise e 500 else: 501 raise CompilerError("All available pipelines exhausted") File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler.py:486, in CompilerBase._compile_core(self) 484 res = None 485 try: --> 486 pm.run(self.state) 487 if self.state.cr is not None: 488 break File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler_machinery.py:368, in PassManager.run(self, state) 365 msg = "Failed in %s mode pipeline (step: %s)" % \ 366 (self.pipeline_name, pass_desc) 367 patched_exception = self._patch_error(msg, e) --> 368 raise patched_exception File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state) 354 pass_inst = _pass_registry.get(pss).pass_inst 355 if isinstance(pass_inst, CompilerPass): --> 356 self._runPass(idx, pass_inst, state) 357 else: 358 raise BaseException("Legacy pass in use") File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.._acquire_compile_lock(*args, **kwargs) 32 @functools.wraps(func) 33 def _acquire_compile_lock(*args, **kwargs): 34 with self: ---> 35 return func(*args, **kwargs) File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state) 309 mutated |= check(pss.run_initialization, internal_state) 310 with SimpleTimer() as pass_time: --> 311 mutated |= check(pss.run_pass, internal_state) 312 with SimpleTimer() as finalize_time: 313 mutated |= check(pss.run_finalizer, internal_state) File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass..check(func, compiler_state) 272 def check(func, compiler_state): --> 273 mangled = func(compiler_state) 274 if mangled not in (True, False): 275 msg = ("CompilerPass implementations should return True/False. " 276 "CompilerPass with name '%s' did not.") File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/typed_passes.py:105, in BaseTypeInference.run_pass(self, state) 99 """ 100 Type inference and legalization 101 """ 102 with fallback_context(state, 'Function "%s" failed type inference' 103 % (state.func_id.func_name,)): 104 # Type inference --> 105 typemap, return_type, calltypes, errs = type_inference_stage( 106 state.typingctx, 107 state.targetctx, 108 state.func_ir, 109 state.args, 110 state.return_type, 111 state.locals, 112 raise_errors=self._raise_errors) 113 state.typemap = typemap 114 # save errors in case of partial typing File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/typed_passes.py:83, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors) 81 infer.build_constraint() 82 # return errors in case of partial typing ---> 83 errs = infer.propagate(raise_errors=raise_errors) 84 typemap, restype, calltypes = infer.unify(raise_errors=raise_errors) 86 # Output all Numba warnings File ~/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/typeinfer.py:1086, in TypeInferer.propagate(self, raise_errors) 1083 force_lit_args = [e for e in errors 1084 if isinstance(e, ForceLiteralArg)] 1085 if not force_lit_args: -> 1086 raise errors[0] 1087 else: 1088 raise reduce(operator.or_, force_lit_args) TypingError: Failed in nopython mode pipeline (step: nopython frontend) Failed in nopython mode pipeline (step: nopython frontend) No implementation of function Function(.elemwise at 0x18dfcdea0>) found for signature: >>> elemwise(readonly array(float64, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C)) There are 2 candidate implementations: - Of which 2 did not match due to: Overload in function 'numba_funcify_Elemwise..ov_elemwise': File: pytensor/link/numba/dispatch/elemwise.py: Line 687. With argument(s): '(readonly array(float64, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C))': Rejected as the implementation raised a specific error: TypingError: Failed in nopython mode pipeline (step: nopython frontend) No implementation of function Function() found for signature: >>> _vectorized(type(CPUDispatcher()), Literal[str](gASVCAAAAAAAAAAoKSkpKXSULg== ), Literal[str](gASVBAAAAAAAAAAphZQu ), Literal[str](gASVDQAAAAAAAACMB2Zsb2F0NjSUhZQu ), Literal[str](gASVCQAAAAAAAABLAEsChpSFlC4= ), StarArgTuple(readonly array(float64, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C))) There are 2 candidate implementations: - Of which 1 did not match due to: Intrinsic in function '_vectorized': File: pytensor/link/numba/dispatch/elemwise.py: Line 466. With argument(s): '(type(CPUDispatcher()), unicode_type, unicode_type, unicode_type, unicode_type, StarArgTuple(readonly array(float64, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C)))': Rejected as the implementation raised a specific error: TypingError: input_bc_patterns must be literal. raised from /Users/bpajusco/opt/anaconda3/envs/fast/lib/python3.10/site-packages/pytensor/link/numba/dispatch/elemwise.py:486 - Of which 1 did not match due to: Intrinsic in function '_vectorized': File: pytensor/link/numba/dispatch/elemwise.py: Line 466. With argument(s): '(type(CPUDispatcher()), Literal[str](gASVCAAAAAAAAAAoKSkpKXSULg== ), Literal[str](gASVBAAAAAAAAAAphZQu ), Literal[str](gASVDQAAAAAAAACMB2Zsb2F0NjSUhZQu ), Literal[str](gASVCQAAAAAAAABLAEsChpSFlC4= ), StarArgTuple(readonly array(float64, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C)))': Rejected as the implementation raised a specific error: TypingError: Inputs to elemwise must be arrays. raised from /Users/bpajusco/opt/anaconda3/envs/fast/lib/python3.10/site-packages/pytensor/link/numba/dispatch/elemwise.py:514 During: resolving callee type: Function() During: typing of call at /Users/bpajusco/opt/anaconda3/envs/fast/lib/python3.10/site-packages/pytensor/link/numba/dispatch/elemwise.py (648) File "../../../../opt/anaconda3/envs/fast/lib/python3.10/site-packages/pytensor/link/numba/dispatch/elemwise.py", line 648: def elemwise_wrapper(*inputs): return _vectorized( ^ raised from /Users/bpajusco/opt/anaconda3/envs/fast/lib/python3.10/site-packages/numba/core/typeinfer.py:1086 During: resolving callee type: Function(.elemwise at 0x18dfcdea0>) During: typing of call at /var/folders/z8/q5v42wdj6j1fxt991j966lqm0000gn/T/tmp0kknc9dd (155) File "../../../../../../var/folders/z8/q5v42wdj6j1fxt991j966lqm0000gn/T/tmp0kknc9dd", line 155: def numba_funcified_fgraph(_joined_variables): # Elemwise{Composite{((i0 - (i1 * i2)) - i3)}}[(0, 2)](TensorConstant{-234.32932596719152}, TensorConstant{0.5}, InplaceDimShuffle{}.0, Sum{acc_dtype=float64}.0) tensor_variable_63 = elemwise_28(tensor_constant_27, tensor_constant_28, tensor_variable_60, tensor_variable_42) ^ During: resolving callee type: type(CPUDispatcher()) During: typing of call at /Users/bpajusco/opt/anaconda3/envs/fast/lib/python3.10/site-packages/nutpie/compile_pymc.py (265) During: resolving callee type: type(CPUDispatcher()) During: typing of call at /Users/bpajusco/opt/anaconda3/envs/fast/lib/python3.10/site-packages/nutpie/compile_pymc.py (265) File "../../../../opt/anaconda3/envs/fast/lib/python3.10/site-packages/nutpie/compile_pymc.py", line 265: def extract_shared(x, user_data_): return inner(x) ^ ```

And also pymc model as a reference:

with pm.Model() as icm:
    # Priors
    λ = pm.Gamma("λ", alpha=2, beta=0.5)
    kernel = pm.gp.cov.ExpQuad(input_dim=2, ls=λ, active_dims=[0])
    σ = pm.HalfNormal("σ", sigma=3)

    # Get ICM kernel
    W = pm.Normal("W", mu=0, sigma=3, shape=(n_outputs,2), initval=np.random.randn(n_outputs,2))
    κ = pm.Gamma("κ", alpha=1.5, beta=1, shape=n_outputs)
    B = pm.Deterministic('B', at.dot(W, W.T) + at.diag(κ))

    cov_icm = get_icm(input_dim=2, kernel=kernel, B=B, active_dims=[1])

    gp = pm.gp.Marginal(cov_func=cov_icm)
    y_ = gp.marginal_likelihood("f", X, Y, sigma=σ)   
twiecki commented 1 year ago

@bpajusco Thanks for reporting. Can you format your message better using ``` ?

bpajusco commented 1 year ago

Done and sorry about not doing it properly the first time around

aseyboldt commented 1 year ago

That looks like an incorrect numba implementation for one of the pytensor ops... Can you share the get_icm function, so that I can debug locally? If not, it would help to see what the output of this is: pytensor.dprint(icm.logp_dlogp_function()._pytensor_function).

Independent of nutpie vs jax etc, if you want to speed this up, could you maybe use the Woodbury identity to avoid factorizing $κI + WW^T$? (assuming that's what's happening in get_icm)...

bpajusco commented 1 year ago

Here is the output from pytensor.dprint()

Sum{acc_dtype=float64} [id A] '__logp' 115
 |MakeVector{dtype='float64'} [id B] 114
   |Elemwise{Composite{(Switch(i0, ((i1 - (i2 * i3)) + Switch(i4, i5, i6)), i5) + i6)}}[(0, 3)] [id C] 'λ_log___logprob' 113
   | |Elemwise{ge,no_inplace} [id D] 14
   | | |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
   | | | |λ_log__ [id F]
   | | |TensorConstant{0.0} [id G]
   | |TensorConstant{-1.3862943611198906} [id H]
   | |TensorConstant{0.5} [id I]
   | |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
   | |Elemwise{eq,no_inplace} [id J] 13
   | | |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
   | | |TensorConstant{0} [id K]
   | |TensorConstant{-inf} [id L]
   | |λ_log__ [id F]
   |Elemwise{Composite{(Switch(i0, ((i1 - i2) + Switch(i3, i4, (i5 * i6))), i4) + i6)}}[(0, 2)] [id M] 'η_log___logprob' 102
   | |Elemwise{ge,no_inplace} [id N] 11
   | | |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
   | | | |η_log__ [id P]
   | | |TensorConstant{0.0} [id G]
   | |TensorConstant{-0.6931471805599454} [id Q]
   | |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
   | |Elemwise{eq,no_inplace} [id R] 10
   | | |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
   | | |TensorConstant{0} [id K]
   | |TensorConstant{-inf} [id L]
   | |TensorConstant{2.0} [id S]
   | |η_log__ [id P]
   |Elemwise{Composite{(Switch(i0, ((i1 * sqr((i2 * i3))) - i4), i5) + i6)}}[(0, 3)] [id T] 'σ_log___logprob' 101
   | |Elemwise{ge,no_inplace} [id U] 16
   | | |Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
   | | | |σ_log__ [id W]
   | | |TensorConstant{0.0} [id G]
   | |TensorConstant{-0.5} [id X]
   | |TensorConstant{0.3333333333333333} [id Y]
   | |Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
   | |TensorConstant{1.3244036413128373} [id Z]
   | |TensorConstant{-inf} [id L]
   | |σ_log__ [id W]
   |Sum{acc_dtype=float64} [id BA] 20
   | |Elemwise{Composite{((i0 * sqr((i1 * i2))) - i3)}} [id BB] 'sigma > 0' 8
   |   |TensorConstant{(1, 1) of -0.5} [id BC]
   |   |TensorConstant{(1, 1) of ..3333333333} [id BD]
   |   |W [id BE]
   |   |TensorConstant{(1, 1) of ..8218727822} [id BF]
   |Sum{acc_dtype=float64} [id BG] 108
   | |Elemwise{Composite{(Switch(i0, ((i1 - i2) + Switch(i3, i4, (i5 * i6))), i4) + i6)}}[(0, 2)] [id BH] 'κ_log___logprob' 106
   |   |Elemwise{ge,no_inplace} [id BI] 19
   |   | |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
   |   | | |κ_log__ [id BK]
   |   | |TensorConstant{(1,) of 0.0} [id BL]
   |   |TensorConstant{(1,) of 0...3763524526} [id BM]
   |   |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
   |   |Elemwise{eq,no_inplace} [id BN] 18
   |   | |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
   |   | |TensorConstant{(1,) of 0} [id BO]
   |   |TensorConstant{(1,) of -inf} [id BP]
   |   |TensorConstant{(1,) of 0.5} [id BQ]
   |   |κ_log__ [id BK]
   |Elemwise{Composite{Switch(i0, ((i1 - (i2 * i3)) - i4), i5)}}[(0, 3)] [id BR] 'posdef' 70
     |All [id BS] 47
     | |Elemwise{gt,no_inplace} [id BT] 43
     |   |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id BU] 41
     |   | |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
     |   |   |Elemwise{Composite{((i0 * i1 * i2) + i3 + i4)}} [id BW] 38
     |   |     |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
     |   |     | |TensorConstant{(1, 1) of -0.5} [id BC]
     |   |     | |Elemwise{Add}[(0, 0)] [id BY] 35
     |   |     | | |Dot22Scalar [id BZ] 31
     |   |     | | | |Elemwise{true_div,no_inplace} [id CA] 22
     |   |     | | | | |TensorConstant{[[ 0.]
 [ ...]
 [84.]]} [id CB]
     |   |     | | | | |InplaceDimShuffle{x,x} [id CC] 12
     |   |     | | | |   |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
     |   |     | | | |InplaceDimShuffle{1,0} [id CD] 27
     |   |     | | | | |Elemwise{true_div,no_inplace} [id CA] 22
     |   |     | | | |TensorConstant{-2.0} [id CE]
     |   |     | | |InplaceDimShuffle{0,x} [id CF] 34
     |   |     | | | |Elemwise{sqr,no_inplace} [id CG] 30
     |   |     | | |   |InplaceDimShuffle{0} [id CH] 26
     |   |     | | |     |Elemwise{true_div,no_inplace} [id CA] 22
     |   |     | | |InplaceDimShuffle{x,0} [id CI] 33
     |   |     | |   |Elemwise{sqr,no_inplace} [id CG] 30
     |   |     | |TensorConstant{(1, 1) of 0.0} [id CJ]
     |   |     | |TensorConstant{(1, 1) of inf} [id CK]
     |   |     |Elemwise{sqr,no_inplace} [id CL] 21
     |   |     | |InplaceDimShuffle{x,x} [id CM] 9
     |   |     |   |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
     |   |     |AdvancedSubtensor [id CN] 32
     |   |     | |SpecifyShape [id CO] 29
     |   |     | | |Gemm{inplace} [id CP] 25
     |   |     | | | |AllocDiag{offset=0, axis1=0, axis2=1} [id CQ] 17
     |   |     | | | | |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
     |   |     | | | |(d__logp/dη_log___log_jacobian){1.0} [id CR]
     |   |     | | | |W [id BE]
     |   |     | | | |InplaceDimShuffle{1,0} [id CS] 'W.T' 3
     |   |     | | | | |W [id BE]
     |   |     | | | |(d__logp/dη_log___log_jacobian){1.0} [id CR]
     |   |     | | |TensorConstant{3} [id CT]
     |   |     | | |TensorConstant{3} [id CT]
     |   |     | |TensorConstant{[[0]
 [0]
.. [2]
 [2]]} [id CU]
     |   |     | |TensorConstant{[[0 0 0 0 ..
  2 2 2]]} [id CV]
     |   |     |AllocDiag{offset=0, axis1=0, axis2=1} [id CW] 28
     |   |     | |Alloc [id CX] 24
     |   |     |   |Elemwise{sqr,no_inplace} [id CY] 15
     |   |     |   | |Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
     |   |     |   |TensorConstant{255} [id CZ]
     |   |     |TensorConstant{[[1.e-06 0..0 1.e-06]]} [id DA]
     |   |TensorConstant{(1,) of 0} [id BO]
     |TensorConstant{-234.32932596719152} [id DB]
     |TensorConstant{0.5} [id I]
     |InplaceDimShuffle{} [id DC] 67
     | |Sum{axis=[1], acc_dtype=float64} [id DD] 64
     |   |Elemwise{Sqr}[(0, 0)] [id DE] 62
     |     |InplaceDimShuffle{1,0} [id DF] 58
     |       |SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True} [id DG] 56
     |         |Elemwise{Switch}[(0, 1)] [id DH] 54
     |         | |InplaceDimShuffle{x,x} [id DI] 51
     |         | | |All [id BS] 47
     |         | |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
     |         | |TensorConstant{(1, 1) of 1} [id DJ]
     |         |TensorConstant{[[-0.68584..56159017]]} [id DK]
     |Sum{acc_dtype=float64} [id DL] 49
     | |Elemwise{Log}[(0, 0)] [id DM] 45
     |   |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id BU] 41
     |TensorConstant{-inf} [id L]
Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + Switch(i4, i5, Switch(i0, i6, i5)) + (i7 * i2) + i6)}}[(0, 7)] [id DN] 'λ_log___grad' 112
 |Elemwise{ge,no_inplace} [id D] 14
 |TensorConstant{-0.5} [id X]
 |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
 |TensorConstant{0} [id K]
 |Elemwise{eq,no_inplace} [id J] 13
 |TensorConstant{0.0} [id G]
 |(d__logp/dη_log___log_jacobian){1.0} [id CR]
 |InplaceDimShuffle{} [id DO] 111
   |Sum{axis=[0], acc_dtype=float64} [id DP] 110
     |Elemwise{Composite{((-((((i0 * ((i1 * i2 * i3) + (i4 * i2 * i5)) * i6) / i7) + i8) * i6)) / i9)}}[(0, 3)] [id DQ] 109
       |TensorConstant{(1, 1) of 2.0} [id DR]
       |TensorConstant{(1, 1) of -0.5} [id BC]
       |Elemwise{sqr,no_inplace} [id CL] 21
       |InplaceDimShuffle{0,x} [id DS] 99
       | |Sum{axis=[1], acc_dtype=float64} [id DT] 90
       |   |Elemwise{mul,no_inplace} [id DU] 85
       |     |Elemwise{Composite{AND(GE(i0, i1), LE(i0, i2))}} [id DV] 36
       |     | |Elemwise{Add}[(0, 0)] [id BY] 35
       |     | |TensorConstant{(1, 1) of 0.0} [id CJ]
       |     | |TensorConstant{(1, 1) of inf} [id CK]
       |     |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
       |     | |Elemwise{Invert}[(0, 0)] [id DX] 50
       |     | | |InplaceDimShuffle{x,x} [id DY] 46
       |     | |   |Any [id DZ] 42
       |     | |     |Elemwise{isnan,no_inplace} [id EA] 40
       |     | |       |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
       |     | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EB] 77
       |     | | |InplaceDimShuffle{1,0} [id EC] 55
       |     | | | |Elemwise{switch,no_inplace} [id ED] 53
       |     | | |   |Elemwise{Invert}[(0, 0)] [id DX] 50
       |     | | |   |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
       |     | | |   |TensorConstant{(1, 1) of 1} [id DJ]
       |     | | |InplaceDimShuffle{1,0} [id EE] 76
       |     | |   |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EF] 75
       |     | |     |InplaceDimShuffle{1,0} [id EC] 55
       |     | |     |Elemwise{Composite{((i0 * i1) - i2)}}[(0, 0)] [id EG] 74
       |     | |       |Dot22 [id EH] 69
       |     | |       | |InplaceDimShuffle{1,0} [id EI] 66
       |     | |       | | |Elemwise{Composite{Switch(i0, (i1 + Switch(i2, (i3 * i4), i5)), i6)}}[(0, 1)] [id EJ] 63
       |     | |       | |   |Elemwise{Invert}[(0, 0)] [id DX] 50
       |     | |       | |   |IncSubtensor{InplaceSet;:int64:, :int64:} [id EK] 52
       |     | |       | |   | |Alloc [id EL] 5
       |     | |       | |   | | |TensorConstant{(1, 1) of 0.0} [id EM]
       |     | |       | |   | | |TensorConstant{255} [id EN]
       |     | |       | |   | | |TensorConstant{255} [id EN]
       |     | |       | |   | |AllocDiag{offset=0, axis1=0, axis2=1} [id EO] 48
       |     | |       | |   | | |Elemwise{true_div} [id EP] 44
       |     | |       | |   | |   |TensorConstant{(1,) of -1.0} [id EQ]
       |     | |       | |   | |   |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id BU] 41
       |     | |       | |   | |ScalarConstant{255} [id ER]
       |     | |       | |   | |ScalarConstant{255} [id ER]
       |     | |       | |   |InplaceDimShuffle{x,x} [id DI] 51
       |     | |       | |   |Dot22Scalar [id ES] 61
       |     | |       | |   | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id ET] 60
       |     | |       | |   | | |InplaceDimShuffle{1,0} [id EU] 57
       |     | |       | |   | | | |Elemwise{Switch}[(0, 1)] [id DH] 54
       |     | |       | |   | | |Elemwise{mul,no_inplace} [id EV] 59
       |     | |       | |   | |   |TensorConstant{(1, 1) of 2.0} [id DR]
       |     | |       | |   | |   |TensorConstant{(1, 1) of -0.5} [id BC]
       |     | |       | |   | |   |SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True} [id DG] 56
       |     | |       | |   | |InplaceDimShuffle{1,0} [id DF] 58
       |     | |       | |   | |TensorConstant{-1.0} [id EW]
       |     | |       | |   |TensorConstant{[[1. 0. 0...1. 1. 1.]]} [id EX]
       |     | |       | |   |TensorConstant{(1, 1) of 0.0} [id CJ]
       |     | |       | |   |TensorConstant{(1, 1) of 1} [id DJ]
       |     | |       | |Elemwise{switch,no_inplace} [id ED] 53
       |     | |       |TensorConstant{[[1. 1. 1...0. 0. 1.]]} [id EY]
       |     | |       |InplaceDimShuffle{1,0} [id EZ] 73
       |     | |         |AllocDiag{offset=0, axis1=0, axis2=1} [id FA] 72
       |     | |           |Elemwise{Mul}[(0, 1)] [id FB] 71
       |     | |             |TensorConstant{(1,) of 0.5} [id BQ]
       |     | |             |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id FC] 68
       |     | |               |Dot22 [id FD] 65
       |     | |                 |InplaceDimShuffle{1,0} [id EC] 55
       |     | |                 |Elemwise{Composite{Switch(i0, (i1 + Switch(i2, (i3 * i4), i5)), i6)}}[(0, 1)] [id EJ] 63
       |     | |InplaceDimShuffle{1,0} [id FE] 79
       |     | | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EB] 77
       |     | |TensorConstant{[[1. 0. 0...1. 1. 1.]]} [id EX]
       |     | |AllocDiag{offset=0, axis1=0, axis2=1} [id FF] 80
       |     | | |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id FG] 78
       |     | |   |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EB] 77
       |     | |TensorConstant{[[nan]]} [id FH]
       |     |AdvancedSubtensor [id CN] 32
       |     |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
       |TensorConstant{(1, 1) of -0.5} [id BC]
       |InplaceDimShuffle{0,x} [id FI] 100
       | |Sum{axis=[0], acc_dtype=float64} [id FJ] 91
       |   |Elemwise{mul,no_inplace} [id DU] 85
       |TensorConstant{[[ 0.]
 [ ...]
 [84.]]} [id CB]
       |InplaceDimShuffle{x,x} [id CC] 12
       |InplaceDimShuffle{0,x} [id FK] 107
       | |CGemv{inplace} [id FL] 103
       |   |CGemv{inplace} [id FM] 93
       |   | |AllocEmpty{dtype='float64'} [id FN] 7
       |   | | |TensorConstant{255} [id FO]
       |   | |(d__logp/dη_log___log_jacobian){1.0} [id CR]
       |   | |Elemwise{Mul}[(0, 1)] [id FP] 86
       |   | | |Elemwise{Composite{AND(GE(i0, i1), LE(i0, i2))}} [id DV] 36
       |   | | |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
       |   | | |Elemwise{sqr,no_inplace} [id CL] 21
       |   | | |AdvancedSubtensor [id CN] 32
       |   | | |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
       |   | |InplaceDimShuffle{0} [id CH] 26
       |   | |TensorConstant{0.0} [id FQ]
       |   |(d__logp/dη_log___log_jacobian){1.0} [id CR]
       |   |InplaceDimShuffle{1,0} [id FR] 92
       |   | |Elemwise{Mul}[(0, 1)] [id FP] 86
       |   |InplaceDimShuffle{0} [id CH] 26
       |   |(d__logp/dη_log___log_jacobian){1.0} [id CR]
       |Elemwise{sqr,no_inplace} [id FS] 23
         |InplaceDimShuffle{x,x} [id CC] 12
Elemwise{Composite{(Switch(i0, (-i1), i2) + Switch(i3, i2, Switch(i0, i4, i2)) + (i5 * i6 * i1 * i1) + i7)}}[(0, 6)] [id FT] 'η_log___grad' 98
 |Elemwise{ge,no_inplace} [id N] 11
 |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
 |TensorConstant{0} [id K]
 |Elemwise{eq,no_inplace} [id R] 10
 |TensorConstant{2.0} [id S]
 |TensorConstant{2.0} [id S]
 |Sum{acc_dtype=float64} [id FU] 89
 | |Elemwise{mul} [id FV] 84
 |   |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
 |   |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
 |   |AdvancedSubtensor [id CN] 32
 |(d__logp/dη_log___log_jacobian){1.0} [id CR]
Elemwise{Composite{(Switch(i0, (i1 * i2 * i2), i3) + (i4 * i5 * i2 * i2) + i6)}}[(0, 5)] [id FW] 'σ_log___grad' 97
 |Elemwise{ge,no_inplace} [id U] 16
 |TensorConstant{-0.1111111111111111} [id FX]
 |Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
 |TensorConstant{0} [id K]
 |TensorConstant{2.0} [id S]
 |Sum{acc_dtype=float64} [id FY] 88
 | |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id FZ] 83
 |   |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
 |(d__logp/dη_log___log_jacobian){1.0} [id CR]
Gemm{inplace} [id GA] 'W_grad' 105
 |Gemm{no_inplace} [id GB] 96
 | |W [id BE]
 | |(d__logp/dη_log___log_jacobian){1.0} [id CR]
 | |AdvancedIncSubtensor{inplace=True,  set_instead_of_inc=False} [id GC] 87
 | | |Alloc [id GD] 6
 | | | |TensorConstant{(1, 1) of 0.0} [id EM]
 | | | |TensorConstant{3} [id GE]
 | | | |TensorConstant{3} [id GE]
 | | |Elemwise{mul} [id GF] 82
 | | | |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
 | | | |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
 | | | |Elemwise{sqr,no_inplace} [id CL] 21
 | | |TensorConstant{[[0]
 [0]
.. [2]
 [2]]} [id CU]
 | | |TensorConstant{[[0 0 0 0 ..
  2 2 2]]} [id CV]
 | |W [id BE]
 | |TensorConstant{-0.1111111111111111} [id FX]
 |(d__logp/dη_log___log_jacobian){1.0} [id CR]
 |InplaceDimShuffle{1,0} [id GG] 95
 | |AdvancedIncSubtensor{inplace=True,  set_instead_of_inc=False} [id GC] 87
 |W [id BE]
 |(d__logp/dη_log___log_jacobian){1.0} [id CR]
Elemwise{Composite{(Switch(i0, (-i1), i2) + Switch(i3, i2, Switch(i0, i4, i2)) + (i5 * i1) + i6)}} [id GH] 'κ_log___grad' 104
 |Elemwise{ge,no_inplace} [id BI] 19
 |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
 |TensorConstant{(1,) of 0} [id BO]
 |Elemwise{eq,no_inplace} [id BN] 18
 |TensorConstant{(1,) of 0.5} [id BQ]
 |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id GI] 94
 | |AdvancedIncSubtensor{inplace=True,  set_instead_of_inc=False} [id GC] 87
 |TensorConstant{(1,) of 1.0} [id GJ]

and here is get_icm function:

def get_icm(input_dim, kernel, W=None, kappa=None, B=None, active_dims=None, name='ICM'):
    coreg = pm.gp.cov.Coregion(input_dim=input_dim, W=W, kappa=kappa, B=B, active_dims=active_dims)
    icm_cov = kernel * coreg 
    return icm_cov

And indeed this is what get_icm does

aseyboldt commented 1 year ago

@bpajusco Should be fixed on pytensor main now. You still won't get great performance however, because there are still two Ops in your graph, (SolveTriangular and AdvancedSubtensor) that do not have a numba implementation and need to call into python.

@bwengals I was looking into the Coregion covariance (essentially $I + WW^T$, where $W$ is low-rank), and it doesn't seem we are taking advantage of the low rank structure. Do you think it would make sense to extend BaseCovariance, adding methods like those, that have a default implementation using factorizations, but could be overwritten by special covariance where we have better ways to do the computation?

Something along the lines of (with better names?)

bwengals commented 1 year ago

Hey thanks for tagging me, super interested in this. I resp here: https://github.com/pymc-devs/pymc/discussions/6615