Closed maqy1995 closed 4 years ago
DistributedDataParallel
support is planned for PyG 1.7 or PyG 1.8, so please stay tuned :)
Just out of curiosity: For what use-case do you desperately need multi-GPU support? IMO, you can train large graphs, e.g., the ones from OGB, already in like 10 minutes.
Thank you for your quick reply and look forward to the new features of pyG.
I don't have an actual scenario to use multi-GPU training. In fact, I only learned pytorch and gnn not long ago, and of course also pyG and DGL. Recently I am learning graphSAGE/GAT, a simple learning process is from single-machine single-GPU to single-machine multi-GPU to multi-machine multi-GPU. Today I saw an example of GraphSAGE for multi-GPU in DGL, so I want to follow this example and implement it in pyG, but it did not run successfully, so I asked this question.
What puzzles me is, why can't pyG now use DistributedDataParallel to implement multi-GPU training? From the multi-GPU example of DGL, it seems that only pytorch native methods are used(torch.multiprocessing and DistributedDataParallel), and each process has its own sampler and dataloader to generate subgraph used in mini-batch training. Can't pyG also use this method to implement multi-GPU training? Are there other details here that I haven't noticed? Looking forward to your reply.
Yes, the idea would be to utilize torch.multiprocessing
and DistributedDataParallel
from PyTorch, and have each client sample its own data. I doubt that this is hard to implement, however, I haven't tested it yet so its definitely unsupported at the moment. I will look into it in the upcoming weeks, and what may needs to be done to allow distributed training :)
Today I continued to try it, and it seems that the single-machine multi-GPU can run successfully(I modified this code to use Reddit dataset:https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ogbn_products_gat.py).
Compared with the multi-GPU training implemented in DGL, the main difference is:
if n_gpus == 1:
run(0, n_gpus, args, devices, data)
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
p.join()
should be changed to(this is where I went wrong yesterday):
if n_gpus == 1:
run(0, n_gpus, args, devices, data)
else:
mp.spawn(run, args=(n_gpus, devices, data), nprocs=n_gpus, join=True)
But I don’t know if there are other hidden bugs
The complete code is below, if there are any bugs, please tell me.
# -*- coding: utf-8 -*-
"""
@Author: maqy
@Time: 2020/7/16
@Description:
"""
import argparse
import os.path as osp
import torch
import torch.nn.functional as F
import time
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
import numpy as np
from torch.nn import Linear as Lin
from tqdm import tqdm
from torch_geometric.data import NeighborSampler
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Reddit
# from utils import thread_wrapped_func
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
heads):
super(GAT, self).__init__()
self.num_layers = num_layers
self.convs = torch.nn.ModuleList()
self.convs.append(GATConv(in_channels, hidden_channels,heads))
for _ in range(num_layers - 2):
self.convs.append(
GATConv(heads * hidden_channels, hidden_channels, heads))
self.convs.append(
GATConv(heads * hidden_channels, out_channels, heads,
concat=False))
# residual
self.skips = torch.nn.ModuleList()
self.skips.append(Lin(in_channels, hidden_channels * heads))
for _ in range(num_layers - 2):
self.skips.append(
Lin(hidden_channels * heads, hidden_channels * heads))
self.skips.append(Lin(hidden_channels * heads, out_channels))
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for skip in self.skips:
skip.reset_parameters()
def forward(self, x, adjs):
# `train_loader` computes the k-hop neighborhood of a batch of nodes,
# and returns, for each layer, a bipartite graph object, holding the
# bipartite edges `edge_index`, the index `e_id` of the original edges,
# and the size/shape `size` of the bipartite graph.
# Target nodes are also included in the source nodes so that one can
# easily apply skip-connections or add self-loops.
for i, (edge_index, _, size) in enumerate(adjs):
x_target = x[:size[1]] # Target nodes are always placed first.
x = self.convs[i]((x, x_target), edge_index)
x = x + self.skips[i](x_target)
if i != self.num_layers - 1:
x = F.elu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x.log_softmax(dim=-1)
def inference(self, x_all, device, subgraph_loader):
pbar = tqdm(total=x_all.size(0) * self.num_layers)
pbar.set_description('Evaluating')
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
total_edges = 0
for i in range(self.num_layers):
xs = []
for batch_size, n_id, adj in subgraph_loader:
edge_index, _, size = adj.to(device)
total_edges += edge_index.size(1)
x = x_all[n_id].to(device)
x_target = x[:size[1]]
x = self.convs[i]((x, x_target), edge_index)
x = x + self.skips[i](x_target)
if i != self.num_layers - 1:
x = F.elu(x)
xs.append(x.cpu())
pbar.update(batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
@torch.no_grad()
def test(model, x, y, data, device, subgraph_loader):
model.eval()
out = model.inference(x, device, subgraph_loader)
y_true = y.cpu().unsqueeze(-1)
y_pred = out.argmax(dim=-1, keepdim=True)
results = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]
return results
def run(proc_id, n_gpus, devices, dataset, args):
data = dataset[0]
dev_id = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12346')
world_size = n_gpus
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=proc_id)
torch.cuda.set_device(dev_id)
train_nid = torch.LongTensor(np.nonzero(data.train_mask))
# Split trian_id
train_nid = torch.split(train_nid, len(train_nid) // n_gpus)[proc_id]
train_loader = NeighborSampler(data.edge_index, node_idx=train_nid.view(-1),
sizes=list(map(int, args.sample_size.split(','))),
batch_size=args.train_batch_size,
shuffle=True, num_workers=args.num_workers)
model = GAT(dataset.num_features, args.num_hidden, dataset.num_classes, num_layers=args.num_layers,
heads=args.num_heads)
model = model.to(dev_id)
model.reset_parameters()
if n_gpus > 1:
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
x = data.x.to(dev_id)
y = data.y.squeeze().to(dev_id)
for epoch in range(1, args.num_epochs):
tic = time.time()
device = dev_id
model.train()
step = 0
for batch_size, n_id, adjs in train_loader:
# `adjs` holds a list of `(edge_index, e_id, size)` tuples.
adjs = [adj.to(device) for adj in adjs]
optimizer.zero_grad()
out = model(x[n_id], adjs)
loss = F.nll_loss(out, y[n_id[:batch_size]])
loss.backward()
if n_gpus > 1:
for param in model.parameters():
if param.requires_grad and param.grad is not None:
torch.distributed.all_reduce(param.grad.data,
op=torch.distributed.ReduceOp.SUM)
param.grad.data /= n_gpus
optimizer.step()
if step % args.log_every == 0 and proc_id == 0:
pred_correct = int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())
cur_acc = float(pred_correct) / batch_size
gpu_mem_alloc = torch.cuda.max_memory_allocated() / 1000000 if torch.cuda.is_available() else 0
print(
'Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MiB'.format(
epoch, step, loss.item(), cur_acc, gpu_mem_alloc))
step = step + 1
if n_gpus > 1:
torch.distributed.barrier()
toc = time.time()
if proc_id == 0:
print('Epoch Time(s): {:.4f}'.format(toc - tic))
# eval only use proc0
if epoch % args.eval_every == 0 and epoch != 0 and proc_id == 0:
subgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1],
batch_size=args.eval_batch_size, shuffle=False,
num_workers=args.num_workers)
if n_gpus == 1:
train_acc, val_acc, test_acc = test(model, x, y, data, devices[0], subgraph_loader)
else:
train_acc, val_acc, test_acc = test(model.module, x, y, data, devices[0], subgraph_loader)
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')
if n_gpus > 1:
torch.distributed.barrier()
if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=str, default='0,1,2',
help="Comma separated list of GPU device IDs.")
argparser.add_argument('--num-epochs', type=int, default=31)
argparser.add_argument('--num-hidden', type=int, default=128)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--sample-size', type=str, default='10,25')
argparser.add_argument('--num-heads', type=int, default=4)
argparser.add_argument('--train-batch-size', type=int, default=1024)
argparser.add_argument('--eval-batch-size', type=int, default=1024)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=int, default=0.001)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--num-workers', type=int, default=2,
help="Number of sampling processes. Use 0 for no extra process.")
args = argparser.parse_args()
process_start_time = time.time()
devices = list(map(int, args.gpu.split(',')))
n_gpus = len(devices)
# change for local data
path = osp.join('/home/maqy/gnn/dataset/pyG/')
load_data_start_time = time.time()
dataset = Reddit(path)
load_data_end_time = time.time()
print("pyG load reddit data time: {:.4f} s".format(load_data_end_time - load_data_start_time))
if n_gpus == 1:
run(0, n_gpus, devices, dataset, args)
else:
mp.spawn(run, args=(n_gpus, devices, dataset, args), nprocs=n_gpus, join=True)
# from dgl
# procs = []
# for proc_id in range(n_gpus):
# p = mp.Process(target=thread_wrapped_func(run),
# args=(proc_id, n_gpus, devices, dataset))
# p.start()
# procs.append(p)
#
# for p in procs:
# p.join()
ps: I'm not use thread_wrapped_func from 'utils' to wrap function 'run' (DGL example used it), because it will get an error:
_pickle.PicklingError: Can't pickle <function run at 0x19204ea2412>: it's not the same object as main.run.
I am not sure if not using thread_wrapped_func will cause problems, for example, deadlock.
uitls is from: (https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/utils.py)
That looks cool! Are you interested in sending this as a PR? :)
of course! I will google how to submit a PR now.
I have submitted a PR, so close this issue.
❓ Questions & Help
There is a multi gpus implementation of GraphSAGE in DGL: https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_multi_gpu.py
The example used DistributedDataParallel , and the subgraph will be sampled in each process.
Can pyG also implement multi-gpus training in this way?