pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Use functional models inside usual nn.Module #1111

Open subho406 opened 1 year ago

subho406 commented 1 year ago

Hi, Thanks for the adding functional features to Pytorch. I want to use a nn.Module converted into a functional form inside a usual stateful nn.Module. However, the code below does not correctly register the parameters for the functional module. Is there a way to do this currently?

import torch
import optree
import torch.nn as nn
from functorch import make_functional

x = torch.randn(4, 10)
class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()
        self.func_l,self.params_l=make_functional(nn.Linear(10,10))
        for i,ele in enumerate(self.params_l):
            self.register_parameter(str(i),ele)
    def forward(self,inputs):
        return self.func_l(self.params_l,inputs)

model = TinyModel()
func, params = make_functional(model)

This is useful for me as I want to use functional operations over an inner nn.Module (such as vmap, jvp, vip) inside the forward pass of an outer nn.Module. The idea is to be able to have a lifted version of vjp, jvp, etc, similar to Flax (https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.vjp.html).

subho406 commented 1 year ago

I figured a way to do this. Here is a sample code:

class LinearModule(torch.nn.Module):
    def __init__(self):
        super(LinearModule, self).__init__()
        self.model,params=functorch.make_functional(torch.nn.Linear(10,20))
        self.params=torch.nn.ParameterList(params)

    def forward(self,inputs):
        return self.model(self.params,inputs)