sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
589 stars 151 forks source link

Adding RNN to MDN #818

Closed philipp128 closed 1 year ago

philipp128 commented 1 year ago

Hello,

I am using the SNPE-C algorithm in my current inference problem. I don't know if this is the correct place to ask but I am struggling with adding an RNN as described in appendix A.5.3 in Greenberg et al. (2019) into the MDN network.

If I understand the Lotka-Volterra benchmark test correctly, I think adding an RNN in my inference problem could be very beneficial.

In my problem I have a data vector of length 750 and my simulator produces a similar output. My parameter space has dimension 10. At the moment, I am using the default implementation of the MAF network to fit a mock data set of which I exactly know the true parameter values (the mock data is created with the same simulator). However, running the inference returns very bad results with some parameters way off from the true values.

As I said, I think following the approach of Greenberg et al. (2019) in the Lotka-Volterra example for inference on raw time-series could help in my case.

Unfortunately, I am quite new to this topic and I don't understand how I can add the RNN in form of an initial layer of 100 GRU units to the MDN network. Looking at the code, it does not seem that I could change the default MDN in the way of just adding another layer to it. Would I need to build a custom density estimator from scratch or is there another way to do it? From my understanding, adding an embedded net is not the correct way, isn't it?

Thanks in advance for your help :)

Philipp

manuelgloeckler commented 1 year ago

Hey,

An embedding net should be the correct way to do this.

If you look at the documentation, you can modify the default density estimators by passing a function when called builds the density estimator. To add an RNN or any other neural net you can use

from sbi.utils.get_nn_models import (
    posterior_nn,
)  # For SNLE: likelihood_nn(). For SNRE: classifier_nn()

rnn = RNN() # You rnn that maps data_dim -> some_dim
density_estimator_build_fun = posterior_nn(
    model="mdn", embedding_net = rnn,
)
inference = SNPE(prior=prior, density_estimator=density_estimator_build_fun)

What happens internally is, that your data is first passed through the RNN, resulting in a "feature"/"statistic" summarizing it. This is then passed to the density estimator which then tries to estimate the posterior.

If you want to do another modification to the MDN, then you also can do this within the "density_estimator_build_fun".

Kind regards, Manuel

michaeldeistler commented 1 year ago

You might also want to read this tutorial

philipp128 commented 1 year ago

Thanks Manuel and Michael, that was really helpful. I was really in believe that an embedded net would be not the right choice here but it makes sense now.

I included the RNN in my network but unfortunately it seems to not help that much. Was worth a try tho. :)

michaeldeistler commented 1 year ago

Nice to hear that it worked, but bummer that it does not help much. For using an embedding net, you often need many simulations. See also my answer here.