pymc-devs / nutpie

Python wrapper for nuts-rs
MIT License
126 stars 11 forks source link

Vectorized Typing error from Numba #37

Closed giiyms closed 1 year ago

giiyms commented 1 year ago

Hello,

I am getting a vectorized typing error from numba.

Any ideas how to fix this?

import pymc as pm
import nutpie
import numpy as np

test_data = np.array([642.29826899, 667.29826899, 692.29826899])
perturbations = np.array(
    [288, 200, 288, 200, 200, 200, 200, 200, 1, 1, 1, 1, 1, 2, 200, 2, 200]
)

partial_derivatives = np.array(
    [
        [-5.16130147e-02],
        [2.52964940e-01],
        [1.75868011e-01],
        [1.67508144e-01],
        [1.22967884e-01],
        [8.50826581e-02],
        [4.51845806e-02],
        [1.52296378e-02],
        [8.54567700e-03],
        [6.86547627e-03],
        [4.94782294e-03],
        [2.71454319e-03],
        [9.46222115e-04],
        [-1.27348193e00],
        [1.27300075e-01],
        [-4.70858227e00],
        [1.72071494e-01],
    ]
)

init_temp = np.array([717.29826899])

with pm.Model() as model:
    # Assuming uniform priors for BCs
    params = [
        pm.Uniform(f"params{idx}", -param, param, shape=1)
        for idx, param in enumerate(perturbations)
    ]
    params_arr = pm.math.concatenate(params, axis=0)
    simulated_model = pm.Deterministic(
        f"thermalmodel",
        pm.math.dot(params_arr, pm.math.constant(partial_derivatives))
        + pm.math.constant(init_temp),
    )
    obs_sigma = pm.HalfCauchy(f"obs_sigma", beta=2) + pm.math.constant(6.3)
    observed = pm.StudentT(
        f"observed",
        nu=len(test_data) - 1,
        mu=simulated_model,
        sigma=obs_sigma,
        observed=test_data,
    )

compiled_model = nutpie.compile_pymc_model(model)

trace = nutpie.sample(
    compiled_model, draws=3000, tune=1000, chains=10, save_warmup=False
)
----------------
TypingError                               Traceback (most recent call last)
HOMEDIR\Dev\python\trame-htbctool\nutpie_bug_example.py in line 53
     44     obs_sigma = pm.HalfCauchy(f"obs_sigma", beta=2) + pm.math.constant(6.3)
     45     observed = pm.StudentT(
     46         f"observed",
     47         nu=len(test_data) - 1,
   (...)
     50         observed=test_data,
     51     )
---> 53 compiled_model = nutpie.compile_pymc_model(model)
     54 trace = nutpie.sample(
     55     compiled_model, draws=3000, tune=1000, chains=10, save_warmup=False
     56 )

File HOMEDIR\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py:121, in compile_pymc_model(model, **kwargs)
    116 user_data = make_user_data(logp_fn_pt, shared_data)
    118 logp_numba_raw, c_sig = _make_c_logp_func(
    119     n_dim, logp_fn, user_data, shared_logp, shared_data
    120 )
--> 121 logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
    123 def expand_draw(x, seed, chain, draw, *, shared_data):
    124     return expand_fn(x, **{name: shared_data[name] for name in shared_expand})[0]

File HOMEDIR\envs\nutpie_debug\lib\site-packages\numba\core\decorators.py:282, in cfunc.<locals>.wrapper(func)
...
File "..\..\..\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py", line 265:
        def extract_shared(x, user_data_):
            return inner(x)

Environment:



name: nutpie_debug
channels:
  - conda-forge
