pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.47k stars 1.97k forks source link

pm.Data using lots of memory on v4 #5468

Closed fbarfi closed 2 years ago

fbarfi commented 2 years ago

Have been using a pymc3 code on apple M1 for two months with no problems. Sought to try It on pymc v4. everything went well, except that RAM usage reached the upper limit 64G in a few seconds when the MCMC started. On Pymc3 it does not go beyond 22 G. Just puzzled.

ricardoV94 commented 2 years ago

Can you share some code to reproduce the issue you are seeing?

fbarfi commented 2 years ago

Sure, here it is attached. Thanks for your quick response.

Badredine Arfi

Professor of Political Science University of Florida Research Foundation Professor PhD in Theoretical Physics PhD in Political Science Department of Political Science, University of Florida Phone: (352) 273 2357 Email: @.**@.> Website: http://people.clas.ufl.edu/barfi/

ericmjl commented 2 years ago

@fbarfi thanks for responding. We'd like to help; however, replying with email to a GitHub thread apparently doesn't result in the attachment coming up.

The easiest way for us is if you could provide a minimal reproducible example that we can execute in its entirety. The two easiest ways are to upload a jupyter notebook to GitHub Gist or to paste the code here.

The other intricacy that I can think of here is Rosetta. Can you confirm that you're running PyMC under native ARM-compiled Python and not under Rosetta emulation? One way to check is if you fire up a Python interpreter and then immediately import pymc:

