sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
4.01k stars 412 forks source link

About the input_size of summary #177

Open UnBuen opened 3 years ago

UnBuen commented 3 years ago

When i use summary(model, input_size), There is a problem about the parameter 'input_size", in general, input_size = (C, H, W), but for graph datasets, there are nodes and labels and edges. So it can't work.

I don’t even need to use summary to print the size of each layer of the network, I only need to print the model structure.

Code address: https://github.com/debadyuti23/GraphCovidNet, If needed, I can provide a processed dataset.

image image

Please help me again. It may be due to my limited level. I tried many times but failed to solve this problem.

model.py is as follows:

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn
import torch
class GNNStack(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, task='node'):
        super(GNNStack, self).__init__()
        self.task = task
        self.convs = nn.ModuleList()
        self.convs.append(self.build_conv_model(input_dim, hidden_dim))
        self.lns = nn.ModuleList()
        self.lns.append(nn.LayerNorm(hidden_dim))
        self.lns.append(nn.LayerNorm(hidden_dim))
        for l in range(2):
            self.convs.append(self.build_conv_model(hidden_dim, hidden_dim))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.Dropout(0.5),
            nn.Linear(hidden_dim, output_dim))
        if not (self.task == 'node' or self.task == 'graph'):
            raise RuntimeError('Unknown task.')

        self.dropout = 0.5
        self.num_layers = 3

    def build_conv_model(self, input_dim, hidden_dim):
        # refer to pytorch geometric nn module for different implementation of GNNs.
        if self.task == 'node':
            return pyg_nn.GCNConv(input_dim, hidden_dim)
        else:
            return pyg_nn.GINConv(nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                                nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))

    def forward(self, data):

        x, edge_index, batch = data.x, data.edge_index, data.batch

        # print("data:{}".format(data))
        """
            data:Batch(batch=[7127], edge_index=[2, 37816], x=[7127, 3], y=[1])
            data:Batch(batch=[2469], edge_index=[2, 12144], x=[2469, 3], y=[1])
            data:Batch(batch=[4846], edge_index=[2, 27550], x=[4846, 3], y=[1])
            data:Batch(batch=[4955], edge_index=[2, 19078], x=[4955, 3], y=[1])
            data:Batch(batch=[8360], edge_index=[2, 48258], x=[8360, 3], y=[1])
            ......
        """
        # print("data.num_node_features:", data.num_node_features)    # 3
        # num_node_features
        if data.num_node_features == 0:
            x = torch.ones(data.num_nodes, 1)

        for i in range(self.num_layers):    # num_layers = 3
            x = self.convs[i](x, edge_index)
            emb = x
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if not i == self.num_layers - 1:
                x = self.lns[i](x)

        if self.task == 'graph':
            x = pyg_nn.global_mean_pool(x, batch)

        x = self.post_mp(x)

        return emb, F.log_softmax(x, dim=1), F.softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)
`