Open subho406 opened 1 year ago
I figured a way to do this. Here is a sample code:
class LinearModule(torch.nn.Module):
def __init__(self):
super(LinearModule, self).__init__()
self.model,params=functorch.make_functional(torch.nn.Linear(10,20))
self.params=torch.nn.ParameterList(params)
def forward(self,inputs):
return self.model(self.params,inputs)
Hi, Thanks for the adding functional features to Pytorch. I want to use a
nn.Module
converted into a functional form inside a usual statefulnn.Module
. However, the code below does not correctly register the parameters for the functional module. Is there a way to do this currently?This is useful for me as I want to use functional operations over an inner
nn.Module
(such as vmap, jvp, vip) inside the forward pass of an outernn.Module
. The idea is to be able to have a lifted version of vjp, jvp, etc, similar to Flax (https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.vjp.html).