❯ python
Python 3.9.10 | packaged by conda-forge | (main, Feb  1 2022, 21:25:34) 
[Clang 11.1.0 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import pymc
>>> import platform
>>> platform.architecture()
('64bit', '')
>>> 

The other thing that could be handy for us is to paste in a screenshot of Activity Monitor. Ensure that there is only one Python session running, and then search for Python:

image

If the "Kind" column is Apple, then we're running native. I would strongly recommend running in native mode rather than emulated; as a sane default, this is just better.

Not many of us have M1 Macs, and some of us are restricted to emulation, which is why we should rule out this hypothesis early on by checking first. Other than that, a copy/paste-able example would be greatly helpful.

Finally, we'd also need to know the other details - how did you install PyMC, for instance? And what version of PyMC v4 are you running?

fbarfi commented 2 years ago

Thank you for your response. Yes I can confirm that I am using pymc natively and not through Rosetta (I never had to use it from day one). I will produce a minimal reproducible example in a notebook and upload it to GitHub Gist tomorrow morning. Here is a capture of the import in a terminal window.

fbarfi commented 2 years ago

Great news! I found the source of my predicament.

I am using a number of pm.Data in the model. I run again the code using this and it quickly consumed the 64G of ram and crushed. I changed it to pm.ConstantData and Eureka! Issue completely resolved and the ram usage is comparable to pymc3. The amazing thing is the speed is a lot, lot faster!

The only issue I have now is the following error which I think has to do with scipy and not pymc. I have been looking online for this for many months (before I switched to pymc) and to no avail. Here is a picture of the message, you might have seen elsewhere. Thank you!

[Text Description automatically generated]

@barfi

twiecki commented 2 years ago

Great, glad it's working. The scipy issue should be resolved in the next version.

fbarfi commented 2 years ago

I look forward to that. It does not seem to affect anything except annoyingly cluttering the output either in Jupyter notebooks or in the log output when I use pymc in the background.

By the way, I have been using pymc3 for some two years by now and I am really impressed by the improvement in speed that PyMC v4 produces – what used to take long hours for running my code, now it is executed in some ten minutes or so! Aesara is indeed a great step up from Theano and the other improvements also in terms of sampling.

Many kudos to you and your colleagues.

Just to let you know, I teach advanced statistics to social scientists, including Bayesian statistics. After I encountered pymc3 I started teaching Bayesian statistics to social scientists using Python and pymc; to put this into context, very few social scientists know python; they instead heavily rely on R and some other commercial software like stata and spss.

Personally, I migrated from Fortran to Matlab to R; and now I am fully using python (and somewhat c++ and Julia).

Thank you!!! And best wishes.

@barfi

fbarfi commented 2 years ago

I was able to solve the scipy issue (erroe messages after pm.sample) by uninstalling the version scipy 1.7.3 that is installed with pymc 4.0b2. Then I installed scipy through

pip install scipy==1.8.0

And all error messages went away!

Best.

@barfi

fbarfi commented 2 years ago

Apologies, I misspoke in my previous email: I did not uninstall scipy 1.7.3 through conda. Instead: I was able to solve the scipy issue (erroe messages after pm.sample) jut by doing

pip install scipy==1.8.0

And all error messages went away!

Best.

@barfi

twiecki commented 2 years ago

Great, thanks for reporting back.

fbarfi commented 2 years ago

You’re welcome. Quick question if I may: is jax automatically used in pymc 4 or should it be configured in aesara? If yes, how? I could not find any good example or tutorial. Thanks.

twiecki commented 2 years ago

No you have to call the pm.sampling_jax.sample_numpyro_nuts() function.

fbarfi commented 2 years ago

Tried it but got the error message: AttributeError Traceback (most recent call last) AttributeError: module 'pymc' has no attribute 'sampling_jax'

fbarfi commented 2 years ago

checked the source code and it does have the file sampling_jax.py what am I missing here? thanks.

ricardoV94 commented 2 years ago

checked the source code and it does have the file sampling_jax.py what am I missing here? thanks.

You have to import it explicitly:

import pymc.sampling_jax
pymc.sampling_jax.sample_numpyro_nuts(...)

Or

from pymc.sampling_jax import sample_numpyro_nuts
sample_numpyro_nuts(...)
fbarfi commented 2 years ago

thanks. I had just figured out that and I did it before your response and now it is working as it should. thanks. Now I am having an other error message which is more specific: NotImplementedError: No JAX conversion for the given Op: LogDet

Thiis is because I am using in the code: pm.math.LogDet

I could not find any LogDet or Det in aesara. Any suggestions? thanks in advance.

twiecki commented 2 years ago

@fbarfi Indeed it seems we need to add this Op to aesara, want to create an issue / work on it? would a great learning experience.

ricardoV94 commented 2 years ago

I am not sure that is needed. You can compute it using Aesara primitives:

import aesara.tensor as at

def log_det(x):
    return at.sum(at.log(at.abs(at.linalg.svd(x, compute_uv=False))))

Maybe the gradient is missing? Yup it's the gradient

Anyway, you shouldn't need the gradients for sample_numpyro_nuts

twiecki commented 2 years ago

@ricardoV94 Nice, we should probably replace this in math.py then.

ricardoV94 commented 2 years ago

@ricardoV94 Nice, we should probably replace this in math.py then.

We might have it there because of the gradient... it's a question of checking if we use it anywhere in the codebase

fbarfi commented 2 years ago

@ricardoV94 Nice, we should probably replace this in math.py then.

We might have it there because of the gradient... it's a question of checking if we use it anywhere in the codebase

Just tried it. Wonderful suggestion. thanks.

fbarfi commented 2 years ago

sampling_jax works wonderfully on one of my models (joint model of survival analysis) with much increase in speed. FYI: using pymc on Apple M1 Max.

on another model. I got this error message: (seemingly aesara related) 141 @singledispatch 142 def jax_funcify(op, node=None, storage_map=None, **kwargs): 143 """Create a JAX compatible function from an Aesara Op.""" --> 144 raise NotImplementedError(f"No JAX conversion for the given Op: {op}")

NotImplementedError: No JAX conversion for the given Op: BroadcastTo

Thanks in advance for any clues.

twiecki commented 2 years ago

@fbarfi Seems like we need to add a JAX implementation for that Op --> aesara issue.

fbarfi commented 2 years ago

was able to resolve this issue by refraining from using any pm.math function. Using instead aesara functions everywhere in the model and the issue disappeared.

Now I have a new issue with the Polyagamma distribution:

141 @singledispatch 142 def jax_funcify(op, node=None, storage_map=None, **kwargs): 143 """Create a JAX compatible function from an Aesara Op.""" --> 144 raise NotImplementedError(f"No JAX conversion for the given Op: {op}")

NotImplementedError: No JAX conversion for the given Op: _PolyaGammaLogDistFunc{get_pdf=True}

Not using JAX is one solution but it takes a long time and overwhelms my 64G of ram. before it reaches any convergence -- the kernel crashes ( in Jupiter and in running the code through nohup in the background).

fbarfi commented 2 years ago

tried pm.Gamma and pm.Beta but still got a problem 141 @singledispatch 142 def jax_funcify(op, node=None, storage_map=None, **kwargs): 143 """Create a JAX compatible function from an Aesara Op.""" --> 144 raise NotImplementedError(f"No JAX conversion for the given Op: {op}")

NotImplementedError: No JAX conversion for the given Op: BroadcastTo

twiecki commented 2 years ago

@fbarfi Added an aesara issue here: https://github.com/aesara-devs/aesara/issues/859

danhphan commented 2 years ago

Hi @fbarfi I am working on this aesara-devs/aesara/issues/859 issue.

Could you post the codes that can replicate this errors:

tried pm.Gamma and pm.Beta but still got a problem 141 @singledispatch 142 def jax_funcify(op, node=None, storage_map=None, **kwargs): 143 """Create a JAX compatible function from an Aesara Op.""" --> 144 raise NotImplementedError(f"No JAX conversion for the given Op: {op}")

I would like to use it to test the new implementation of JAX conversion for the BroadcastTo Op

Thank you.

fbarfi commented 2 years ago

Thanks @danhphan Here it is with the exact error message.

import warnings
warnings.filterwarnings("ignore")

import pymc as pm
from pymc import sampling_jax

import aesara
import aesara.tensor as at
from aesara.tensor.nlinalg import tensorinv,matrix_dot,matrix_inverse
from aesara.tensor.nnet.basic import sigmoid

import statsmodels.api as sm
import numpy as np
import pandas as pd

import arviz as az
from polyagamma import random_polyagamma

def ols_model(X,y):
    data = pd.DataFrame(data = X)
    res = sm.OLS(y, data).fit()
    cov = res._results.normalized_cov_params
    coef = res._results.params
    return coef, cov

def logp(ρ, tt_W, X_β,Ω,z):
    log_det  = at.sum(at.log(at.abs(at.linalg.svd((1. - ρ*tt_W), 
                                                  compute_uv=False))))

    return log_det -.5* matrix_dot(((1. - ρ*tt_W)@z - X_β ).T, Ω,
                                     ((1. - ρ*tt_W)@z - X_β ))

def All_func(ρ, β, tt_W, tt_X, tt_y,μ_0, σ_0):
    A = 1. - ρ*tt_W

    inv_A = matrix_inverse(A)

    X_β = tt_X @ β
    η = (inv_A@inv_A) @ X_β 

    ω = at.as_tensor_variable(random_polyagamma(h=1,z=η.eval(),
                                                size=η.eval().shape[0]))

    Ω = at.basic.AllocDiag()(ω)
    z = (tt_y - .5*at.ones(η.shape))/ω
    σ_bar = matrix_inverse(matrix_dot(tt_X.T, Ω,tt_X)+ matrix_inverse(σ_0))  
    sd_bar = at.basic.extract_diag(σ_bar)
    μ_bar = σ_bar @(matrix_dot(tt_X.T,Ω,A) @ z + matrix_inverse(σ_0) @ μ_0 )

    return X_β, Ω, z, sd_bar, μ_bar,η

def model():

    with pm.Model(coords=coords) as SAR_model:

        tt_X = pm.ConstantData('tt_X',X, dims = ('X_dim_0','X_dim_1'))
        tt_y = pm.ConstantData('tt_y',y, dims = 'X_dim_0')
        tt_W = pm.ConstantData('tt_W',W, dims = ('X_dim_0','W_dim_2'))
        μ_0 = pm.ConstantData('μ_0', μ_ols,dims = 'X_dim_1')
        σ_0 = pm.ConstantData('σ_0', σ_ols ,dims = ('σ_dim_0','σ_dim_1' ))

        ξ = pm.Exponential('ξ', 1.)
        ρ = pm.Normal('ρ',0.,ξ) 

        β = pm.MvNormal('β', mu = μ_ols,cov =σ_ols ,dims = 'X_dim_1')

        X_β, Ω, z, sd_bar, μ_bar, η =  All_func(ρ, β, tt_W, tt_X,tt_y, μ_0, σ_0 )

        β_like = pm.Normal('β_like', mu = μ_bar,sd = sd_bar)

        ρ_like = pm.Potential('ρ_like', logp(ρ, tt_W, X_β,Ω,z))

        y_pred = pm.Binomial('y_pred',tt_N,p=sigmoid(η),observed = y,
                                 dims= 'X_dim_0')
    return SAR_model

if __name__ == "__main__":

    random_seed=1234
    N = 200
    m = 5
    X = np.random.rand(N,m)
    W_ = np.random.rand(N,N)
    W_l = np.tril(W_,k=-1)
    W = W_l+ W_l.T
    y = np.random.randint(low=0,high=2,size=N)

    μ_ols, σ_ols =  ols_model(X,y)
    tt_N =at.as_tensor_variable(X.shape[0],dtype='int64')

    coords = {'X_dim_0': np.arange(X.shape[0]),
         'X_dim_1': np.arange(X.shape[1]),
          'V_dim_1': np.arange(X.shape[1]),
         'W_dim_2': np.arange(X.shape[0]),
         'σ_dim_0':np.arange(X.shape[1]),
         'σ_dim_1': np.arange(X.shape[1])}

    SAR_model = model()

    with SAR_model:
        idata = pm.sampling_jax.sample_numpyro_nuts(draws=100,tune=100,
                                           random_seed=random_seed) 

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [1], in <module>
    100 SAR_model = model()
    102 with SAR_model:
--> 103     idata = pm.sampling_jax.sample_numpyro_nuts(draws=100,tune=100,
    104                                        random_seed=random_seed)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling_jax.py:231, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
    222 print("Compiling...", file=sys.stdout)
    224 init_params = _get_batched_jittered_initial_points(
    225     model=model,
    226     chains=chains,
    227     initvals=initvals,
    228     random_seed=random_seed,
    229 )
--> 231 logp_fn = get_jaxified_logp(model)
    233 if nuts_kwargs is None:
    234     nuts_kwargs = {}

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling_jax.py:102, in get_jaxified_logp(model)
    100 def get_jaxified_logp(model: Model) -> Callable:
--> 102     logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()])
    104     def logp_fn_wrap(x):
    105         # NumPyro expects a scalar potential with the opposite sign of model.logpt
    106         res = logp_fn(*x)[0]

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling_jax.py:97, in get_jaxified_graph(inputs, outputs)
     94 mode.JAX.optimizer.optimize(fgraph)
     96 # We now jaxify the optimized fgraph
