idekazuki / diary

記録
0 stars 1 forks source link

Getting started with distributed data parallel #85

Open idekazuki opened 4 years ago

idekazuki commented 4 years ago

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html 上のリンクの日本語訳+メモ

DistributedDataParallel(DDP)は、モジュールレベルでデータの並列処理を実装することができる。 torch.distributed packageを使用することで勾配、パラメーター、バッファーを同期できる。 並列処理は、プロセス内、ブロセス間で利用できる。

プロセスがGPUデバイスを共有しない限り、プロセスをどの使用可能なリソースに割り当てるのかはユーザー次第。

推奨される方法は、すべてのモジュールレプリカに対してプロセスを作成すること。つまり、プロセス内にモジュールの複製はしない。

idekazuki commented 4 years ago

DataParallelとDistributedDataParallelの比較 なぜより複雑になるのにDataParallelではなくDDPの使用を検討する必要があるのかを明確にする。

idekazuki commented 4 years ago

DDPモジュールを作成するには、まずプロセスグループを適切に設定する必要がある。

import os
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # Explicitly setting seed to make sure that models created in two processes
    # start from same random weights and biases.
    torch.manual_seed(42)

def cleanup():
    dist.destroy_process_group()
idekazuki commented 4 years ago

DDPの他のすべてのメソッドを呼び出す前にtorch.distributed.init_process_group()を呼び出す必要がある。

init_process_group() 初期化には主に2つ方法がある。 store, rank, world_sizeを明示的に指定する。 init_method でpeerを検出する場所と方法を示すurlを指定する。

どちらも指定されていない場合、init_methodは"env://"とみなされる。 parameter

訓練がランダムなパラメーターから開始するとき、すべてのDDPプロセスが同じ初期値を使用する必要がある。そうでないとglobal gradient の同期は意味がない。

idekazuki commented 4 years ago
class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

def demo_basic(rank, world_size):
    setup(rank, world_size)

    # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
    # rank 2 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // world_size
    device_ids = list(range(rank * n, (rank + 1) * n))

    # create model and move it to device_ids[0]
    model = ToyModel().to(device_ids[0])
    # output_device defaults to device_ids[0]
    ddp_model = DDP(model, device_ids=device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)
idekazuki commented 4 years ago
idekazuki commented 4 years ago
def demo_checkpoint(rank, world_size):
    setup(rank, world_size)

    # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
    # rank 2 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // world_size
    device_ids = list(range(rank * n, (rank + 1) * n))

    model = ToyModel().to(device_ids[0])
    # output_device defaults to device_ids[0]
    ddp_model = DDP(model, device_ids=device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    rank0_devices = [x - rank * len(device_ids) for x in device_ids]
    device_pairs = zip(rank0_devices, device_ids)
    map_location = {'cuda:%d' % x: 'cuda:%d' % y for x, y in device_pairs}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn = nn.MSELoss()
    loss_fn(outputs, labels).backward()
    optimizer.step()

    # Use a barrier() to make sure that all processes have finished reading the
    # checkpoint
    dist.barrier()

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()
idekazuki commented 4 years ago

DDPとモデルの並列処理の組み合わせ DDPはマルチGPUにモデルでも動作するが、プロセス内の複製はサポートしていない。モジュールレプリカごとに1つのプロセスを作成する必要がある。通常、プロセスごとに複数のレプリカを作成するよりもパフォーマンスは向上する。DDPラッピングマルチGPUモデルは、大量のデータを含む大規模モデルを訓練するときに役立つ。この機能を使用するとき、異なるモデルのレプリカが異なるデバイスに配置されるため、ハードコーディングされたデバイスを避けるために、マルチGPUモデルを慎重に実装する必要がある。

idekazuki commented 4 years ago
class ToyMpModel(nn.Module):
    def __init__(self, dev0, dev1):
        super(ToyMpModel, self).__init__()
        self.dev0 = dev0
        self.dev1 = dev1
        self.net1 = torch.nn.Linear(10, 10).to(dev0)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 5).to(dev1)

    def forward(self, x):
        x = x.to(self.dev0)
        x = self.relu(self.net1(x))
        x = x.to(self.dev1)
        return self.net2(x)
idekazuki commented 4 years ago

マルチGPUモデルをDDPに渡す場合、device_ids及びoutput_deviceを設定しないこと。入力または出力データは適切なデバイスに配置される。

def demo_model_parallel(rank, world_size):
    setup(rank, world_size)

    # setup mp_model and devices for this process
    dev0 = rank * 2
    dev1 = rank * 2 + 1
    mp_model = ToyMpModel(dev0, dev1)
    ddp_mp_model = DDP(mp_model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_mp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    # outputs will be on dev1
    outputs = ddp_mp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(dev1)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()

if __name__ == "__main__":
    run_demo(demo_basic, 2)
    run_demo(demo_checkpoint, 2)

    if torch.cuda.device_count() >= 8:
        run_demo(demo_model_parallel, 4)