dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.54k stars 3.01k forks source link

DGLError Expected data to have %d rows, got %d. occurs at large batch size #4512

Open sidazhou opened 2 years ago

sidazhou commented 2 years ago

🐛 Bug

DGLError('Expected data to have %d rows, got %d.') occurs at large batch__size, and doesnt occur at smaller batch_size. The larger the batch_size the larger the difference in rows. Feels like a rounding error somewhere.

To Reproduce

BATCH_SIZE = 1000 #  <---- works fine
# BATCH_SIZE = 5000 # <---- DGL errors

sampler = dgl.dataloading.NeighborSampler([4, 4])
_, _, mfgs = sampler.sample_blocks(train_pos_g, seed_ids[:BATCH_SIZE])

print(mfgs[0].srcdata['feat'].shape)
# torch.Size([10239, 128]) <---- works fine
# torch.Size([48913, 128]) <---- DGL errors

model(mfgs, mfgs[0].srcdata['feat']) # <---- errors

Expected behavior

Shouldn't DGLError

Environment

Additional context

model:

# model is the default 2 layer graphSage in the tutorials
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h_dst = x[:mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[:mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        return h

Error stack:

    ---------------------------------------------------------------------------
    DGLError                                  Traceback (most recent call last)
    Input In [58], in <cell line: 13>()
          9 print(mfgs[0].srcdata['feat'].shape)
         10 # torch.Size([10239, 128]) works fine
         11 # torch.Size([48913, 128]) DGL errors
    ---> 13 model(mfgs, mfgs[0].srcdata['feat'])

    File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
       1106 # If we don't have any hooks, we want to skip the rest of the logic in
       1107 # this function, and just call forward.
       1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1109         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1110     return forward_call(*input, **kwargs)
       1111 # Do not call functions when jit is used
       1112 full_backward_hooks, non_full_backward_hooks = [], []

    Input In [1], in Model.forward(self, mfgs, x)
        101 h = F.relu(h)
        102 h_dst = h[:mfgs[1].num_dst_nodes()]
    --> 103 h = self.conv2(mfgs[1], (h, h_dst))
        104 return h

    File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
       1106 # If we don't have any hooks, we want to skip the rest of the logic in
       1107 # this function, and just call forward.
       1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1109         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1110     return forward_call(*input, **kwargs)
       1111 # Do not call functions when jit is used
       1112 full_backward_hooks, non_full_backward_hooks = [], []

    File /opt/conda/lib/python3.10/site-packages/dgl/nn/pytorch/conv/sageconv.py:235, in SAGEConv.forward(self, graph, feat, edge_weight)
        233 if self._aggre_type == 'mean':
        234     graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
    --> 235     graph.update_all(msg_fn, fn.mean('m', 'neigh'))
        236     h_neigh = graph.dstdata['neigh']
        237     if not lin_before_mp:

    File /opt/conda/lib/python3.10/site-packages/dgl/heterograph.py:4900, in DGLHeteroGraph.update_all(self, message_func, reduce_func, apply_node_func, etype)
       4898         key = list(ndata.keys())[0]
       4899         ndata[key] = F.replace_inf_with_zero(ndata[key])
    -> 4900     self._set_n_repr(dtid, ALL, ndata)
       4901 else:   # heterogeneous graph with number of relation types > 1
       4902     if not core.is_builtin(message_func) or not core.is_builtin(reduce_func):

    File /opt/conda/lib/python3.10/site-packages/dgl/heterograph.py:4136, in DGLHeteroGraph._set_n_repr(self, ntid, u, data)
       4132         raise DGLError('Pinned graph requires the node data to be pinned as well. '
       4133                        'Please pin the node data before assignment.')
       4135 if is_all(u):
    -> 4136     self._node_frames[ntid].update(data)
       4137 else:
       4138     self._node_frames[ntid].update_row(u, data)

    File /opt/conda/lib/python3.10/_collections_abc.py:994, in MutableMapping.update(self, other, **kwds)
        992 if isinstance(other, Mapping):
        993     for key in other:
    --> 994         self[key] = other[key]
        995 elif hasattr(other, "keys"):
        996     for key in other.keys():

    File /opt/conda/lib/python3.10/site-packages/dgl/frame.py:584, in Frame.__setitem__(self, name, data)
        574 def __setitem__(self, name, data):
        575     """Update the whole column.
        576 
        577     Parameters
       (...)
        582         The column data.
        583     """
    --> 584     self.update_column(name, data)

    File /opt/conda/lib/python3.10/site-packages/dgl/frame.py:661, in Frame.update_column(self, name, data)
        659 col = Column.create(data)
        660 if len(col) != self.num_rows:
    --> 661     raise DGLError('Expected data to have %d rows, got %d.' %
        662                    (self.num_rows, len(col)))
        663 self._columns[name] = col

    DGLError: Expected data to have 5000 rows, got 4998.
sidazhou commented 2 years ago

image

Narrowed it down to output_nodes.shape mismatch with mfgs[1].dstnodes().shape, 100 vs 99. Why is this? Surely it's a bug?

sidazhou commented 2 years ago

So the issue seems to occur when seed_nodes contain duplicated id. Is this a bug or a feature?

rudongyu commented 2 years ago

It seems that the to_block during sampling will remove duplicated nodes, thus it causes inconsistency between the number of destination nodes and the size of destination node features. @BarclayII I guess we should check possible duplications in seed_nodes before sampling. What do you think?

sidazhou commented 2 years ago

Surely it's a bug, right? Because dataloader is yielding mfgs that cannot be used as input for model()

RManLuo commented 2 years ago

Hi, I am also facing this problem. The seed_nodes I input contains some duplicated ids. But I need to get these duplicate embeddings. Is there any solution for now?

RManLuo commented 2 years ago

I try to use dgl.dataloading.MultiLayerFullNeighborSampler to sample the blocks for a set of seed_nodes which contains the duplicated items. If I sample them in CPU, the returned mfg would contain inconsistent results. However, if I sample them in GPU, the duplicated seed nodes would not be removed. I think sample results on different devices should be the same.

import torch
import dgl

src = torch.LongTensor(
    [0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10,
     1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11])
dst = torch.LongTensor(
    [1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11,
     0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10])
g = dgl.graph((src, dst))

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

# Sample in CPU
idx = torch.LongTensor([8,8])
src_nodes, dst_nodes, mfgs = sampler.sample_blocks(g, idx)
print(dst_nodes) # tensor([8, 8])
print(mfgs[-1].num_dst_nodes()) # 1
print(mfgs[-1].dstdata) # {'_ID': tensor([8, 8])}
# Inconsistant

# Sample in GPU
device = torch.device('cuda:0')
src_nodes, dst_nodes, mfgs = sampler.sample_blocks(g.to(device), idx.to(device))
print(dst_nodes) # tensor([8, 8], device='cuda:0')
print(mfgs[-1].num_dst_nodes()) # 2
print(mfgs[-1].dstdata) # {'_ID': tensor([8, 8], device='cuda:0')}
# Consistant
FAF-D2 commented 2 years ago

BTW, it seems that sampling in GPU still can' t solve the problem of duplicated nodes in heterograph

import torch
import dgl
import dgl.nn.pytorch as dglnn

src = torch.LongTensor(
    [0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10,
     1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11])
dst = torch.LongTensor(
    [1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11,
     0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10])
graph_data = {
        ('user', 'plays', 'game') : (src, dst),
        ('user', 'follows', 'user'): (torch.LongTensor([0, 1, 2, 3]), torch.LongTensor([5, 6, 7, 8]))
    }
g = dgl.heterograph(graph_data)
g.nodes['user'].data['h'] = torch.ones(g.num_nodes('user'), 16)
g.nodes['game'].data['h'] = torch.ones(g.num_nodes('game'), 16)

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

uid = torch.LongTensor([0, 0, 2, 2, 4, 4])
device = torch.device('cuda:0')
src_nodes, dst_nodes, mfgs = sampler.sample_blocks(g.to(device), {'user': uid.to(device)})
print(dst_nodes)  # {'user': tensor([0, 0, 2, 2, 4, 4], device='cuda:0')}
print(mfgs[-1].num_dst_nodes()) # 6

conv1 = dglnn.HeteroGraphConv({
                    'plays': dglnn.SAGEConv(16, 32, 'gcn'),
                    'follows': dglnn.SAGEConv(16, 32, 'gcn')
                }, 'sum').to(device)
conv2 = dglnn.HeteroGraphConv({
                    'plays': dglnn.SAGEConv(32, 32, 'gcn'),
                    'follows': dglnn.SAGEConv(32, 32, 'gcn')
                }, 'sum').to(device)

out = mfgs[0].srcdata['h']

print(mfgs[0].num_dst_nodes()) # 3
print(len(out['game'])) # 0
print(len(out['user'])) # 3

out = conv1(mfgs[0], out)

print(mfgs[1].num_dst_nodes()) # 6
print(len(out['game'])) # 0
print(len(out['user'])) # 3

out = conv2(mfgs[1], out) # Error
czkkkkkk commented 2 years ago

Sorry we currently don't support duplicate values in the seed nodes for sampler. We’ve added it to our backlog to get prioritized over other feature requests in our roadmap.

yueliu1999 commented 1 year ago

I also met this error on the paper100M dataset. Has this bug been fixed yet? Are there any other potential solutions?

czkkkkkk commented 1 year ago

Haven't been solved yet. We suggest users to explicitly deduplicate seed nodes.