---> 97 return jax_funcify(fgraph)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/functools.py:888, in singledispatch.<locals>.wrapper(*args, **kw)
    884 if not args:
    885     raise TypeError(f'{funcname} requires at least '
    886                     '1 positional argument')
--> 888 return dispatch(args[0].__class__)(*args, **kw)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:641, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    634 @jax_funcify.register(FunctionGraph)
    635 def jax_funcify_FunctionGraph(
    636     fgraph,
   (...)
    639     **kwargs,
    640 ):
--> 641     return fgraph_to_python(
    642         fgraph,
    643         jax_funcify,
    644         type_conversion_fn=jax_typify,
    645         fgraph_name=fgraph_name,
    646         **kwargs,
    647     )

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/utils.py:723, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    721 body_assigns = []
    722 for node in order:
--> 723     compiled_func = op_conversion_fn(
    724         node.op, node=node, storage_map=storage_map, **kwargs
    725     )
    727     # Create a local alias with a unique name
    728     local_compiled_func_name = unique_name(compiled_func)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/functools.py:888, in singledispatch.<locals>.wrapper(*args, **kw)
    884 if not args:
    885     raise TypeError(f'{funcname} requires at least '
    886                     '1 positional argument')
--> 888 return dispatch(args[0].__class__)(*args, **kw)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:143, in jax_funcify(op, node, storage_map, **kwargs)
    140 @singledispatch
    141 def jax_funcify(op, node=None, storage_map=None, **kwargs):
    142     """Create a JAX compatible function from an Aesara `Op`."""
