pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.03k stars 3.62k forks source link

DataLoader is slow #802

Open LeeHuan18 opened 4 years ago

LeeHuan18 commented 4 years ago

Hi, I generate my own graphs and I follow the webpage instruct for the DataLoader.

data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

But with this loader, the model trains very slow and the gpu consumption is only 2%. I doubt if it's because DataLoader runs slow?

So I use a simple dnn and use this DataLoader, it's still slow and the gpu usage is still 2%.

class Net(torch.nn.Module):
    def __init__(self, in_features=7, num_classes=2):
        super(Net, self).__init__()
        self.fc = nn.Sequential(
                  nn.Linear(in_features*4, 2**8),
                  nn.BatchNorm1d(2**8),
                  nn.ReLU(),
                  nn.Linear(2**8, num_classes)
        )

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = x.view(-1, 28)
        x = self.fc(x)        
        return F.softmax(x, dim=1)

So is it a problem of DataLoader or any possible reason?

rusty1s commented 4 years ago

I never witnessed this myself and always thought the DataLoader is fast as f***. Since your list is already holding data objects, this can only be caused by Batch.from_data_list. Can you actually measure the time consumption for a call such as:

torch_geometric.data.Batch.from_data_list(data_list[:32])

In addition, it would be great if you could send some details about your data objects. How big/small are your graphs?

LeeHuan18 commented 4 years ago

@rusty1s Using %%time, the result is:

CPU times: user 6.99 ms, sys: 153 µs, total: 7.15 ms
Wall time: 22.9 ms
Batch(batch=[128], edge_index=[2, 0], x=[128, 7], y=[32])

I don't have a concept it's fast or not...

About the graph, each single graph, as you can see above, has 0-edge, 4-node and each node is a 1-D array with 7 features. Also there is a additional feature y as the label of the graph. So the graph is very small.

rusty1s commented 4 years ago

I think it is quite understandable that your gpu usage is low with such small graphs and a batch_size of 32.

LeeHuan18 commented 4 years ago

So I increase the batch_size to 512, but it doesn't improve. Can you please give a hint how to improve it?

SaschaStenger commented 4 years ago

I asked a question, that might be related to yours. I had a high CPU load and a low GPU load and was able to fix it by changing the number of workers for the data loader Issue

LeeHuan18 commented 4 years ago

@SaschaStenger Thank you very much! let me try.

ghorbanimahdi73 commented 2 years ago

Hi. I have a graph data loader class that overrides the getitem with each datapoint having one anchor graph and 10 randomly selected graphs from the dataset and their label distances in a dictionary. During training the Dataloader becomes super slow relative to the number of graph neighbors (10 here) Following is a snapshot of the code. For making the dataset of neighbors I use the get() function to make a list of Data objects. Is there a way around this without using get()

class GraphTripleDataset(InMemoryDataset):
    def __init__(self,
                args,
                transform=None,
                pre_transform=None,
                mode='train',
                base_dir=None,
                n_neighbors=10):
        ...
    def _triplet_mining(self, anchor_idx):

        avail_idxs = range(len(self.data_list))
        n_idxs = np.random.randint(0, len(avail_idxs), size=self.n_neighbors)
        neighbors = [self.get(i) for i in n_idxs] 
        neighbor_distances = self.compute_distances_from_anchor(anchor_idx, neighbors_idx=n_idxs)
        return neighbors, torch.tensor(neighbor_distances)

    def __getitem__(self, idx):
        if (isinstance(idx, (int, np.integer)) or 
           (isinstance(idx, torch.Tensor) and idx.dim()==0 or 
           (isinstance(idx, np.ndarray) and np.isscalar(idx)))):

            anchor = self.get(self.indices()[idx])
        else:
            anchor = self.index_select(idx)

        neighbors, distances = self._triplet_mining(idx)

        return {
                'anchor': anchor,
                'neighbors': neighbors,
                'distances': distances}
rusty1s commented 2 years ago

Your code looks good to me. Do you know which is the slow part: getting all the data objects within __getitem__ or collating them together during DataLoader?

ghorbanimahdi73 commented 2 years ago

It is the __getitem__ that makes it slow. Since every datapoint has 10 neighbors, the dataloading is 10 times slower if I also save a list of neighbors. I am using the get() function to retrieve each data objects. I tried using index_select(idx) and giving it an array of indices as idx but that would return the full dataset and not the selected indices as calling dataset.index_select(idx).data would show all the data points.

rusty1s commented 2 years ago

Since you are requesting 10 more objects, it is expected that the data access takes 10 times longer - I don‘t think there is much we can do here besides improving speed of __getitem__ and batching overall, or by leveraging more workers. I still think that the main bottleneck should be in the collate function though (which does the heavy work of merging these objects into one), so might be good to confirm that this is indeed the case.

ghorbanimahdi73 commented 2 years ago

So if I understand it correctly collate only affects the preprocessing of data and collating it into a huge Data object. Does it also affect batching during the data loading? In that case do I need to redefine the collate function?

rusty1s commented 2 years ago

So the DataLoader does two things: (1) It gets all graphs/examples that we want to batch together (the __getitem__ call), (2) it batches these graphs together into a huge data object. Usually, step (2) is more expensive since (1) just refers to a simple list access without any real computation taking place.

ghorbanimahdi73 commented 2 years ago

Sure, that is right. Collating takes time. In this case, should I avoid collating and load individual graphs from gpu scratch directory or is there another way to redefine the collate function?

rusty1s commented 2 years ago

Can you clarify what you mean with re-defining the collate function? I can look into any bottlenecks of the collate function. Do you have a small example to reproduce the bottleneck?

AneesBKazi commented 1 year ago

Hi, I am facing a similar problem. Each of my graph in a graph classification problem is a 3k x 3k size. Each epoch takes 6s. When checked with snakeviz, I see that dataloader : next --> fetch takes the maximum time. Any idea how to fix this @rusty1s?

rusty1s commented 1 year ago

Do you have a reproducible example? What num_workers configuration are you using?

OliEfr commented 5 months ago

To add my two pence: I had a similar problem (slow collate), and I solved it by saving the batched data to disk (i.e. collating them offline before training), and then loading the batched samples directly in a custom dataset. See the docs here.

In that case I created the batches using a PyG Dataloader and Dataset. I then iterate through this dataloader and save all batches to disk. Then during training, I use a custom pytorch Dataset that loads the batches during __getitem__ from disk and returns it. For the pytorch Dataloader simply set batch_size=None to disable automatic batching.