pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.11k stars 3.63k forks source link

Does GraphSAGE support multi-GPUs training? #1447

Closed maqy1995 closed 4 years ago

maqy1995 commented 4 years ago

❓ Questions & Help

There is a multi gpus implementation of GraphSAGE in DGL: https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_multi_gpu.py

The example used DistributedDataParallel , and the subgraph will be sampled in each process.

Can pyG also implement multi-gpus training in this way?

rusty1s commented 4 years ago

DistributedDataParallel support is planned for PyG 1.7 or PyG 1.8, so please stay tuned :)

Just out of curiosity: For what use-case do you desperately need multi-GPU support? IMO, you can train large graphs, e.g., the ones from OGB, already in like 10 minutes.

maqy1995 commented 4 years ago

Thank you for your quick reply and look forward to the new features of pyG.

I don't have an actual scenario to use multi-GPU training. In fact, I only learned pytorch and gnn not long ago, and of course also pyG and DGL. Recently I am learning graphSAGE/GAT, a simple learning process is from single-machine single-GPU to single-machine multi-GPU to multi-machine multi-GPU. Today I saw an example of GraphSAGE for multi-GPU in DGL, so I want to follow this example and implement it in pyG, but it did not run successfully, so I asked this question.

What puzzles me is, why can't pyG now use DistributedDataParallel to implement multi-GPU training? From the multi-GPU example of DGL, it seems that only pytorch native methods are used(torch.multiprocessing and DistributedDataParallel), and each process has its own sampler and dataloader to generate subgraph used in mini-batch training. Can't pyG also use this method to implement multi-GPU training? Are there other details here that I haven't noticed? Looking forward to your reply.

rusty1s commented 4 years ago

Yes, the idea would be to utilize torch.multiprocessing and DistributedDataParallel from PyTorch, and have each client sample its own data. I doubt that this is hard to implement, however, I haven't tested it yet so its definitely unsupported at the moment. I will look into it in the upcoming weeks, and what may needs to be done to allow distributed training :)

maqy1995 commented 4 years ago

Today I continued to try it, and it seems that the single-machine multi-GPU can run successfully(I modified this code to use Reddit dataset:https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ogbn_products_gat.py).

Compared with the multi-GPU training implemented in DGL, the main difference is:

    if n_gpus == 1:
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
            p = mp.Process(target=thread_wrapped_func(run),
                           args=(proc_id, n_gpus, args, devices, data))
            p.start()
            procs.append(p)
        for p in procs:
            p.join()

should be changed to(this is where I went wrong yesterday):

    if n_gpus == 1:
        run(0, n_gpus, args, devices, data)
    else:
        mp.spawn(run, args=(n_gpus, devices, data), nprocs=n_gpus, join=True)

But I don’t know if there are other hidden bugs

maqy1995 commented 4 years ago

The complete code is below, if there are any bugs, please tell me.

# -*- coding: utf-8 -*-
"""
@Author: maqy
@Time: 2020/7/16
@Description:
"""
import argparse
import os.path as osp

import torch
import torch.nn.functional as F
import time
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import numpy as np
from torch.nn import Linear as Lin
from tqdm import tqdm
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Reddit
# from utils import thread_wrapped_func

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 heads):
        super(GAT, self).__init__()

        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels,heads))
        for _ in range(num_layers - 2):
            self.convs.append(
                GATConv(heads * hidden_channels, hidden_channels, heads))
        self.convs.append(
            GATConv(heads * hidden_channels, out_channels, heads,
                    concat=False))

        # residual
        self.skips = torch.nn.ModuleList()
        self.skips.append(Lin(in_channels, hidden_channels * heads))
        for _ in range(num_layers - 2):
            self.skips.append(
                Lin(hidden_channels * heads, hidden_channels * heads))
        self.skips.append(Lin(hidden_channels * heads, out_channels))

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for skip in self.skips:
            skip.reset_parameters()

    def forward(self, x, adjs):
        # `train_loader` computes the k-hop neighborhood of a batch of nodes,
        # and returns, for each layer, a bipartite graph object, holding the
        # bipartite edges `edge_index`, the index `e_id` of the original edges,
        # and the size/shape `size` of the bipartite graph.
        # Target nodes are also included in the source nodes so that one can
        # easily apply skip-connections or add self-loops.
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            x = self.convs[i]((x, x_target), edge_index)
            x = x + self.skips[i](x_target)
            if i != self.num_layers - 1:
                x = F.elu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x.log_softmax(dim=-1)

    def inference(self, x_all, device, subgraph_loader):
        pbar = tqdm(total=x_all.size(0) * self.num_layers)
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch.
        total_edges = 0
        for i in range(self.num_layers):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                total_edges += edge_index.size(1)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                x = x + self.skips[i](x_target)

                if i != self.num_layers - 1:
                    x = F.elu(x)
                xs.append(x.cpu())

                pbar.update(batch_size)

            x_all = torch.cat(xs, dim=0)

        pbar.close()

        return x_all

@torch.no_grad()
def test(model, x, y, data, device, subgraph_loader):
    model.eval()

    out = model.inference(x, device, subgraph_loader)

    y_true = y.cpu().unsqueeze(-1)
    y_pred = out.argmax(dim=-1, keepdim=True)

    results = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]

    return results