--> 143     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True}
fbarfi commented 2 years ago

The other version of my code is as follows. I used pm.Normal instead of pm.MvNormal for beta. And pm.PolyaGamma instead of random_polyagamma for omega. Now the error comes from PolyaGamma as you see:


import warnings
warnings.filterwarnings("ignore")

import pymc as pm
from pymc import sampling_jax

import aesara
import aesara.tensor as at
from aesara.tensor.nlinalg import tensorinv,matrix_dot,matrix_inverse
from aesara.tensor.nnet.basic import sigmoid

import statsmodels.api as sm
import numpy as np
import pandas as pd

import arviz as az
# from polyagamma import random_polyagamma

def ols_model(X,y):
    data = pd.DataFrame(data = X)
    res = sm.OLS(y, data).fit()
    cov = res._results.normalized_cov_params
    coef = res._results.params
    return coef, cov

def logp(ρ, tt_W, X_β,Ω,z):
    log_det  = at.sum(at.log(at.abs(at.linalg.svd((1. - ρ*tt_W), 
                                                  compute_uv=False))))

    return log_det -.5* matrix_dot(((1. - ρ*tt_W)@z - X_β ).T, Ω,
                                     ((1. - ρ*tt_W)@z - X_β ))

def model():

    with pm.Model(coords=coords) as SAR_model:

        tt_X = pm.ConstantData('tt_X',X, dims = ('X_dim_0','X_dim_1'))
        tt_y = pm.ConstantData('tt_y',y, dims = 'X_dim_0')
        tt_W = pm.ConstantData('tt_W',W, dims = ('X_dim_0','W_dim_2'))
        μ_0 = pm.ConstantData('μ_0', μ_ols,dims = 'X_dim_1')
        σ_0 = pm.ConstantData('σ_0', σ_ols ,dims = ('σ_dim_0','σ_dim_1' ))

        ξ = pm.Exponential('ξ', 1.)
        ρ = pm.Normal('ρ',0.,ξ) 

        β = pm.Normal('β', mu = μ_ols, sd = np.diag(σ_ols) ,dims = 'X_dim_1')

        A = 1. - ρ*tt_W

        inv_A = matrix_inverse(A)

        X_β = tt_X @ β
        η = (inv_A@inv_A) @ X_β 

        ω = pm.PolyaGamma('ω',h=1,z=η,size=N)

        Ω = at.basic.AllocDiag()(ω)
        z = (tt_y - .5*at.ones(η.shape))/ω
        σ_bar = matrix_inverse(matrix_dot(tt_X.T, Ω,tt_X)+ matrix_inverse(σ_0))  
        sd_bar = at.basic.extract_diag(σ_bar)
        μ_bar = σ_bar @(matrix_dot(tt_X.T,Ω,A) @ z + matrix_inverse(σ_0) @ μ_0 )

        β_like = pm.Normal('β_like', mu = μ_bar,sd = sd_bar)

        ρ_like = pm.Potential('ρ_like', logp(ρ, tt_W, X_β,Ω,z))

        y_pred = pm.Binomial('y_pred',tt_N,p=sigmoid(η),observed = y,
                                 dims= 'X_dim_0')
    return SAR_model

