HarisIqbal88 / PlotNeuralNet

Latex code for making neural networks diagrams
MIT License
21.92k stars 2.86k forks source link

Feature: Automated Creation Based on Example for PyTorch Linear Modules with ReLU Activations #126

Open git-thor opened 2 years ago

git-thor commented 2 years ago

Work in progress

The PR addresses #124

Automated generation from PyTorch module class `torch.nn.module child, leveraging torchinfo architecture summary interface, comparable with TensorFlow/Keras summary method.

image

This is created from the following code:

# Define example module
import torch as th

class MLP(th.nn.Module):

    def __init__(self):
        super(MLP, self).__init__()
        self.net = th.nn.Sequential(
            th.nn.Linear(2, 16),
            th.nn.ReLU(),
            th.nn.Linear(16, 16),
            th.nn.ReLU(),
            th.nn.Linear(16, 1)
        )

    def forward(self, x):
        x = self.net(x)
        return x

# Parse the example module
from pycore.torchparse import TorchArchParser
from pycore.tikzeng import to_generate

mlp = MLP()
parser = TorchArchParser(torch_module=mlp, input_size=(1,2))
arch = parser.get_arch()
to_generate(arch, pathname="./test_torch_mlp.tex")

TODOs for subsequent PRs

Addressed #124 with respect to PyTorch

git-thor commented 2 years ago

Please feel free to feed back and contribute to this feature. Especially regarding the subsequent Keras support.

git-thor commented 1 year ago

Ready as initial functionality for PyTorch automated generation support. Please review and merge if deemed OK.

space192 commented 8 months ago

hey guys, any plan to support conv layer ?

git-thor commented 8 months ago

Hey @space192 I am currently occupied and hindered to push the PR further but we happily accept your extension to CNNs. You can create a PR regarding this to that branch of my fork - so it shows up here.