pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.67k stars 2k forks source link

`pt.max` is not differentiable in PyMC models #7251

Closed jessegrabowski closed 5 months ago

jessegrabowski commented 5 months ago

Description

The following model.dlogp() raises a NotImplemented error:

with pm.Model() as model:
    x = pm.Normal('x', shape=(2,))
    mu = x.max()
    pm.Normal('obs', mu, observed=np.random.uniform())

This should be differentiable (pt.max has gradients implemented), so it seems like something is going wrong in rewrites, either with respect to logp rewrites or with respect to MaxAndArgmax

bburan commented 5 months ago

min also raises a NotImplementedError.

ricardoV94 commented 5 months ago

min is just implemented as negative of max of negative, so that's expected

tanish1729 commented 5 months ago

hi! i can take up this issue. i can see that this has to sort out dealing with gradients in the rewrites for some specific functions. could you provide some more details for me to start working

ricardoV94 commented 5 months ago

@tanish1729 this one is not really a beginner friendly issue. I'll try and fix it now myself. Let us know if you need help finding more suitable issues

tanish1729 commented 5 months ago

oh great i see you did this yourself. i'll go through the code and see if i can understand it. what are some other good beginner friendly issues open rn?

ricardoV94 commented 5 months ago

oh great i see you did this yourself. i'll go through the code and see if i can understand it. what are some other good beginner friendly issues open rn?

You can filter issues on Github by labels: https://github.com/pymc-devs/pymc/issues?q=is%3Aissue+is%3Aopen+label%3A%22beginner+friendly%22

bburan commented 5 months ago

@ricardoV94 Thanks so much for fixing this so quickly. I can verify this solved the issue in my model and I am now able to run my model using NUTS only. Brings runtime down from ~2 hours to 10 minutes (3 minutes if I use an experimental NUTS sampler such as numpyro).

ricardoV94 commented 5 months ago

You're welcome. By the way we don't consider the numpyro integration experimental anymore