if __name__ == "__main__":

    random_seed=1234
    N = 200
    m = 5
    X = np.random.rand(N,m)
    W_ = np.random.rand(N,N)
    W_l = np.tril(W_,k=-1)
    W = W_l+ W_l.T
    y = np.random.randint(low=0,high=2,size=N)

    μ_ols, σ_ols =  ols_model(X,y)
    tt_N =at.as_tensor_variable(X.shape[0],dtype='int64')

    coords = {'X_dim_0': np.arange(X.shape[0]),
         'X_dim_1': np.arange(X.shape[1]),
          'V_dim_1': np.arange(X.shape[1]),
         'W_dim_2': np.arange(X.shape[0]),
         'σ_dim_0':np.arange(X.shape[1]),
         'σ_dim_1': np.arange(X.shape[1])}

    SAR_model = model()

    with SAR_model:
        idata = pm.sampling_jax.sample_numpyro_nuts(draws=100,tune=100,
                                           random_seed=random_seed) 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Input In [1], in <module>
    114 SAR_model = model()
    116 with SAR_model:
--> 117     idata = pm.sampling_jax.sample_numpyro_nuts(draws=100,tune=100,
    118                                        random_seed=random_seed)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling_jax.py:231, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
    222 print("Compiling...", file=sys.stdout)
    224 init_params = _get_batched_jittered_initial_points(
    225     model=model,
    226     chains=chains,
    227     initvals=initvals,
    228     random_seed=random_seed,
    229 )
