mert-kurttutan / torchview

torchview: visualize pytorch models
https://torchview.dev
MIT License
793 stars 36 forks source link

Support forward with multiple arguments #93

Open joaolcguerreiro opened 1 year ago

joaolcguerreiro commented 1 year ago

Imagine I have a module like this:

class Model(nn.Module):
    def __init__(self, generator, discriminator):
        super(Model, self).__init__()

        # Define Generator
        self.generator = generator

        # Define Discriminator
        self.discriminator = discriminator

    def forward(self, lr, hr):
        gen = self.generator(lr)

        return gen, self.discriminator(gen), self.discriminator(hr)

If I want to call draw_graph(model, input_size=..., depth=1) what should the input_size look like? Is it supported?

I believe the draw_graph function could handle a input_size in a list meaning the forward will receive as many arguments as element in the list passed.

mert-kurttutan commented 1 year ago

Yes it is supported. input_size is either SizeItem or list(SizeItem) where SizeItem is anything that can represent the shape of a torch Tensor, e.g. tuple, torch.Size.

mert-kurttutan commented 1 year ago

If this does not work you, you can also show your code?