pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Multiprocessing (8) is broken on colab #1969

Closed anhmeow closed 4 years ago

anhmeow commented 4 years ago

🐛 Bug

Calling xm.spawn() with nproc=8 makes colab notebook running indefinitely. it works with nproc=1

If i stop the running cell and do !ps -ef | grep python, there are like 30+ python processes trigerred

To Reproduce

install torch_xla in a colab notebook and execute the following code

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

import numpy as np

def map_train(index=0):
    device = xm.xla_device() 

    generator = nn.Linear(1, 2).to(device)       #dummy generator
    classifier_1 = nn.Linear(2, 1).to(device)    #dummy classifier

    sampler = torch.utils.data.distributed.DistributedSampler(
                    list(np.arange(4, dtype=np.float32)),
                    num_replicas=xm.xrt_world_size(),
                    rank=xm.get_ordinal(),
                    shuffle=True)

    loader = torch.utils.data.DataLoader(
              list(np.arange(4, dtype=np.float32)),
              batch_size=4,
              sampler=sampler,
              num_workers=1,
              drop_last=True)

    para_loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    loss = torch.nn.BCEWithLogitsLoss().to(device)

    optimizer = torch.optim.Adam(list(generator.parameters()) + 
                                 list(classifier_1.parameters()))

    for count, number in enumerate(para_loader):
        print(index, count, number)
        number = number.unsqueeze(-1)

        optimizer.zero_grad()
        gen = generator(number)
        class1 = classifier_1(gen)

        loss_class1 = loss(class1, number)
        loss_class1.backward()
        xm.optimizer_step(optimizer)

    print('done')

    return 0

xmp.spawn(map_train, nprocs=8, start_method='fork')
#change nprocs to 1 and it will work
ncoop57 commented 4 years ago

Having a similar issue with nproc = 8 in Google Colab, which version of XLA are you using?

dlibenzi commented 4 years ago

I have tried with our MNIST multi-process training and it seems fine.

image

dlibenzi commented 4 years ago

Hmm, I do see the hanging if I start a new Colab notebook. Weird. Seems like the client mesh master does not come up. There is an error, but we saw this before IIRC and it was benign:

E0426 17:12:19.151082192    1316 server_chttp2.cc:40]        {"created":"@1587921139.151062454","description":"Only 1 addresses added out of total 2 resolved","file":"external/com_github_grpc_grpc/src/core/ext/transport/chttp2/server/chttp2_server.cc","file_line":404,"referenced_errors":[{"created":"@1587921139.151059962","description":"Address family not supported by protocol","errno":97,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/socket_utils_common_posix.cc","file_line":420,"os_error":"Address family not supported by protocol","syscall":"socket","target_address":"[::1]:56359"}]}

Unless we were lucky and both mesh client and server picked up the same INET address family, and now we got no more lucky. @jysohn23 Have you ever seen this in Colab?

dlibenzi commented 4 years ago

So that is not the issue. Even forcing IPV4 you can see the mesh clients starting and waiting to connect, the mesh master starting, but clients still hanging. Something must have changed in Colab OS setup.

dlibenzi commented 4 years ago

Argh! The code was incorrectly setup the data loader and sampler. So cores were ending due to un-even number of steps. Here's the code that works:

%%writefile train.py
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl

import numpy as np

def map_train(index):
    device = xm.xla_device() 

    generator = nn.Linear(1, 2).to(device)       #dummy generator
    classifier_1 = nn.Linear(2, 1).to(device)    #dummy classifier

    loader = torch.utils.data.DataLoader(
              list(np.arange(4, dtype=np.float32)) * 32,
              batch_size=4,
              sampler=None,
              num_workers=1,
              drop_last=True)

    para_loader = pl.ParallelLoader(loader, [device]).per_device_loader(device)
    loss = torch.nn.BCEWithLogitsLoss().to(device)

    optimizer = torch.optim.Adam(list(generator.parameters()) + 
                                 list(classifier_1.parameters()))

    for count, number in enumerate(para_loader):
        print(index, count)
        number = number.unsqueeze(-1)

        optimizer.zero_grad()
        gen = generator(number)
        class1 = classifier_1(gen)

        loss_class1 = loss(class1, number)
        loss_class1.backward()
        xm.optimizer_step(optimizer)

    print('done')

    return 0

if __name__ == '__main__':
    torch.set_default_tensor_type('torch.FloatTensor')
    xmp.spawn(map_train, args=(), nprocs=8, start_method='fork')
anhmeow commented 4 years ago

Hello,

Indeed, you are right, i over-simplified my sample case and make the error.

On colab, it works now but at the end, there are still an error : Exception in device=TPU:2: tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:367 : Check failed: session->session()->Run( session_work->feed_inputs, session_work->outputs_handles, &outputs) == ::tensorflow::Status::OK() (Unavailable: From /job:tpu_worker/replica:0/task:0: DisableableBlockingRefcount is disabled. [[{{node XRTAllocateFromTensor_8}}]] vs. OK) Begin stack trace tensorflow::CurrentStackTrace[abi:cxx11]()

On top of that, unfortunately, on my colab notebook that worked fined before, the issue is still presents : Hanging if i set nproc = 8 without any error, working if i set nproc = 1.

Is there something i can activate so i can get more logs ?

swaroop-nath commented 2 years ago

I have the same issue - running A3C algorithm with Pytorch Multiprocessing - everything works fine on my laptop, but the processes never seem to advance in Google colab