--> 231 logp_fn = get_jaxified_logp(model)
    233 if nuts_kwargs is None:
    234     nuts_kwargs = {}

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling_jax.py:102, in get_jaxified_logp(model)
    100 def get_jaxified_logp(model: Model) -> Callable:
--> 102     logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model.logpt()])
    104     def logp_fn_wrap(x):
    105         # NumPyro expects a scalar potential with the opposite sign of model.logpt
    106         res = logp_fn(*x)[0]

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/pymc/sampling_jax.py:97, in get_jaxified_graph(inputs, outputs)
     94 mode.JAX.optimizer.optimize(fgraph)
     96 # We now jaxify the optimized fgraph
---> 97 return jax_funcify(fgraph)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/functools.py:888, in singledispatch.<locals>.wrapper(*args, **kw)
    884 if not args:
    885     raise TypeError(f'{funcname} requires at least '
    886                     '1 positional argument')
--> 888 return dispatch(args[0].__class__)(*args, **kw)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:641, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    634 @jax_funcify.register(FunctionGraph)
    635 def jax_funcify_FunctionGraph(
    636     fgraph,
   (...)
    639     **kwargs,
    640 ):
--> 641     return fgraph_to_python(
    642         fgraph,
    643         jax_funcify,
    644         type_conversion_fn=jax_typify,
    645         fgraph_name=fgraph_name,
    646         **kwargs,
    647     )

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/utils.py:723, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    721 body_assigns = []
    722 for node in order:
--> 723     compiled_func = op_conversion_fn(
    724         node.op, node=node, storage_map=storage_map, **kwargs
    725     )
    727     # Create a local alias with a unique name
    728     local_compiled_func_name = unique_name(compiled_func)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/functools.py:888, in singledispatch.<locals>.wrapper(*args, **kw)
    884 if not args:
    885     raise TypeError(f'{funcname} requires at least '
    886                     '1 positional argument')
--> 888 return dispatch(args[0].__class__)(*args, **kw)

File /opt/homebrew/Caskroom/miniforge/base/envs/pymc-dev-py39/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:143, in jax_funcify(op, node, storage_map, **kwargs)
    140 @singledispatch
    141 def jax_funcify(op, node=None, storage_map=None, **kwargs):
    142     """Create a JAX compatible function from an Aesara `Op`."""
--> 143     raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")

NotImplementedError: No JAX conversion for the given `Op`: _PolyaGammaLogDistFunc{get_pdf=True}
fbarfi commented 2 years ago

@danhphan there are thus two issues: pm.MvNormal and pm.PolyaGamma -- both of which seem to be coming from JAX implementation here are my watermarks:

Last updated: 2022-04-02T08:01:07.939068-04:00

Python implementation: CPython
Python version       : 3.9.10
IPython version      : 8.0.1

Compiler    : Clang 11.1.0 
OS          : Darwin
Release     : 21.4.0
Machine     : arm64
Processor   : arm
CPU cores   : 10
Architecture: 64bit

aesara     : 2.4.0
arviz      : 0.11.4
statsmodels: 0.13.2
json       : 2.0.9
pymc       : 4.0.0b3
numpy      : 1.21.5
pandas     : 1.4.1
danhphan commented 2 years ago

Hi @fbarfi,

Both Aesara's SolveTriangular (https://github.com/aesara-devs/aesara/pull/880) and BroadcastTo Ops (https://github.com/aesara-devs/aesara/pull/883) have supported Jax. I am not sure these changes have also been updated into PyMC latest version yet, but you can try to pull the latest PyMC and Aesara version, and run your code again to check.

Hopefully it will work :D

ricardoV94 commented 2 years ago

Unfortunately PyMC will likely fail with later Aesara version (we have a PR opened for that), but we will try to fix that soon

fbarfi commented 2 years ago

@danhphan thank you for the update. Will do and keep you posted. best.

fbarfi commented 2 years ago

how do you update pymc without reinstalling the whole thing since I am already using v.40.b3 (dev)? I tried pip but nothing changed. thanks.

danhphan commented 2 years ago

You can use git command git pull to update the latest pymc code.

Also, you may need to wait for this PR #5672 to be merged first :) In order to get the latest code from Aesara. Hopefully it will be merged soon.

fbarfi commented 2 years ago

I already updated to pymc 4.0.0b6 and aesara 2.5.1 using the git pull command. Everything so far works fine. thanks.