pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.23k stars 3.64k forks source link

Size error while calling the forward method of hetero model #7852

Closed Amirmohammad-Bamdad closed 1 year ago

Amirmohammad-Bamdad commented 1 year ago

🐛 Describe the bug

Description: I've created a model using conv.LGConv and made it hetero using to_hetero(). Now when I want to call the model with my batches x and edge_index, an exception error pops up. What should I do?

CODE:

from torch_geometric.nn.conv import LGConv

class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = LGConv()

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        return x

model = GNN()
print(model)
print("===================")
model = to_hetero(model, data1.metadata(), aggr='sum').to(device)
print(model)

@torch.no_grad()
def init_params():
    # Initialize lazy parameters via forwarding a single batch to the model:
    batch = next(iter(train_loader))
    batch = batch.to(device, 'edge_index')
    print('**', type(batch.x_dict), batch.x_dict)
    model(batch.x_dict, batch.edge_index_dict)

def train():
    model.train()

    total_examples = total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        batch = batch.to(device)
        batch_size = batch['user'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)['user'][:batch_size]
        loss = F.cross_entropy(out, batch['user'].y[:batch_size])
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples

@torch.no_grad()
def test(loader):
    model.eval()

    total_examples = total_correct = 0
    for batch in tqdm(loader):
        batch = batch.to(device, 'edge_index')
        batch_size = batch['user'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)['user'][:batch_size]
        pred = out.argmax(dim=-1)

        total_examples += batch_size
        total_correct += int((pred == batch['user'].y[:batch_size]).sum())

    return total_correct / total_examples

init_params()  # Initialize parameters.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1,2):
    loss = train()
    #val_acc = test(val_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val:')

ERROR:

AttributeError                            Traceback (most recent call last)
[c:\Users\Amir\Downloads\my_D.ipynb](file:///C:/Users/Amir/Downloads/my_D.ipynb) Cell 14 in 6
     [60](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=59)         total_correct += int((pred == batch['user'].y[:batch_size]).sum())
     [62](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=61)     return total_correct / total_examples
---> [65](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=64) init_params()  # Initialize parameters.
     [66](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=65) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
     [68](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=67) for epoch in range(1,2):

File [c:\Users\Amir\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\utils\_contextlib.py:115](file:///C:/Users/Amir/AppData/Local/Programs/Python/Python39/lib/site-packages/torch/utils/_contextlib.py:115), in context_decorator..decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

[c:\Users\Amir\Downloads\my_D.ipynb](file:///C:/Users/Amir/Downloads/my_D.ipynb) Cell 14 in 9
      [7](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=6) batch = batch.to(device, 'edge_index')
      [8](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=7) print('**', type(batch.x_dict), batch.x_dict)
----> [9](vscode-notebook-cell:/c%3A/Users/Amir/Downloads/my_D.ipynb#X23sZmlsZQ%3D%3D?line=8) model(batch.x_dict, batch.edge_index_dict)

File [c:\Users\Amir\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\fx\graph_module.py:662](file:///C:/Users/Amir/AppData/Local/Programs/Python/Python39/lib/site-packages/torch/fx/graph_module.py:662), in GraphModule.recompile..call_wrapped(self, *args, **kwargs)
    661 def call_wrapped(self, *args, **kwargs):
--> 662     return self._wrapped_call(self, *args, **kwargs)

File [c:\Users\Amir\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\fx\graph_module.py:281](file:///C:/Users/Amir/AppData/Local/Programs/Python/Python39/lib/site-packages/torch/fx/graph_module.py:281), in _WrappedCall.__call__(self, obj, *args, **kwargs)
...
     42                        add_self_loops=False, flow=self.flow, dtype=x.dtype)
     43         edge_index, edge_weight = out
     44     elif self.normalize and isinstance(edge_index, SparseTensor):

AttributeError: 'tuple' object has no attribute 'size'

Environment

EdisonLeeeee commented 1 year ago

LGConv does not support bipartite graphs. You can refer to the GNN cheatsheet for alternative methods that do support them.

Amirmohammad-Bamdad commented 1 year ago

LGConv does not support bipartite graphs. You can refer to the GNN cheatsheet for alternative methods that do support them.

Yes, that was it. Thanks for your help.