pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

tensorboard support #1795

Closed level14taken closed 4 years ago

level14taken commented 4 years ago

❓ Questions and Help

Hey, does training using with 8num_cores in tpu cause problem in using tensorboard for visualizing training loss.Despite many trails it ends up with the following error *** End stack trace *** Failed to connect to client mesh master: localhost:47971 I'm using torch.utils.tensorboard

dlibenzi commented 4 years ago

If you are writing to the same file, writes will happen from multiple processes (or threads when using the old multi-threaded parallelism). So that might cause problems.

level14taken commented 4 years ago

I have done this xm.get_ordinal()==0 to selectively write to the file, then the training seems to end up with same error, any ideas on what can be done??

dlibenzi commented 4 years ago

Can you show the sode snippet?

level14taken commented 4 years ago

writer = SummaryWriter('runs/unet_experiment_2')

for image, mask in loader: y_pred = model(image) loss = loss_fn(y_pred, mask) if xm.get_ordinal()==0: writer.add_scalar('TestLoss',loss,epoch) print('[xla:{}] Loss={:.5f}'.format( xm.get_ordinal(),loss.item()), flush=True)

zcain117 commented 4 years ago

Here is an example of how we use the SummaryWriter with multiprocess: https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py#L113

We use xm.is_master_ordinal and interact with the summary writer via test_utils

dlibenzi commented 4 years ago

writer = SummaryWriter('runs/unet_experiment_2')

for image, mask in loader: y_pred = model(image) loss = loss_fn(y_pred, mask) if xm.get_ordinal()==0: writer.add_scalar('TestLoss',loss,epoch) print('[xla:{}] Loss={:.5f}'.format( xm.get_ordinal(),loss.item()), flush=True)

Every core should run the same graph. So in your example you should:

loss = loss_fn(y_pred, mask)
loss_value = loss.item()
if xm.get_ordinal() == 0:
  write(loss_value)
dlibenzi commented 4 years ago

In general though, we recommend not to write data at every step, for performance reasons. It is better to write every N steps, and use the step-closures (especially if you are writing more values than the simple loss).

https://github.com/pytorch/xla/blob/e58a6a213173dce566352cf626387104b12a514e/test/test_train_mp_mnist.py#L129

level14taken commented 4 years ago

@zcain117 , i have tried the exact same code (xla/test/test_train_mp_mnist.py) in colab,but it didn't work(process 2 terminated with exit code 1). Then i changed the start_method = fork, then the code resulted in the same error as above(***End stack trace *** Failed to connect to client mesh master: localhost:51627)

@dlibenzi I tried your approach as well but it didn't make the error go away either.(***End stack trace *** Failed to connect to client mesh master: localhost:51627)

So can you suggest any other way i can use tensorboard with tpu in colab??

zcain117 commented 4 years ago

Can you post a bit more of the stack trace?

level14taken commented 4 years ago
-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 116, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 108, in _setup_replication
    device = xm.xla_device()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 137, in xla_device
    devkind=[devkind] if devkind is not None else None)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 41, in get_xla_supported_devices
    xla_devices = torch_xla._XLAC._xla_get_devices()
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:212 : Check failed: impl_->channel->WaitForConnected( std::chrono::system_clock::now() + std::chrono::seconds(connect_wait_seconds)) 
*** Begin stack trace ***
    tensorflow::CurrentStackTrace[abi:cxx11]()
    xla::service::MeshClient::MeshClient(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
    xla::service::MeshClient::Get()
    xla::ComputationClient::Create()
    xla::ComputationClient::Get()

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    PyObject_Call
    _PyEval_EvalFrameDefault

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyFunction_FastCallDict

    _PyObject_FastCallKeywords

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyFunction_FastCallDict

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    PyObject_Call
    _PyEval_EvalFrameDefault

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault
    _PyFunction_FastCallDict

    PyObject_Call
    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

    _PyEval_EvalFrameDefault

*** End stack trace ***
Failed to connect to client mesh master: localhost:51627
dlibenzi commented 4 years ago

Can you please post the repro code?

level14taken commented 4 years ago
datadir='/tmp/mnist-data'
batch_size=128
momentum=0.5
lr=0.01
target_accuracy=98.0
num_epochs=18

import os
import shutil
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def _train_update(device, x, loss, tracker):
  test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                   tracker.global_rate())

def train_mnist():
  torch.manual_seed(1)

  if False:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(batch_size, 1, 28,
                          28), torch.zeros(batch_size,
                                           dtype=torch.int64)),
        sample_count=60000 // batch_size // xm.xrt_world_size())
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(batch_size, 1, 28,
                          28), torch.zeros(batch_size,
                                           dtype=torch.int64)),
        sample_count=10000 // batch_size // xm.xrt_world_size())
  else:
    train_dataset = datasets.MNIST(
        os.path.join(datadir, str(xm.get_ordinal())),
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    test_dataset = datasets.MNIST(
        os.path.join(datadir, str(xm.get_ordinal())),
        train=False,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    train_sampler = None
    if xm.xrt_world_size() > 1:
      train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        drop_last=True,
        shuffle=False if train_sampler else True,
        num_workers=8)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        drop_last=True,
        shuffle=False,
        num_workers=8)

  # Scale learning rate to num cores
  lr =lr * xm.xrt_world_size()

  device = xm.xla_device()
  model = MNIST().to(device)
  writer = None
  if xm.is_master_ordinal():
    writer = test_utils.get_summary_writer(logdir)
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()
    print('here')
    model.train()
    for x, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(batch_size)
      if x % 10 == 0:
        xm.add_step_closure(_train_update, args=(device, x, loss, tracker))

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    test_utils.print_test_update(device, accuracy)
    return accuracy

  accuracy = 0.0
  max_accuracy = 0.0
  for epoch in range(1, num_epochs + 1):
    para_loader = pl.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))
    xm.master_print('Finished training epoch {}'.format(epoch))

    para_loader = pl.ParallelLoader(test_loader, [device])
    accuracy = test_loop_fn(para_loader.per_device_loader(device))
    max_accuracy = max(accuracy, max_accuracy)
    test_utils.write_to_summary(writer, epoch,
                                dict_to_write={'Accuracy/test': accuracy},
                                write_xla_metrics=True)
    if False:
      xm.master_print(met.metrics_report())

  test_utils.close_summary_writer(writer)
  xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
  return max_accuracy

def _mp_fn(index, flags):
  global FLAGS
  FLAGS = flags
  torch.set_default_tensor_type('torch.FloatTensor')
  accuracy = train_mnist()
  if False and os.path.isdir(datadir):
    shutil.rmtree(datadir)
  if accuracy < 90:
    print('Accuracy {} is below target {}'.format(accuracy,90))
    sys.exit(21)

if __name__ == '__main__':
 FLAGS={}
 xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8,start_method='fork')

I have only changed the start_method = 'fork' in the original,https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py#L113

dlibenzi commented 4 years ago

That seems unrelated. It errors out at xm.xla_device() which is way before talking to TB.

level14taken commented 4 years ago

sorry, I just restarted the notebook,the code was running just fine and Thankx