TyXe-BDL / TyXe

MIT License
144 stars 33 forks source link

Flipout implementation for recurrent networks (lstm, gru) #6

Open lbasora opened 2 years ago

lbasora commented 2 years ago

Thanks for the excellent TyXe initiative. Currently, flipout is implemented in TyXe for linear and convolutional layers. Are you considering supporting as well RNN in the near future? If I understood well, for linear and conv layers you monkey-patched F.linear and F.conv, but I didn't see any equivalent functions for RNNs in torch.nn.functional allowing for a similar solution. Do you have any idea on how to implement this?

hpplyt commented 2 years ago

Thanks for your kind words about the project!

We currently don't have plans for adding RNN support for flipout ourselves. But it would of course be great to have in the library. I had a quick look at what the pytorch RNN classes do under the hood and they all seem to be calling torch._VF.gru etc for the forward pass, so I'd imagine that if you monkey-patch these functions, a similar solution to the one for linear and conv layers could work.

I think it would just require a couple of small changes to the Reparameterization and FlipoutMessenger, specifically not assuming that the functions live in nn.functional and accounting for the different argument structure of the RNN functions compared to linear and conv. Let me know if you'd like to have a go at implementing this (and ideally submit a PR :-) ), I'm happy to advise in more detail.

lbasora commented 2 years ago

Not ready for a PR yet but have a look at: https://github.com/lbasora/TyXe/blob/56ae7d32f6bf877d142bba99e2c38d84b0c25999/tyxe/poutine/reparameterization_messengers.py

It's just an initial attempt to test the monkey-patching and the working principle. Let me know whether the solution is more or less what you had in mind.

The code has not been properly tested and is still incomplete: dropout, bidirectional, multilayer options not implemented yet.

Here is a very basic code snipped to play with it:

from functools import partial
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import pyro
import pyro.distributions as dist
import tyxe
from tqdm.auto import trange

class Lstm(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(3, 2, batch_first=True)
        self.out = nn.Linear(2, 1)

    def forward(self, x):
        _, (ht, _) = self.lstm(x)
        return self.out(ht.squeeze())

class Gru(nn.Module):
    def __init__(self):
        super().__init__()
        self.gru = nn.GRU(3, 2, batch_first=True)
        self.out = nn.Linear(2, 1)

    def forward(self, x):
        _, ht = self.gru(x)
        return self.out(ht.squeeze())

net = Lstm()
x = torch.rand(5, 4, 3)
y = torch.rand(5, 1)
pyro.set_rng_seed(42)
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda") if USE_CUDA else torch.device("cpu")
prior = tyxe.priors.IIDPrior(
    dist.Normal(torch.tensor(0.0, device=DEVICE), torch.tensor(1.0, device=DEVICE))
)
likelihood = tyxe.likelihoods.HomoskedasticGaussian(len(x), scale=0.1)
guide = partial(tyxe.guides.AutoNormal, init_scale=0.1)
bnn = tyxe.VariationalBNN(net, prior, likelihood, guide)

ds = TensorDataset(x, y)
dl = DataLoader(ds, batch_size=len(x), pin_memory=USE_CUDA)
pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 1e-3})
num_epochs = 1
pbar = trange(num_epochs)
elbos = []

def callback(_i, _ii, e):
    elbos.append(e / len(dl.sampler))
    pbar.update()

with tyxe.poutine.flipout():
    bnn.fit(dl, optim, num_epochs, device=DEVICE, callback=callback)
hpplyt commented 2 years ago

Thanks, I'll try to have a closer look at this in the next couple of days. In terms of the interface it is what I have in mind.

Regarding the implementation, I was hoping it would be possible to avoid re-implementing the forward passes by hand. Since what you're doing ultimately relies on calling F.linear, I wonder if it wouldn't be easier to offer regular nn.Module implementations of LSTM/GRU and the corresponding cells that call F.linear. Then the existing implementation would just work with those module if I'm not mistaken?

I'm also a little bit worried that with the way this is implemented a new flipout mask would be sampled at every time step corresponding to a different weight sample. However, the sample from the weight posterior should be tied across all time steps, I think we'd need some extra logic for caching the sign multipliers in the original reparameterize function.

Regarding testing, just as a basic sanity check, have you tried passing an input that is repeated along the batch dimension through the network inside and outside of a flipout context? Inside the context you should get different outputs (if the implementation works), whereas outside they should all be identical since the same sample for the weights should be used across the batch.

lbasora commented 2 years ago

To avoid re-implementing the forward pass, are you suggesting TyXe should provide the user with LSTM/GRU Flipout layers? As in bayesian-torch or in Edward2? But then what about the nice TyXe feature of turning a regular pytorch LSTM/GRU model into a BNN model?

I agree with you current implementation isn't correct because as you say a new flipout mask is applied at each time step. But I don't know yet how we can cache the sign multipliers. Let me know if you see a practial way to do that.

Good suggestion for the testing.

hpplyt commented 2 years ago

But then what about the nice TyXe feature of turning a regular pytorch LSTM/GRU model into a BNN model?

That is definitely the priority :) I was just thinking that if we add code to the library that is an "explicit" pytorch implementation of a gru/lstm forward pass we might as well expose it rather than hiding it in the reparameterization messengers. This seems to be something that some people are interested in.

That would of course require having an additional helper function for conversion, but I think this all could be as simple as putting your code for the forward pass in a separate function, set up tyxe.nn.LSTM etc classes that inherit from the corresponding pytorch modules and have them call this function in the forward pass. Then the converter just has to change the __class__ attribute of the module.

This also has the advantage that it's a bit more explicit that the forward pass is being changed. I'm a bit concerned that such a pure pytorch implementation would be a lot slower than the cuda kernel that pytorch calls.

But I don't know yet how we can cache the sign multipliers. Let me know if you see a practial way to do that.

I had another look at the implementation and realized that I was already caching the masks by setting them as attributes on the sampled weights. So the samples should actually be consistent across time steps. I'll need to figure out how to test this though to be absolutely sure, let me know if you have any thoughts on this.

lbasora commented 2 years ago

I've commited some changes (8e1e5b37ff6e07c6fffaca967aa2171b36c1d64d) to try to take into account your last comments. Please let me know whether the interface is convenient for you or you want some changes.

I've tested and we obtain the same results that with the pytorch GRU/LSTM implementation when not using flipout. I also did the basic sanity check you suggested: when input is repeated along the batch dimension inside the flipout context we get different outputs whereas outside they are identical. I haven't tested though the consistency regarding the cached sign multipliers.

The current LSTM/GRU is not complete yet. Features like multilayer, bidirectionality and dropout can be added later though.