Closed bpajusco closed 1 year ago
@bpajusco Thanks for reporting. Can you format your message better using ``` ?
Done and sorry about not doing it properly the first time around
That looks like an incorrect numba implementation for one of the pytensor ops...
Can you share the get_icm
function, so that I can debug locally? If not, it would help to see what the output of this is: pytensor.dprint(icm.logp_dlogp_function()._pytensor_function)
.
Independent of nutpie vs jax etc, if you want to speed this up, could you maybe use the Woodbury identity to avoid factorizing $κI + WW^T$? (assuming that's what's happening in get_icm
)...
Here is the output from pytensor.dprint()
Sum{acc_dtype=float64} [id A] '__logp' 115
|MakeVector{dtype='float64'} [id B] 114
|Elemwise{Composite{(Switch(i0, ((i1 - (i2 * i3)) + Switch(i4, i5, i6)), i5) + i6)}}[(0, 3)] [id C] 'λ_log___logprob' 113
| |Elemwise{ge,no_inplace} [id D] 14
| | |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
| | | |λ_log__ [id F]
| | |TensorConstant{0.0} [id G]
| |TensorConstant{-1.3862943611198906} [id H]
| |TensorConstant{0.5} [id I]
| |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
| |Elemwise{eq,no_inplace} [id J] 13
| | |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
| | |TensorConstant{0} [id K]
| |TensorConstant{-inf} [id L]
| |λ_log__ [id F]
|Elemwise{Composite{(Switch(i0, ((i1 - i2) + Switch(i3, i4, (i5 * i6))), i4) + i6)}}[(0, 2)] [id M] 'η_log___logprob' 102
| |Elemwise{ge,no_inplace} [id N] 11
| | |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
| | | |η_log__ [id P]
| | |TensorConstant{0.0} [id G]
| |TensorConstant{-0.6931471805599454} [id Q]
| |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
| |Elemwise{eq,no_inplace} [id R] 10
| | |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
| | |TensorConstant{0} [id K]
| |TensorConstant{-inf} [id L]
| |TensorConstant{2.0} [id S]
| |η_log__ [id P]
|Elemwise{Composite{(Switch(i0, ((i1 * sqr((i2 * i3))) - i4), i5) + i6)}}[(0, 3)] [id T] 'σ_log___logprob' 101
| |Elemwise{ge,no_inplace} [id U] 16
| | |Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
| | | |σ_log__ [id W]
| | |TensorConstant{0.0} [id G]
| |TensorConstant{-0.5} [id X]
| |TensorConstant{0.3333333333333333} [id Y]
| |Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
| |TensorConstant{1.3244036413128373} [id Z]
| |TensorConstant{-inf} [id L]
| |σ_log__ [id W]
|Sum{acc_dtype=float64} [id BA] 20
| |Elemwise{Composite{((i0 * sqr((i1 * i2))) - i3)}} [id BB] 'sigma > 0' 8
| |TensorConstant{(1, 1) of -0.5} [id BC]
| |TensorConstant{(1, 1) of ..3333333333} [id BD]
| |W [id BE]
| |TensorConstant{(1, 1) of ..8218727822} [id BF]
|Sum{acc_dtype=float64} [id BG] 108
| |Elemwise{Composite{(Switch(i0, ((i1 - i2) + Switch(i3, i4, (i5 * i6))), i4) + i6)}}[(0, 2)] [id BH] 'κ_log___logprob' 106
| |Elemwise{ge,no_inplace} [id BI] 19
| | |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
| | | |κ_log__ [id BK]
| | |TensorConstant{(1,) of 0.0} [id BL]
| |TensorConstant{(1,) of 0...3763524526} [id BM]
| |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
| |Elemwise{eq,no_inplace} [id BN] 18
| | |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
| | |TensorConstant{(1,) of 0} [id BO]
| |TensorConstant{(1,) of -inf} [id BP]
| |TensorConstant{(1,) of 0.5} [id BQ]
| |κ_log__ [id BK]
|Elemwise{Composite{Switch(i0, ((i1 - (i2 * i3)) - i4), i5)}}[(0, 3)] [id BR] 'posdef' 70
|All [id BS] 47
| |Elemwise{gt,no_inplace} [id BT] 43
| |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id BU] 41
| | |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
| | |Elemwise{Composite{((i0 * i1 * i2) + i3 + i4)}} [id BW] 38
| | |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
| | | |TensorConstant{(1, 1) of -0.5} [id BC]
| | | |Elemwise{Add}[(0, 0)] [id BY] 35
| | | | |Dot22Scalar [id BZ] 31
| | | | | |Elemwise{true_div,no_inplace} [id CA] 22
| | | | | | |TensorConstant{[[ 0.]
[ ...]
[84.]]} [id CB]
| | | | | | |InplaceDimShuffle{x,x} [id CC] 12
| | | | | | |Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
| | | | | |InplaceDimShuffle{1,0} [id CD] 27
| | | | | | |Elemwise{true_div,no_inplace} [id CA] 22
| | | | | |TensorConstant{-2.0} [id CE]
| | | | |InplaceDimShuffle{0,x} [id CF] 34
| | | | | |Elemwise{sqr,no_inplace} [id CG] 30
| | | | | |InplaceDimShuffle{0} [id CH] 26
| | | | | |Elemwise{true_div,no_inplace} [id CA] 22
| | | | |InplaceDimShuffle{x,0} [id CI] 33
| | | | |Elemwise{sqr,no_inplace} [id CG] 30
| | | |TensorConstant{(1, 1) of 0.0} [id CJ]
| | | |TensorConstant{(1, 1) of inf} [id CK]
| | |Elemwise{sqr,no_inplace} [id CL] 21
| | | |InplaceDimShuffle{x,x} [id CM] 9
| | | |Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
| | |AdvancedSubtensor [id CN] 32
| | | |SpecifyShape [id CO] 29
| | | | |Gemm{inplace} [id CP] 25
| | | | | |AllocDiag{offset=0, axis1=0, axis2=1} [id CQ] 17
| | | | | | |Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
| | | | | |(d__logp/dη_log___log_jacobian){1.0} [id CR]
| | | | | |W [id BE]
| | | | | |InplaceDimShuffle{1,0} [id CS] 'W.T' 3
| | | | | | |W [id BE]
| | | | | |(d__logp/dη_log___log_jacobian){1.0} [id CR]
| | | | |TensorConstant{3} [id CT]
| | | | |TensorConstant{3} [id CT]
| | | |TensorConstant{[[0]
[0]
.. [2]
[2]]} [id CU]
| | | |TensorConstant{[[0 0 0 0 ..
2 2 2]]} [id CV]
| | |AllocDiag{offset=0, axis1=0, axis2=1} [id CW] 28
| | | |Alloc [id CX] 24
| | | |Elemwise{sqr,no_inplace} [id CY] 15
| | | | |Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
| | | |TensorConstant{255} [id CZ]
| | |TensorConstant{[[1.e-06 0..0 1.e-06]]} [id DA]
| |TensorConstant{(1,) of 0} [id BO]
|TensorConstant{-234.32932596719152} [id DB]
|TensorConstant{0.5} [id I]
|InplaceDimShuffle{} [id DC] 67
| |Sum{axis=[1], acc_dtype=float64} [id DD] 64
| |Elemwise{Sqr}[(0, 0)] [id DE] 62
| |InplaceDimShuffle{1,0} [id DF] 58
| |SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True} [id DG] 56
| |Elemwise{Switch}[(0, 1)] [id DH] 54
| | |InplaceDimShuffle{x,x} [id DI] 51
| | | |All [id BS] 47
| | |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
| | |TensorConstant{(1, 1) of 1} [id DJ]
| |TensorConstant{[[-0.68584..56159017]]} [id DK]
|Sum{acc_dtype=float64} [id DL] 49
| |Elemwise{Log}[(0, 0)] [id DM] 45
| |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id BU] 41
|TensorConstant{-inf} [id L]
Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + Switch(i4, i5, Switch(i0, i6, i5)) + (i7 * i2) + i6)}}[(0, 7)] [id DN] 'λ_log___grad' 112
|Elemwise{ge,no_inplace} [id D] 14
|TensorConstant{-0.5} [id X]
|Elemwise{exp,no_inplace} [id E] 'λ_log___log' 1
|TensorConstant{0} [id K]
|Elemwise{eq,no_inplace} [id J] 13
|TensorConstant{0.0} [id G]
|(d__logp/dη_log___log_jacobian){1.0} [id CR]
|InplaceDimShuffle{} [id DO] 111
|Sum{axis=[0], acc_dtype=float64} [id DP] 110
|Elemwise{Composite{((-((((i0 * ((i1 * i2 * i3) + (i4 * i2 * i5)) * i6) / i7) + i8) * i6)) / i9)}}[(0, 3)] [id DQ] 109
|TensorConstant{(1, 1) of 2.0} [id DR]
|TensorConstant{(1, 1) of -0.5} [id BC]
|Elemwise{sqr,no_inplace} [id CL] 21
|InplaceDimShuffle{0,x} [id DS] 99
| |Sum{axis=[1], acc_dtype=float64} [id DT] 90
| |Elemwise{mul,no_inplace} [id DU] 85
| |Elemwise{Composite{AND(GE(i0, i1), LE(i0, i2))}} [id DV] 36
| | |Elemwise{Add}[(0, 0)] [id BY] 35
| | |TensorConstant{(1, 1) of 0.0} [id CJ]
| | |TensorConstant{(1, 1) of inf} [id CK]
| |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
| | |Elemwise{Invert}[(0, 0)] [id DX] 50
| | | |InplaceDimShuffle{x,x} [id DY] 46
| | | |Any [id DZ] 42
| | | |Elemwise{isnan,no_inplace} [id EA] 40
| | | |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
| | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EB] 77
| | | |InplaceDimShuffle{1,0} [id EC] 55
| | | | |Elemwise{switch,no_inplace} [id ED] 53
| | | | |Elemwise{Invert}[(0, 0)] [id DX] 50
| | | | |Cholesky{lower=True, destructive=False, on_error='nan'} [id BV] 39
| | | | |TensorConstant{(1, 1) of 1} [id DJ]
| | | |InplaceDimShuffle{1,0} [id EE] 76
| | | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EF] 75
| | | |InplaceDimShuffle{1,0} [id EC] 55
| | | |Elemwise{Composite{((i0 * i1) - i2)}}[(0, 0)] [id EG] 74
| | | |Dot22 [id EH] 69
| | | | |InplaceDimShuffle{1,0} [id EI] 66
| | | | | |Elemwise{Composite{Switch(i0, (i1 + Switch(i2, (i3 * i4), i5)), i6)}}[(0, 1)] [id EJ] 63
| | | | | |Elemwise{Invert}[(0, 0)] [id DX] 50
| | | | | |IncSubtensor{InplaceSet;:int64:, :int64:} [id EK] 52
| | | | | | |Alloc [id EL] 5
| | | | | | | |TensorConstant{(1, 1) of 0.0} [id EM]
| | | | | | | |TensorConstant{255} [id EN]
| | | | | | | |TensorConstant{255} [id EN]
| | | | | | |AllocDiag{offset=0, axis1=0, axis2=1} [id EO] 48
| | | | | | | |Elemwise{true_div} [id EP] 44
| | | | | | | |TensorConstant{(1,) of -1.0} [id EQ]
| | | | | | | |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id BU] 41
| | | | | | |ScalarConstant{255} [id ER]
| | | | | | |ScalarConstant{255} [id ER]
| | | | | |InplaceDimShuffle{x,x} [id DI] 51
| | | | | |Dot22Scalar [id ES] 61
| | | | | | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id ET] 60
| | | | | | | |InplaceDimShuffle{1,0} [id EU] 57
| | | | | | | | |Elemwise{Switch}[(0, 1)] [id DH] 54
| | | | | | | |Elemwise{mul,no_inplace} [id EV] 59
| | | | | | | |TensorConstant{(1, 1) of 2.0} [id DR]
| | | | | | | |TensorConstant{(1, 1) of -0.5} [id BC]
| | | | | | | |SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True} [id DG] 56
| | | | | | |InplaceDimShuffle{1,0} [id DF] 58
| | | | | | |TensorConstant{-1.0} [id EW]
| | | | | |TensorConstant{[[1. 0. 0...1. 1. 1.]]} [id EX]
| | | | | |TensorConstant{(1, 1) of 0.0} [id CJ]
| | | | | |TensorConstant{(1, 1) of 1} [id DJ]
| | | | |Elemwise{switch,no_inplace} [id ED] 53
| | | |TensorConstant{[[1. 1. 1...0. 0. 1.]]} [id EY]
| | | |InplaceDimShuffle{1,0} [id EZ] 73
| | | |AllocDiag{offset=0, axis1=0, axis2=1} [id FA] 72
| | | |Elemwise{Mul}[(0, 1)] [id FB] 71
| | | |TensorConstant{(1,) of 0.5} [id BQ]
| | | |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id FC] 68
| | | |Dot22 [id FD] 65
| | | |InplaceDimShuffle{1,0} [id EC] 55
| | | |Elemwise{Composite{Switch(i0, (i1 + Switch(i2, (i3 * i4), i5)), i6)}}[(0, 1)] [id EJ] 63
| | |InplaceDimShuffle{1,0} [id FE] 79
| | | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EB] 77
| | |TensorConstant{[[1. 0. 0...1. 1. 1.]]} [id EX]
| | |AllocDiag{offset=0, axis1=0, axis2=1} [id FF] 80
| | | |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id FG] 78
| | | |SolveTriangular{lower=False, trans=0, unit_diagonal=False, check_finite=True} [id EB] 77
| | |TensorConstant{[[nan]]} [id FH]
| |AdvancedSubtensor [id CN] 32
| |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
|TensorConstant{(1, 1) of -0.5} [id BC]
|InplaceDimShuffle{0,x} [id FI] 100
| |Sum{axis=[0], acc_dtype=float64} [id FJ] 91
| |Elemwise{mul,no_inplace} [id DU] 85
|TensorConstant{[[ 0.]
[ ...]
[84.]]} [id CB]
|InplaceDimShuffle{x,x} [id CC] 12
|InplaceDimShuffle{0,x} [id FK] 107
| |CGemv{inplace} [id FL] 103
| |CGemv{inplace} [id FM] 93
| | |AllocEmpty{dtype='float64'} [id FN] 7
| | | |TensorConstant{255} [id FO]
| | |(d__logp/dη_log___log_jacobian){1.0} [id CR]
| | |Elemwise{Mul}[(0, 1)] [id FP] 86
| | | |Elemwise{Composite{AND(GE(i0, i1), LE(i0, i2))}} [id DV] 36
| | | |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
| | | |Elemwise{sqr,no_inplace} [id CL] 21
| | | |AdvancedSubtensor [id CN] 32
| | | |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
| | |InplaceDimShuffle{0} [id CH] 26
| | |TensorConstant{0.0} [id FQ]
| |(d__logp/dη_log___log_jacobian){1.0} [id CR]
| |InplaceDimShuffle{1,0} [id FR] 92
| | |Elemwise{Mul}[(0, 1)] [id FP] 86
| |InplaceDimShuffle{0} [id CH] 26
| |(d__logp/dη_log___log_jacobian){1.0} [id CR]
|Elemwise{sqr,no_inplace} [id FS] 23
|InplaceDimShuffle{x,x} [id CC] 12
Elemwise{Composite{(Switch(i0, (-i1), i2) + Switch(i3, i2, Switch(i0, i4, i2)) + (i5 * i6 * i1 * i1) + i7)}}[(0, 6)] [id FT] 'η_log___grad' 98
|Elemwise{ge,no_inplace} [id N] 11
|Elemwise{exp,no_inplace} [id O] 'η_log___log' 0
|TensorConstant{0} [id K]
|Elemwise{eq,no_inplace} [id R] 10
|TensorConstant{2.0} [id S]
|TensorConstant{2.0} [id S]
|Sum{acc_dtype=float64} [id FU] 89
| |Elemwise{mul} [id FV] 84
| |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
| |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
| |AdvancedSubtensor [id CN] 32
|(d__logp/dη_log___log_jacobian){1.0} [id CR]
Elemwise{Composite{(Switch(i0, (i1 * i2 * i2), i3) + (i4 * i5 * i2 * i2) + i6)}}[(0, 5)] [id FW] 'σ_log___grad' 97
|Elemwise{ge,no_inplace} [id U] 16
|TensorConstant{-0.1111111111111111} [id FX]
|Elemwise{exp,no_inplace} [id V] 'σ_log___log' 2
|TensorConstant{0} [id K]
|TensorConstant{2.0} [id S]
|Sum{acc_dtype=float64} [id FY] 88
| |ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id FZ] 83
| |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
|(d__logp/dη_log___log_jacobian){1.0} [id CR]
Gemm{inplace} [id GA] 'W_grad' 105
|Gemm{no_inplace} [id GB] 96
| |W [id BE]
| |(d__logp/dη_log___log_jacobian){1.0} [id CR]
| |AdvancedIncSubtensor{inplace=True, set_instead_of_inc=False} [id GC] 87
| | |Alloc [id GD] 6
| | | |TensorConstant{(1, 1) of 0.0} [id EM]
| | | |TensorConstant{3} [id GE]
| | | |TensorConstant{3} [id GE]
| | |Elemwise{mul} [id GF] 82
| | | |Elemwise{Composite{Switch(i0, (((i1 + i2) * i3) - i4), i5)}} [id DW] 81
| | | |Elemwise{Composite{exp((i0 * clip(i1, i2, i3)))}}[(0, 1)] [id BX] 37
| | | |Elemwise{sqr,no_inplace} [id CL] 21
| | |TensorConstant{[[0]
[0]
.. [2]
[2]]} [id CU]
| | |TensorConstant{[[0 0 0 0 ..
2 2 2]]} [id CV]
| |W [id BE]
| |TensorConstant{-0.1111111111111111} [id FX]
|(d__logp/dη_log___log_jacobian){1.0} [id CR]
|InplaceDimShuffle{1,0} [id GG] 95
| |AdvancedIncSubtensor{inplace=True, set_instead_of_inc=False} [id GC] 87
|W [id BE]
|(d__logp/dη_log___log_jacobian){1.0} [id CR]
Elemwise{Composite{(Switch(i0, (-i1), i2) + Switch(i3, i2, Switch(i0, i4, i2)) + (i5 * i1) + i6)}} [id GH] 'κ_log___grad' 104
|Elemwise{ge,no_inplace} [id BI] 19
|Elemwise{exp,no_inplace} [id BJ] 'κ_log___log' 4
|TensorConstant{(1,) of 0} [id BO]
|Elemwise{eq,no_inplace} [id BN] 18
|TensorConstant{(1,) of 0.5} [id BQ]
|ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id GI] 94
| |AdvancedIncSubtensor{inplace=True, set_instead_of_inc=False} [id GC] 87
|TensorConstant{(1,) of 1.0} [id GJ]
and here is get_icm
function:
def get_icm(input_dim, kernel, W=None, kappa=None, B=None, active_dims=None, name='ICM'):
coreg = pm.gp.cov.Coregion(input_dim=input_dim, W=W, kappa=kappa, B=B, active_dims=active_dims)
icm_cov = kernel * coreg
return icm_cov
And indeed this is what get_icm does
@bpajusco Should be fixed on pytensor main now.
You still won't get great performance however, because there are still two Ops in your graph, (SolveTriangular
and AdvancedSubtensor
) that do not have a numba implementation and need to call into python.
@bwengals I was looking into the Coregion
covariance (essentially $I + WW^T$, where $W$ is low-rank), and it doesn't seem we are taking advantage of the low rank structure. Do you think it would make sense to extend BaseCovariance
, adding methods like those, that have a default implementation using factorizations, but could be overwritten by special covariance where we have better ways to do the computation?
Something along the lines of (with better names?)
logdet_and_solve
: Compute logdet(C)
and C^{-1}x
. This could then be used to implement the logp.factorization_product
: Compute $Zx$ where $C = ZZ^T$ is a factorization. This could be used for non-centered parametrizations.matvec
: $Cx$. Could for instance be used to iteratively compute $sqrt(C)x$?Hey thanks for tagging me, super interested in this. I resp here: https://github.com/pymc-devs/pymc/discussions/6615
I'm trying to use nutpie to speed up an ICM GP model. I've run the model through pymc without errors but once I call the command nutpie.compile_pymc_model(icm) I get the following error message:
And also pymc model as a reference: