mjo22 / cryojax

Cryo electron microscopy image simulation and analysis built on JAX.
https://mjo22.github.io/cryojax/
GNU Lesser General Public License v2.1
30 stars 13 forks source link

Build a class that computes the log-likelihood over ensembles? #182

Closed mjo22 closed 7 months ago

mjo22 commented 8 months ago

I've been thinking more about how we might build batching into the library, specifically in order to compute likelihood. If there is interest, a new class in cryojax.inference would be easy to implement that would compute a batch of log-likelihood values.

# Library code in cryojax.inference
import equinox as eqx
from jaxtyping import Shared
from cryojax.typing import Image
from cryojax.inference import AbstractDistribution, AbstractMarginalDistribution
from cryojax.core import filter_vmap_with_spec

class LogLikelihoodBatcher(eqx.Module):

    distribution: AbstractDistribution
    filter_spec: AbstractDistribution[bool]

    @eqx.filter_jit
    def __call__(self, observed_stack: Shaped[Image, "batch_dim"]) -> Shaped[RealNumber, "batch_dim"]:
        """Compute a vector of log-likelihood values over parameters and observed data."""
        if isinstance(self.distribution, AbstractMarginalDistribution):
            log_likelihood_func = self.distribution.marginal_log_likelihood
        else:
            log_likelihood_func = self.distribution.log_likelihood
        # ... vmap over arbitrary arbitrary pytree leaves, specified by `filter_spec`
        vmapped_log_likelihood_func = filter_vmap_with_spec(log_likelihood_func, filter_spec=filter_spec)
        # compute batch of log likelihood values
        return vmapped_log_likelihood_func(observed)

# Some script a user has written
import jax
from cryojax.core import get_filter_spec  # see here for an introduction to filter_specs: https://mjo22.github.io/cryojax/examples/simulate-micrograph/
from cryojax.inference import IndependentFourierGaussian, LogLikelihoodBatcher

# Read observed data
image_stack = ...
# Initialize statistical model, making sure that params we want to batch
# over have a batch dimension
...
distribution = IndependentFourierGaussian(...)
# Point to the parameters we want to batch over
where = lambda dist: dist.path.to.vmapped.params  # this is pseudocode
filter_spec = get_filter_spec(distribution, where)
# Initialize our new batching utility
log_likelihood_batcher = LogLikelihoodBatcher(distribution, filter_spec)

@jax.jit
def compute_log_likelihood(log_likelihood_batcher, image_stack):
    return jnp.sum(log_likelihood_batcher(image_stack))  # assuming we just want to add likelihood values

# Finally, compute
loglike = compute_log_likelihood(log_likelihood_batcher, image_stack)

This is a good starting point for a skeleton I think. This uses utilities I've recently written in cryojax.core (see here for a tutorial).

I’m not totally sure if this is something that should go in the library because this LogLikelihoodBatcher class is a function that’s only a few lines (I would even recommend people try to write things this themselves). But it depends on people’s interest and the direction of the library!

geoffwoollard commented 8 months ago

Computing log_prob batches raises the question of shape/vmap and the independence structure of inference.

Probabilistic programming frameworks like pyro have ways to do this book keeping (batch dimension, plates) https://pyro.ai/examples/tensor_shapes.html https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/

mjo22 commented 8 months ago

This makes sense. This may be reason to not include things like this in the library! What I wrote is easy enough to do externally, and in general we should stay away from assuming how exactly a user might want to do this. A good alternative would be to write up good notebook examples in the docs.

mjo22 commented 7 months ago

Closing as I think this is no longer the way to go. It is easy enough to just vmap in a users specific workflow.