dependencies:
  - appdirs=1.4.4=pyh9f0ad1d_0
  - arviz=0.14.0=pyhd8ed1ab_0
  - asttokens=2.2.1=pyhd8ed1ab_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  - blas=2.0=netlib
  - brotli=1.0.9=hcfcfb64_8
  - brotli-bin=1.0.9=hcfcfb64_8
  - brotlipy=0.7.0=py310h8d17308_1005
  - bzip2=1.0.8=h8ffe710_4
  - ca-certificates=2022.12.7=h5b45459_0
  - cachetools=5.3.0=pyhd8ed1ab_0
  - certifi=2022.12.7=pyhd8ed1ab_0
  - cffi=1.15.1=py310h628cb3f_3
  - cftime=1.6.2=py310h9b08ddd_1
  - charset-normalizer=2.1.1=pyhd8ed1ab_0
  - cloudpickle=2.2.1=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - comm=0.1.2=pyhd8ed1ab_0
  - cons=0.4.5=pyhd8ed1ab_0
  - contourpy=1.0.7=py310h232114e_0
  - cryptography=39.0.1=py310h6e82f81_0
  - curl=7.87.0=h68f0423_0
  - cycler=0.11.0=pyhd8ed1ab_0
  - debugpy=1.6.6=py310h00ffb61_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - etuples=0.3.8=pyhd8ed1ab_0
  - executing=1.2.0=pyhd8ed1ab_0
  - fastprogress=1.0.3=pyhd8ed1ab_0
  - filelock=3.9.0=pyhd8ed1ab_0
  - fonttools=4.38.0=py310h8d17308_1
  - freetype=2.12.1=h546665d_1
  - hdf4=4.2.15=h1b1b6ef_5
  - hdf5=1.12.2=nompi_h57737ce_101
  - idna=3.4=pyhd8ed1ab_0
  - importlib-metadata=6.0.0=pyha770c72_0
  - importlib_metadata=6.0.0=hd8ed1ab_0
  - intel-openmp=2023.0.0=h57928b3_25922
  - ipykernel=6.21.1=pyh025b116_0
  - ipython=8.9.0=pyh08f2357_0
  - jedi=0.18.2=pyhd8ed1ab_0
  - jpeg=9e=h8ffe710_2
  - jupyter_client=8.0.2=pyhd8ed1ab_0
  - jupyter_core=5.2.0=py310h5588dad_0
  - kiwisolver=1.4.4=py310h232114e_1
  - krb5=1.20.1=heb0366b_0
  - lcms2=2.14=ha5c8aab_1
  - lerc=4.0.0=h63175ca_0
  - libaec=1.0.6=h63175ca_1
  - libblas=3.9.0=0_h8933c1f_netlib
  - libbrotlicommon=1.0.9=hcfcfb64_8
  - libbrotlidec=1.0.9=hcfcfb64_8
  - libbrotlienc=1.0.9=hcfcfb64_8
  - libcblas=3.9.0=0_h8933c1f_netlib
  - libcurl=7.87.0=h68f0423_0
  - libdeflate=1.17=hcfcfb64_0
  - libffi=3.4.2=h8ffe710_5
  - libhwloc=2.8.0=h039e092_1
  - libiconv=1.17=h8ffe710_0
  - liblapack=3.9.0=0_h8933c1f_netlib
  - liblapacke=3.9.0=0_h8933c1f_netlib
  - libnetcdf=4.8.1=nompi_h8c042bf_106
  - libpng=1.6.39=h19919ed_0
  - libpython=2.2=py310h5588dad_2
  - libsodium=1.0.18=h8d14728_1
  - libsqlite=3.40.0=hcfcfb64_0
  - libssh2=1.10.0=h9a1e1f7_3
  - libtiff=4.5.0=hf8721a0_2
  - libwebp-base=1.2.4=h8ffe710_0
  - libxcb=1.13=hcd874cb_1004
  - libxml2=2.10.3=hc3477c8_0
  - libzip=1.9.2=h519de47_1
  - libzlib=1.2.13=hcfcfb64_4
  - llvmlite=0.39.1=py310hb84602e_1
  - logical-unification=0.4.5=pyhd8ed1ab_0
  - m2w64-binutils=2.25.1=5
  - m2w64-bzip2=1.0.6=6
  - m2w64-crt-git=5.0.0.4636.2595836=2
  - m2w64-gcc=5.3.0=6
  - m2w64-gcc-ada=5.3.0=6
  - m2w64-gcc-fortran=5.3.0=6
  - m2w64-gcc-libgfortran=5.3.0=6
  - m2w64-gcc-libs=5.3.0=7
  - m2w64-gcc-libs-core=5.3.0=7
  - m2w64-gcc-objc=5.3.0=6
  - m2w64-gmp=6.1.0=2
  - m2w64-headers-git=5.0.0.4636.c0ad18a=2
  - m2w64-isl=0.16.1=2
  - m2w64-libiconv=1.14=6
  - m2w64-libmangle-git=5.0.0.4509.2e5a9a2=2
  - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
  - m2w64-make=4.1.2351.a80a8b8=2
  - m2w64-mpc=1.0.3=3
  - m2w64-mpfr=3.1.4=4
  - m2w64-pkg-config=0.29.1=2
  - m2w64-toolchain=5.3.0=7
  - m2w64-toolchain_win-64=2.4.0=0
  - m2w64-tools-git=5.0.0.4592.90b8472=2
  - m2w64-windows-default-manifest=6.4=3
  - m2w64-winpthreads-git=5.0.0.4634.697f757=2
  - m2w64-zlib=1.2.8=10
  - matplotlib-base=3.6.3=py310h51140c5_0
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - minikanren=1.0.3=pyhd8ed1ab_0
  - mkl=2022.2.1=h6a75c08_19751
  - mkl-service=2.4.0=py310h84a9c25_0
  - msys2-conda-epoch=20160418=1
  - multipledispatch=0.6.0=py_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - nest-asyncio=1.5.6=pyhd8ed1ab_0
  - netcdf4=1.6.2=nompi_py310h459bb5f_100
  - numba=0.56.4=py310h19bcfe9_0
  - numpy=1.23.5=py310h4a8f9c9_0
  - nutpie=0.5.1=py310h96eb580_0
  - openjpeg=2.5.0=ha2aaf27_2
  - openssl=3.0.8=hcfcfb64_0
  - packaging=23.0=pyhd8ed1ab_0
  - pandas=1.5.3=py310h1c4a608_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pickleshare=0.7.5=py_1003
  - pillow=9.4.0=py310hdbb7713_1
  - pip=23.0=pyhd8ed1ab_0
  - platformdirs=3.0.0=pyhd8ed1ab_0
  - pooch=1.6.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.36=pyha770c72_0
  - psutil=5.9.4=py310h8d17308_0
  - pthread-stubs=0.4=hcd874cb_1001
  - pthreads-win32=2.9.1=hfa6e2cd_3
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.14.0=pyhd8ed1ab_0
  - pymc=5.0.2=hd8ed1ab_0
  - pymc-base=5.0.2=pyhd8ed1ab_0
  - pyopenssl=23.0.0=pyhd8ed1ab_0
  - pyparsing=3.0.9=pyhd8ed1ab_0
  - pysocks=1.7.1=pyh0701188_6
  - pytensor=2.9.1=py310h53af72e_0
  - pytensor-base=2.9.1=py310h00ffb61_0
  - python=3.10.9=h4de0772_0_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.10=3_cp310
  - pytz=2022.7.1=pyhd8ed1ab_0
  - pywin32=304=py310h00ffb61_2
  - pyzmq=25.0.0=py310hcd737a0_0
  - requests=2.28.2=pyhd8ed1ab_0
  - scipy=1.10.0=py310h578b7cb_2
  - setuptools=67.1.0=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tbb=2021.7.0=h91493d7_1
  - tk=8.6.12=h8ffe710_0
  - toolz=0.12.0=pyhd8ed1ab_0
  - tornado=6.2=py310h8d17308_1
  - traitlets=5.9.0=pyhd8ed1ab_0
  - typing-extensions=4.4.0=hd8ed1ab_0
  - typing_extensions=4.4.0=pyha770c72_0
  - tzdata=2022g=h191b570_0
  - ucrt=10.0.22621.0=h57928b3_0
  - unicodedata2=15.0.0=py310h8d17308_0
  - urllib3=1.26.14=pyhd8ed1ab_0
  - vc=14.3=hb6edc58_10
  - vs2015_runtime=14.34.31931=h4c5c07a_10
  - wcwidth=0.2.6=pyhd8ed1ab_0
  - wheel=0.38.4=pyhd8ed1ab_0
  - win_inet_pton=1.1.0=pyhd8ed1ab_6
  - xarray=2023.2.0=pyhd8ed1ab_0
  - xarray-einstats=0.5.1=pyhd8ed1ab_0
  - xorg-libxau=1.0.9=hcd874cb_0
  - xorg-libxdmcp=1.1.3=hcd874cb_0
  - xz=5.2.6=h8d14728_0
  - zeromq=4.3.4=h0e60522_1
  - zipp=3.13.0=pyhd8ed1ab_0
  - zlib=1.2.13=hcfcfb64_4
  - zstd=1.5.2=h12be248_6