def run(proc_id, n_gpus, devices, dataset, args):
    data = dataset[0]

    dev_id = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12346')
        world_size = n_gpus
        torch.distributed.init_process_group(backend="nccl",
                                             init_method=dist_init_method,
                                             world_size=world_size,
                                             rank=proc_id)
    torch.cuda.set_device(dev_id)

    train_nid = torch.LongTensor(np.nonzero(data.train_mask))
    # Split trian_id
    train_nid = torch.split(train_nid, len(train_nid) // n_gpus)[proc_id]

    train_loader = NeighborSampler(data.edge_index, node_idx=train_nid.view(-1),
                                   sizes=list(map(int, args.sample_size.split(','))),
                                   batch_size=args.train_batch_size,
                                   shuffle=True, num_workers=args.num_workers)

    model = GAT(dataset.num_features, args.num_hidden, dataset.num_classes, num_layers=args.num_layers,
                heads=args.num_heads)
    model = model.to(dev_id)
    model.reset_parameters()

    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    x = data.x.to(dev_id)
    y = data.y.squeeze().to(dev_id)

    for epoch in range(1, args.num_epochs):
        tic = time.time()
        device = dev_id
        model.train()

        step = 0
        for batch_size, n_id, adjs in train_loader:
            # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
            adjs = [adj.to(device) for adj in adjs]

            optimizer.zero_grad()
            out = model(x[n_id], adjs)
            loss = F.nll_loss(out, y[n_id[:batch_size]])
            loss.backward()
            if n_gpus > 1:
                for param in model.parameters():
                    if param.requires_grad and param.grad is not None:
                        torch.distributed.all_reduce(param.grad.data,
                                                     op=torch.distributed.ReduceOp.SUM)
                        param.grad.data /= n_gpus
            optimizer.step()

            if step % args.log_every == 0 and proc_id == 0:
                pred_correct = int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
                cur_acc = float(pred_correct) / batch_size
                gpu_mem_alloc = torch.cuda.max_memory_allocated() / 1000000 if torch.cuda.is_available() else 0
                print(
                    'Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MiB'.format(
                        epoch, step, loss.item(), cur_acc, gpu_mem_alloc))

            step = step + 1

        if n_gpus > 1:
            torch.distributed.barrier()

        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))

        # eval only use proc0
        if epoch % args.eval_every == 0 and epoch != 0 and proc_id == 0:
            subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1],
                                              batch_size=args.eval_batch_size, shuffle=False,
                                              num_workers=args.num_workers)
            if n_gpus == 1:
                train_acc, val_acc, test_acc = test(model, x, y, data, devices[0], subgraph_loader)
            else:
                train_acc, val_acc, test_acc = test(model.module, x, y, data, devices[0], subgraph_loader)
            print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
                  f'Test: {test_acc:.4f}')

        if n_gpus > 1:
            torch.distributed.barrier()

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
    argparser.add_argument('--gpu', type=str, default='0,1,2',
                           help="Comma separated list of GPU device IDs.")
    argparser.add_argument('--num-epochs', type=int, default=31)
    argparser.add_argument('--num-hidden', type=int, default=128)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--sample-size', type=str, default='10,25')
    argparser.add_argument('--num-heads', type=int, default=4)
    argparser.add_argument('--train-batch-size', type=int, default=1024)
    argparser.add_argument('--eval-batch-size', type=int, default=1024)
    argparser.add_argument('--eval-every', type=int, default=5)
    argparser.add_argument('--lr', type=int, default=0.001)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--num-workers', type=int, default=2,
                           help="Number of sampling processes. Use 0 for no extra process.")

    args = argparser.parse_args()

    process_start_time = time.time()

    devices = list(map(int, args.gpu.split(',')))
    n_gpus = len(devices)

    # change for local data
    path = osp.join('/home/maqy/gnn/dataset/pyG/')
    load_data_start_time = time.time()
    dataset = Reddit(path)
    load_data_end_time = time.time()
    print("pyG load reddit data time: {:.4f} s".format(load_data_end_time - load_data_start_time))

    if n_gpus == 1:
        run(0, n_gpus, devices, dataset, args)
    else:
        mp.spawn(run, args=(n_gpus, devices, dataset, args), nprocs=n_gpus, join=True)
        # from dgl
        # procs = []
        # for proc_id in range(n_gpus):
        #     p = mp.Process(target=thread_wrapped_func(run),
        #                    args=(proc_id, n_gpus, devices, dataset))
        #     p.start()
        #     procs.append(p)
        #
        # for p in procs:
        #     p.join()

ps: I'm not use thread_wrapped_func from 'utils' to wrap function 'run' (DGL example used it), because it will get an error:

_pickle.PicklingError: Can't pickle <function run at 0x19204ea2412>: it's not the same object as main.run.

I am not sure if not using thread_wrapped_func will cause problems, for example, deadlock.

uitls is from: (https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/utils.py)

rusty1s commented 4 years ago

That looks cool! Are you interested in sending this as a PR? :)

maqy1995 commented 4 years ago

of course! I will google how to submit a PR now.

maqy1995 commented 4 years ago

I have submitted a PR, so close this issue.