pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Add an Ensemble Module that is constructed from a list of Modules and encapsulates the necessary state #992

Open sinking-point opened 2 years ago

sinking-point commented 2 years ago

Most of the examples I've seen use hmap at the top level, to create an 'outer' ensemble of models, or to factor out the batch dimension. However, my use case is 'inner' ensembles of modules within a larger model. This means I have to register the parameters and buffers from combine_state_for_ensemble with the parent module, which is annoying and messy.

An obvious solution is to create an Ensemble module which internally calls combine_state_for_ensemble and vmap along with storing the necessary state:

self.ens = Ensemble(my_modules, in_dims=(0, 0, 2), out_dims=(0, 0, 2))
...
x = ens(x)

Even if registering the state weren't an issue, I still think this would be a popular feature. It's more intuitive than the current method of creating ensembles.

sinking-point commented 2 years ago

Something like this, perhaps:

class Ensemble(nn.Module):
    def __init__(self, modules, **kwargs):
        super().__init__()

        fmodel, self.params, self.buffers = combine_state_for_ensemble(modules)

        self.vmap_model = vmap(fmodel, **kwargs)

        for i, param in enumerate(self.params):
            self.register_parameter('param_' + str(i), nn.Parameter(param))

        for i, buffer in enumerate(self.buffers):
            self.register_buffer('buffer_' + str(i), nn.Buffer(buffer))

    def forward(self, *args, **kwargs):
        return self.vmap_model(self.params, self.buffers, *args, **kwargs)
zou3519 commented 2 years ago

This seems convenient to have. I am not sure if this would go into functorch or in torch.nn in the long-term state, but we can certainly toss something like this into functorch to start. cc @samdow who is thinking about functional modules. Also curious to hear @jbschlosser and @albanD's opinions as torch.nn maintainers.

albanD commented 2 years ago

This would need to be part of a bigger plan to move things like combine_state_for_ensemble as well? Also this seems to be very vmap specific?

zou3519 commented 2 years ago

Also this seems to be very vmap specific?

Are you suggesting that we should put the nn.Ensemble API into functorch because it is vmap specific?

sinking-point commented 2 years ago

I did wonder about this because my suggestion is not really functional. It doesn't fit with the theme of this package. However, this is the only place it can go since torch can't have functorch as a dependency. Unless we create a new package for this, I guess.

albanD commented 2 years ago

Are you suggesting that we should put the nn.Ensemble API into functorch because it is vmap specific?

Not necessarily but it does sound much "higher level" than things currently in torch.nn. So not sure where it should live.