prefix: C:\envs\nutpie_debug
giiyms commented 1 year ago

Longer error message:

C:\envs\nutpie_debug\lib\site-packages\pymc\util.py:501: FutureWarning: The tag attribute observations is deprecated. Use model.rvs_to_values[rv] instead
  warnings.warn(
Backend TkAgg is interactive backend. Turning interactive mode on.
Traceback (most recent call last):
  File "C:\envs\nutpie_debug\lib\runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\envs\nutpie_debug\lib\runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "c:\.vscode\extensions\ms-python.python-2023.2.0\pythonFiles\lib\python\debugpy\__main__.py", line 39, in 
<module>
    cli.main()
  File "c:\.vscode\extensions\ms-python.python-2023.2.0\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 430, in main
    run()
  File "c:\.vscode\extensions\ms-python.python-2023.2.0\pythonFiles\lib\python\debugpy/..\debugpy\server\cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "c:\.vscode\extensions\ms-python.python-2023.2.0\pythonFiles\lib\python\debugpy\_vendored\pydevd\_pydevd_bundle\pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "c:\.vscode\extensions\ms-python.python-2023.2.0\pythonFiles\lib\python\debugpy\_vendored\pydevd\_pydevd_bundle\pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "c:\.vscode\extensions\ms-python.python-2023.2.0\pythonFiles\lib\python\debugpy\_vendored\pydevd\_pydevd_bundle\pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "C:\Dev\python\trame-htbctool\nutpie_bug_example.py", line 57, in <module>
    compiled_model = nutpie.compile_pymc_model(model)
  File "C:\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py", line 121, in compile_pymc_model
    logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\decorators.py", line 282, in wrapper
    res.compile()
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock   
    return func(*args, **kwargs)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\ccallback.py", line 67, in compile
    cres = self._compile_uncached()
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\ccallback.py", line 81, in _compile_uncached
    return self._compiler.compile(sig.args, sig.return_type)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\dispatcher.py", line 129, in compile
    raise retval
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\dispatcher.py", line 139, in _compile_cached
    retval = self._compile_core(args, return_type)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\dispatcher.py", line 152, in _compile_core
    cres = compiler.compile_extra(self.targetdescr.typing_context,
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler.py", line 716, in compile_extra
    return pipeline.compile_extra(func)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler.py", line 452, in compile_extra
    return self._compile_bytecode()
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler.py", line 520, in _compile_bytecode
    return self._compile_core()
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler.py", line 499, in _compile_core
    raise e
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler.py", line 486, in _compile_core
    pm.run(self.state)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler_machinery.py", line 368, in run
    raise patched_exception
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock   
    return func(*args, **kwargs)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\typed_passes.py", line 105, in run_pass
    typemap, return_type, calltypes, errs = type_inference_stage(
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\typed_passes.py", line 83, in type_inference_stage     
    errs = infer.propagate(raise_errors=raise_errors)
  File "C:\envs\nutpie_debug\lib\site-packages\numba\core\typeinfer.py", line 1086, in propagate
    raise errors[0]
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function numba_funcify_Elemwise.<locals>.elemwise at 0x0000022D579439A0>) found for signature:

 >>> elemwise(readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), readonly array(float64, 0d, C), readonly array(float32, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, 
C))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'numba_funcify_Elemwise.<locals>.ov_elemwise': File: pytensor\link\numba\dispatch\elemwise.py: Line 687.
    With argument(s): '(readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), readonly array(float64, 0d, C), readonly array(float32, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C))':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   No implementation of function Function(<intrinsic _vectorized>) found for signature:

    >>> _vectorized(type(CPUDispatcher(<function numba_funcified_fgraph at 0x0000022D57941E10>)), Literal[str](gASVDgAAAAAAAAAoKSkpKSkpKSkpKXSULg==
   ), Literal[str](gASVBAAAAAAAAAAphZQu
   ), Literal[str](gASVDQAAAAAAAACMB2Zsb2F0NjSUhZQu
   ), Literal[str](gASVCQAAAAAAAABLAEsIhpSFlC4=
   ), StarArgTuple(readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), readonly array(float64, 0d, C), readonly array(float32, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C)))

   There are 2 candidate implementations:
         - Of which 1 did not match due to:
         Intrinsic in function '_vectorized': File: pytensor\link\numba\dispatch\elemwise.py: Line 466.
           With argument(s): '(type(CPUDispatcher(<function numba_funcified_fgraph at 0x0000022D57941E10>)), unicode_type, unicode_type, unicode_type, unicode_type, StarArgTuple(readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), readonly array(float64, 0d, C), readonly array(float32, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C)))':
          Rejected as the implementation raised a specific error:
            TypingError: input_bc_patterns must be literal.
     raised from C:\envs\nutpie_debug\lib\site-packages\pytensor\link\numba\dispatch\elemwise.py:486
         - Of which 1 did not match due to:
         Intrinsic in function '_vectorized': File: pytensor\link\numba\dispatch\elemwise.py: Line 466.
           With argument(s): '(type(CPUDispatcher(<function numba_funcified_fgraph at 0x0000022D57941E10>)), Literal[str](gASVDgAAAAAAAAAoKSkpKSkpKSkpKXSULg==
         ), Literal[str](gASVBAAAAAAAAAAphZQu
         ), Literal[str](gASVDQAAAAAAAACMB2Zsb2F0NjSUhZQu
         ), Literal[str](gASVCQAAAAAAAABLAEsIhpSFlC4=
         ), StarArgTuple(readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), float64, readonly array(float64, 0d, C), readonly array(float64, 0d, C), readonly array(float32, 0d, C), readonly array(float64, 0d, C), float64, array(float64, 0d, C)))':
          Rejected as the implementation raised a specific error:
            TypingError: Inputs to elemwise must be arrays.
     raised from C:\envs\nutpie_debug\lib\site-packages\pytensor\link\numba\dispatch\elemwise.py:514

   During: resolving callee type: Function(<intrinsic _vectorized>)
   During: typing of call at C:\envs\nutpie_debug\lib\site-packages\pytensor\link\numba\dispatch\elemwise.py (648)

   File "..\..\..\envs\nutpie_debug\lib\site-packages\pytensor\link\numba\dispatch\elemwise.py", line 648:
       def elemwise_wrapper(*inputs):
           return _vectorized(
           ^

  raised from C:\envs\nutpie_debug\lib\site-packages\numba\core\typeinfer.py:1086

During: resolving callee type: Function(<function numba_funcify_Elemwise.<locals>.elemwise at 0x0000022D579439A0>)
During: typing of call at C:\AppData\Local\Temp\tmpgsgit4yd (445)

File "..\..\..\AppData\Local\Temp\tmpgsgit4yd", line 445:
def numba_funcified_fgraph(_joined_variables):
    <source elided>
    # Elemwise{Composite{((i0 + Switch(AND(GE(i1, i2), LE(i3, i4)), i5, i6)) - ((i7 * scalar_softplus(i8)) + i9))}}[(0, 8)](TensorConstant{5.991464547107982}, InplaceDimShuffle{}.0, TensorConstant{-200.0}, InplaceDimShuffle{}.0, TensorConstant{200.0}, 
TensorConstant{-5.991464547107982}, TensorConstant{-inf}, TensorConstant{2.0}, InplaceDimShuffle{}.0, Reshape{0}.0)
    tensor_variable_186 = elemwise_102(tensor_constant_31, tensor_variable_126, tensor_constant_32, tensor_variable_126, tensor_constant_33, tensor_constant_34, tensor_constant_25, tensor_constant_35, tensor_variable_91, tensor_variable_18)
    ^

