waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

Profile/fold custom layer #76

Open matthijsvk opened 4 years ago

matthijsvk commented 4 years ago

Hi,

even for very simle wrappers around Pytorch layers the graph can get quite complex. Is there a way to create subgraphs automatically fold those in 1 block?

example:

class Lin(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.lin = nn.Linear(*args, **kwargs)
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        return self.lin(x)

model = nn.Sequential(*[ nn.Sequential(Lin(10), nn.ReLU()) for i in range(5)])
batch = torch.randn((4, 10))
print(model)

graph = hl.build_graph(model, batch)
    graph.save("graph", format="pdf")

What I mean is something like:

# profile a submodule
lin_example = Lin(10, 10)
lin_graph = hl.build_graph(lin_example, batch)

# now fold all the submodules in the main graph
graph = graph.fold_submodule(lin_graph)