TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.39k stars 111 forks source link

Nested module not printing the output shape #299

Closed sup3rgiu closed 4 months ago

sup3rgiu commented 5 months ago

Describe the bug In a nested-module scenario, the nn.Module output shape is not printed if the depth is not high enough.

image

Increasing the depth shows the output shape of the nested layers, but this could be quite inconvenient if the model is deeply nested and I'm only interested in the output shape of the "main" module.

image

To Reproduce

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        self.module = nn.ModuleList()
        for i in range(4):
            block = nn.ModuleList()
            for i in range(5):
                block.append(nn.Linear(10, 10))
            module = nn.Module()
            module.block = block
            self.module.append(module)

    def forward(self, x):
        for i in range(4):
            for j in range(5):
                x = self.module[i].block[j](x)
        return x

module = MyModule()

summary(module, input_size=[(3, 10)], dtypes=[torch.float32], depth=2)

Expected behavior Always print the output shape of the most nested layer/module with respect to the selected depth.

allispaul commented 4 months ago

I think the problem is that the nn.Module submodules of that object don't have forward methods and are never actually called. This variant using nn.Sequential does what you want it to:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        self.module = nn.ModuleList()
        for i in range(4):
            block = nn.ModuleList()
            for i in range(5):
                block.append(nn.Linear(10, 10))
            self.module.append(nn.Sequential(*block))

    def forward(self, x):
        for i in range(4):
            x = self.module[i](x)
        return x

as does defining a custom submodule class with a forward method and using that class for the submodules. Does either of those resolve the issue?

sup3rgiu commented 4 months ago

Actually, the problem arose when trying to get a nice summary table for the VAE of Stable Diffusion, where as you can see they use this approach: https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/diffusionmodules/model.py#L542

However, I see your point about the missing forward method, and indeed it would make no sense to try to compute an exact output shape for such nn.Module.

Still, I think it might be useful to try to find a workaround to print an "expected" output shape. Maybe we could set it to the output shape of the last used .submodule of such nn.Module.

For instance, in the following example:

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()

        self.module = nn.ModuleList()
        for i in range(4):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            for i in range(5):
                block.append(nn.Linear(10, 10))
                attn.append(nn.Linear(10, 10))
            module = nn.Module()
            module.block = block
            module.attn = attn
            self.module.append(module)

    def forward(self, x):
        for i in range(4):
            for j in range(5):
                x = self.module[i].block[j](x)
                x = self.module[i].attn[j](x)
        return x

the output shape of the first nn.Module would be the output shape of self.module[0].attn[4](x) since .attn[4] is the last thing called in self.module[0], the output shape of the second nn.Module would be the output shape of self.module[1].attn[4](x), and so on.

But I don't know if torchinfo could be easily extended to keep track of the last used "submodule" of a nn.Module.

allispaul commented 4 months ago

Personally, I'm not sure that this would be desired behavior. If an nn.Module is used just as a container for submodules, as it is here, then one can imagine the submodules being used in lots of different configurations and different types of "output" being extracted from inside the nn.Module. Picking one of these as the "output shape" could be more confusing than the current behavior, which is at least clear and unambiguous (the module has no output and thus no "output shape").

But the actual maintainers should probably weigh in -- I'm just a guy looking for issues to work on :)

sup3rgiu commented 4 months ago

Yes, I agree with you that it's probably not the most desirable behavior. Even if we somehow specify in the output summary that these shapes are "estimated", it could probably be confusing.

Since I recognize that this is probably a non-issue and won't be fixed, we should probably close the issue as "Won't fix".

TylerYep commented 4 months ago

Agreed that this behavior is ambiguous and hard to maintain. Thanks for the discussion, hopefully this thread helps answer future issues like this.