During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x0000022D58135BD0>))
During: typing of call at C:\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py (265)

During: resolving callee type: type(CPUDispatcher(<function numba_funcified_fgraph at 0x0000022D58135BD0>))
During: typing of call at C:\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py (265)

File "..\..\..\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py", line 265:
        def extract_shared(x, user_data_):
            return inner(x)
            ^
aseyboldt commented 1 year ago

That you for reporting this.

It is indeed a problem in pytensor. Shorter reproducer:

with pm.Model() as model:
    pm.Uniform("x", 0., 1., shape=1)

# Fails
func = pytensor.function(model.value_vars, model.logp(), mode="NUMBA")
func(np.zeros(1))

# Works
#func = pytensor.function(model.value_vars, model.logp())
#func(np.zeros(1))

In the meantime you can work around this by defining params as an array in the first place:

params_arr = pm.Uniform(f"params", -perturbations, perturbations, shape=len(perturbations))

Or if you want to have individual param distributions by defining those as scalars and then stacking them:

    params = [
        pm.Uniform(f"params{idx}", -param, param)
        for idx, param in enumerate(perturbations)
    ]
    params_arr = pm.math.stack(params, axis=0)
twiecki commented 1 year ago

Should we open a pytensor issue?

giiyms commented 1 year ago

Hi thanks for quick responses. I can't use the shape property as not all input params will be uniform.

The pm.math.stack(params, axis=0) does not work for me.

ValueError: Size length is incompatible with batched dimensions of parameter 1 thermalmodel:
len(size) = 1, len(batched dims thermalmodel) = 2. Size length must be 0 or >= 2

Is there anyway to see debug and see what the current value of params_arr is?

If I do:

params_arr = pm.math.stack(params, axis=0).reshape(-1,1)

This allows the model is start compiling but then I get the same Typing error as before.

aseyboldt commented 1 year ago

@twiecki

Should we open a pytensor issue? I opened a PR already...

@giiyms That sounds to me like some random variable in params still has rank 1, (ie you are passing shape=1). The default is shape=(), which is different. Could you double check if that's the case?

giiyms commented 1 year ago

@aseyboldt you are correct thank you. It works again. Appreciate the help.

Should I close this now or do you want to close it once the pytensor bug is fixed?