emdgroup / baybe

Bayesian Optimization and Design of Experiments
https://emdgroup.github.io/baybe/
Apache License 2.0
269 stars 42 forks source link

Dealing with zero-inflated data #368

Open brandon-holt opened 2 months ago

brandon-holt commented 2 months ago

I was wondering if there are any built-in options for transforming or handling zero-inflated data? My understanding is that this process works best with gaussian-distributed data.

@AdrianSosic @Scienfitz

AdrianSosic commented 2 months ago

Hi @brandon-holt. Yes, right now, we largely focus on Gaussian data at the moment (this assumption is basically hard-coded in our GP model) and with the upcoming 0.11.0 release, you'll also have the possibility to model binary targets using bandits #343. It is generally planned to make the capabilities broader in all directions, that is, we want to also

Of course, this is not just a matter of configurability but requires actual extensions of our models, such as approximate GP variants. The latter would also generally enable your zero-inflated case.

That said, even with the current state of the code, nothing would block you from implementing a surrogate that is tailor-made for zero-inflated data – the necessary machinery is all there, I think. All that is needed is to write a corresponding surrogate model class that either 1) inherits from our SurrogateModel base class, which allows to reuse much of the existing functionality such as scaling or 2) simply write a class that conforms to our SurrogateModelProtocol

In such a class, you could train a specialized probabilistic machine learning model that is capable of dealing with zero-inflated data. Nowadays, this can be done rather easily with frameworks such as pyro, where you can easily build Bayesian models where you encode the zero-inflatedness property using priors, e.g. like the horseshoe prior.

I'ld love to assist, but I have to be realistic: until the end of this year, I'll very much need to focus on other topics like multi-target support etc. But I'm happy to answer questions if that helps ✌🏼

brandon-holt commented 2 months ago

@AdrianSosic I see, thanks for the informed response! How does this approach look for approach 1 (inherit from SurrogateModel base class):

import pyro
import pyro.distributions as dist
import torch
from baybe.surrogate_model import SurrogateModel

class ZeroInflatedSurrogateModel(SurrogateModel):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        # Define your model parameters here

    def model(self, x, y=None):
        # Define the zero-inflated model using Pyro
        zero_prob = pyro.param("zero_prob", torch.tensor(0.5))
        with pyro.plate("data", x.shape[0]):
            zero_inflated = pyro.sample("zero_inflated", dist.Bernoulli(zero_prob))
            mean = pyro.param("mean", torch.zeros(self.output_dim))
            std = pyro.param("std", torch.ones(self.output_dim))
            obs = zero_inflated * 0 + (1 - zero_inflated) * pyro.sample("obs", dist.Normal(mean, std))
            if y is not None:
                pyro.sample("obs", dist.Normal(mean, std), obs=y)

    def fit(self, x, y):
        # Fit the model to the data
        pyro.clear_param_store()
        svi = pyro.infer.SVI(model=self.model, guide=self.guide, optim=pyro.optim.Adam({"lr": 0.01}), loss=pyro.infer.Trace_ELBO())
        num_steps = 1000
        for step in range(num_steps):
            loss = svi.step(x, y)
            if step % 100 == 0:
                print(f"Step {step} : loss = {loss}")

    def predict(self, x):
        # Predict using the fitted model
        predictive = pyro.infer.Predictive(self.model, guide=self.guide, num_samples=1000)
        samples = predictive(x)
        return samples["obs"].mean(0)

    def guide(self, x, y=None):
        # Define the guide for variational inference
        zero_prob = pyro.param("zero_prob", torch.tensor(0.5))
        mean = pyro.param("mean", torch.zeros(self.output_dim))
        std = pyro.param("std", torch.ones(self.output_dim))
        with pyro.plate("data", x.shape[0]):
            zero_inflated = pyro.sample("zero_inflated", dist.Bernoulli(zero_prob))
            pyro.sample("obs", dist.Normal(mean, std))

And then to use the custom model

campaign.recommender.recommender.surrogate_model = ZeroInflatedSurrogateModel(input_dim, output_dim)

1) Would this have all the necessary machinery to be functional with baybe? 2) Would this mean GPs are no longer used anywhere to model any aspect of the data when the surrogate model is assigned this way?

AdrianSosic commented 2 months ago

Hi @brandon-holt, please apologize the delay, I'm traveling since two weeks and will be OOO till early October. Writing from my phone, hence very limited capabilities and will not include any links/code.

To answer your question: No, this is not (yet) functional, but you are almost there. With the latest 0.11.0 release, you basically have two options ways to inject your own models (perhaps I already mentioned this above but I can't see my old text while typing):

  1. Write a class that conforms with our SurrogateProtocol. Pros: Absolute minimal requirements, no need to dig into our own baybe classes, thus entirely frictionless. Just implement a method to fit the model and one to export the fitted model to botorch (see docstrings). Cons: you need to take care of everything yourself, e.g. proper input/output scaling

2) Inherit from our SurrogateModel base class. Pros: you only need to implement the core mathematical operations of your model and can easily configure scaling if needed. Cons: you need to understand our class/method hierarchy.

In both cases, what will be required in the end (and what is missing in your outlined approach) is that your class is able to produce a Posterior object for an arbitrary set of candidat points. This is a mathematical requirement to enable proper batch prediction in the Bayesian optimization sense. The good news: you are almost there. Since you have coded a model that is inherently probabilistic, you already have all necessary pieces available. Just wrap your predictions in an appropriate Posterior object instead of projecting down to mean values only. Have a look at our existing GP or RandomForest model to see how it can be done or check out Botorchs doc page on posterior objects.

Once this is done, you should be able to run it. And yes, you'll then completely bypass any Gaussianity assumption whatsoever.

Hope that helps πŸ‘Œ happy to help you further, but probably with some delay involved. But I'm sure that @Scienfitz and @AVHopp can also point you to the right pages in our or botorchs docs if needed πŸ™ƒ