sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
581 stars 150 forks source link

Automatic feature extraction #527

Closed akapet00 closed 3 years ago

akapet00 commented 3 years ago

Hi,

In [Goncalves et al., 2020] it is stated that: "SNPE can be applied to, and might benefit from the use of summary features, but it also makes use of the ability of neural networks to automatically learn informative features in high-dimensional data. Thus, SNPE can also be applied directly to raw data (e.g. using recurrent neural networks [Lueckmann et al., 2017]), ...". The work by [Lueckmann et al., 2017] is related to the method SNPE_B, at least according to sbi documentation, however, in the version of sbi I am currently using (0.16.0), is stated that mentioned inference algorithm in currently not implemented.

Nevertheless, I have been playing around with SNPE (or, more precisely, SNPE_C) and raw data and it seems to work quite well for a very simple example similar to that in brian2 official example directory, available here. In this example, the Hodgkin-Huxley neuron model is used to test the ability of simulation-based inference and the possibility of integration with brian2. It is based on a fake current-clamp recording generated from the same model that has been used in the inference process. Two of the parameters (the maximum sodium and potassium conductivity values) are considered unknown and are proceed to be inferred from the data.

The first thing I tried is using embedding network as the way to semi-automatically extract relevant features. This embedding network is based on Time2Vec, which is, in a nutshell, a very simple sinusoidal layer:

class SinLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(SinLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1))
        self.b0 = nn.parameter.Parameter(torch.randn(1))
        self.w = nn.parameter.Parameter(torch.randn(in_features, out_features-1))
        self.b = nn.parameter.Parameter(torch.randn(out_features-1))

    def forward(self, x):
        linear = torch.matmul(x, self.w0) + self.b0
        periodic = torch.sin(torch.matmul(x, self.w) + self.b)
        return torch.cat([linear, periodic], axis=1)

    def extra_repr(self):
        return (
            f'in_features={self.in_features}, '
            f'out_features={self.out_features}')

class Time2Vec(nn.Module):
    def __init__(self):
        super(Time2Vec, self).__init__()
        self.p1 = SinLayer(in_features, hidden_features)
        self.p2 = SinLayer(hidden_features, hidden_features)
        self.l = nn.Linear(hidden_features, embeddings_size)

    def forward(self, x):
        x = self.p1(x)
        x = self.p2(x)
        x = self.l(x)
        return x

in_features = x.shape[1]
hidden_features = ...
embeddings_size = ...
embedding_net = Time2Vec()

and according to r/MachineLearning commentators is nothing but a "quality case of 'just throw neural networks at it' and is overall just a shitty rehashing of discrete Fourier transforms". In the original paper for Time2Vec authors use this sine representation of the input just as the additional layer to LSTM or GRU and it seems to produce better results than vanilla reccurent networks, but in the case I've been working on, it does not seem to work well.

Next approach was applying raw data output (generated voltage traces), x, of size (10000, 7000) to SNPE. It works extremely well and is comparable (if not better) than the situation where I have used summary statistics consisted of mean and std of the active potential, number of spikes and maximum value of the membrane potential from generated traces. The thing that I do not understand is, how is this possible? Am I doing something wrong or this SNPE_C approach is able to automatically extract features from the data even though embedding_net is still set to None. In [Lueckmann et al., 2017], in subsection 2.3. under Learning Features, it is stated that in the cases when time-series recordings are directly fed into the network, the first layer of the MDN becomes a recurrent layer instead of a fully connected one. But even with different methods, such as NSF for example, I've been able to obtain good results even though much slower. The notebook is available here.

Sorry for this long text :grimacing:

refs. Goncalves et al. eLife 2020; 9:e56261 available online Lueckmann et al. in proceeding of NIPS 2017 available online

michaeldeistler commented 3 years ago

Hi there,

thanks for reading out! I had a brief look at your notebook, I do not think that you are doing anything wrong. A few thoughts: 1) Both SNPE-B and SNPE-C can "learn informative features in high-dimensional data". Also, the mechanism with which they do this is identical (by adding layers at the beginning of the neural net). So, there's no difference there. Aside: SNPE-C and SNPE-B are identical if you do only one "round" (i.e. if you do not do what is described here ) 2) if you specify embedding_net=None, the embedding network is basically a multi-layer-perceptron (MLP) defined here. So, if this MLP is sufficient to learn informative features of your data, then it will work -- I think this is what is happening in your case.

I hope that helps! Michael

akapet00 commented 3 years ago

Yes, it helps, thank you very much. I still have a few questions.

  1. Does the same thing described in your second point applies to other neural networks (such as nsf for example)?
  2. Do you scale the data that are fed into the network anywhere?
  3. Why is the output of the network ReLU activated (line 62)?
  4. Is there any example where you applied custom embedding network for time series data? I am not sure how the mentioned RNN would need to be defined in order to provide automatic feature extraction given more complex task than the one I've been playing around in the notebook I provided.

Thank you again.

Best, Ante

michaeldeistler commented 3 years ago
  1. Somewhat, yes. In the case of flows, the simulation outputs are added to the activations of the flows and then passed into a resnet. So yes, even with flows, this can work.
  2. By scale, you mean normalize? If yes, then: yes, we do that here. You can turn it off via density_estimator = posterior_nn("mdn", z_score_x=False)
  3. After the relu, there are linear units that map onto the means, stds, and weights of the Mixture of Gaussians, see this file
  4. I have not done this, but you might want to have a look at this

Best Michael

akapet00 commented 3 years ago
4. I have not done this, but you might want to have a look at this

It is so good to see a paper with <15 pages for a change. I'll check it out!

Thanks for all the information, really helpful! Feel free to close this issue now since I am all out of questions :smile:

Best, Ante