pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Status: Not found: Op type not registered 'XRTMemoryInfo' What does this relate to ? #2567

Closed maitham closed 4 years ago

maitham commented 4 years ago

I'm trying a custom LSTM architecture but seem to be getting the following error when tryin to train, I'm not sure where to start on debugging this.

Pytorch: 1.6 Pytorch lightning: latest Xla 1.6.0 TF: 2.3.1

**train.py** 

from torch.utils.data import Dataset, DataLoader, IterableDataset
import glob
import tqdm
import numpy as np
import sys
import torch
from model import RNNLitModel
from dataset import DataReader, TextDataset
from vocab import Vocab
from utils import open_list

import pytorch_lightning as pl
from pytorch_lightning import  seed_everything

if __name__ == "__main__":
    seed_everything(0)

    train_data = DataReader(
                dtype= "train",
                emails="/batched_data/*emails_train.pt",
                contexts="/batched_data/*subjects_train.pt",
                lengths=np.load("train_lengths.npy").tolist())

    valid_data = DataReader(
                dtype= "valid",
                emails="/batched_data/*emails_valid.pt",
                contexts="/batched_data/*subjects_valid.pt",
                lengths=np.load("valid_lengths.npy").tolist())

    test_data = DataReader(
                dtype= "test",
                emails="/home/maithamdib/batched_data/*emails_test.pt",
                contexts="/home/maithamdib/batched_data/*subjects_test.pt",
                lengths=np.load("test_lengths.npy").tolist())

    vocab = Vocab(open_list("combined_vocab.pkl"))

    batch_size=64
    rnn_type = "LSTM"
    vocab_size = len(vocab.itos)
    nhid = 1024
    emb_size = 256
    nlayers = 2
    bptt=60

    train_loader = DataLoader(TextDataset(80, train_data), num_workers=1, batch_size=None, pin_memory=True)
    valid_loader = DataLoader(TextDataset(80, valid_data), num_workers=1, batch_size=None, pin_memory=True)
    test_loader = DataLoader(TextDataset(80, test_data), num_workers=1, batch_size=None, pin_memory=True)

    model = RNNLitModel(0.1, rnn_type, vocab_size, emb_size, nhid, nlayers=2, dropout=0.5)

    trainer = pl.Trainer(val_check_interval=0.25, 
                     max_epochs=5,
                     gradient_clip_val=0.5, 
                     tpu_cores=8, 
                     auto_lr_find=True,
                     automatic_optimization=True,  
                     progress_bar_refresh_rate=20
                    )
    trainer.tune(model, train_loader)

    trainer.fit(model, train_loader, valid_loader)

**Error Log**
GPU available: False, used: False
TPU available: True, using: 1 TPU cores
training on 1 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
2020-10-21 19:19:10.923243: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] >>> Dumping Computation 0
2020-10-21 19:19:10.923308: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] HloModule SyncTensorsGraph.10
2020-10-21 19:19:10.923316: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923322: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] ENTRY %SyncTensorsGraph.10 (p0.1: f32[431]) -> (f32[431]) {
2020-10-21 19:19:10.923328: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.2 = f32[] constant(0)
2020-10-21 19:19:10.923334: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.3 = f32[1]{0} reshape(f32[] %constant.2)
2020-10-21 19:19:10.923340: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.4 = f32[1]{0} broadcast(f32[1]{0} %reshape.3), dimensions={0}
2020-10-21 19:19:10.923346: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %reshape.5 = f32[] reshape(f32[1]{0} %broadcast.4)
2020-10-21 19:19:10.923368: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %broadcast.6 = f32[431]{0} broadcast(f32[] %reshape.5), dimensions={}
2020-10-21 19:19:10.923396: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %p0.1 = f32[431]{0} parameter(0)
2020-10-21 19:19:10.923406: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %constant.7 = f32[] constant(0)
2020-10-21 19:19:10.923417: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   %pad.8 = f32[431]{0} pad(f32[431]{0} %p0.1, f32[] %constant.7), padding=0_0
2020-10-21 19:19:10.923435: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]   ROOT %tuple.9 = (f32[431]{0}) tuple(f32[431]{0} %pad.8)
2020-10-21 19:19:10.923444: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] }
2020-10-21 19:19:10.923456: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923468: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923482: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] OutputShape: (f32[431]{0})
2020-10-21 19:19:10.923494: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923505: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] StackTrace:
2020-10-21 19:19:10.923517: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** Begin stack trace ***
2020-10-21 19:19:10.923530: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]    tensorflow::CurrentStackTrace()
2020-10-21 19:19:10.923546: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]    xla::util::ReportComputationError(tensorflow::Status const&, absl::lts_2020_02_25::Span<xla::XlaComputation const* const>, absl::lts_2020_02_25::Span<xla::Shape const* const>)
2020-10-21 19:19:10.923563: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]    xla::XrtComputationClient::CheckCompileStatus(tensorflow::Status const&, std::vector<xla::ComputationClient::CompileInstance, std::allocator<xla::ComputationClient::CompileInstance> > const&, xla::XrtComputationClient::SessionWork const&)
2020-10-21 19:19:10.923580: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923595: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]    xla::util::MultiWait::Complete(std::function<void ()> const&)
2020-10-21 19:19:10.923609: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923624: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923639: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923653: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]    clone
2020-10-21 19:19:10.923668: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] *** End stack trace ***
2020-10-21 19:19:10.923683: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76]
2020-10-21 19:19:10.923698: E    9922 tensorflow/compiler/xla/xla_client/xla_util.cc:76] Status: Not found: Op type not registered 'XRTMemoryInfo' in binary running on n-29d6244c-w-0. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.
Traceback (most recent call last):
  File "train.py", line 63, in <module>
    trainer.fit(model, train_loader, valid_loader)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 439, in fit
    results = self.accelerator_backend.train()
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/accelerators/tpu_accelerator.py", line 98, in train
    start_method=self.start_method
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 387, in spawn
    _start_fn(0, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/accelerators/tpu_accelerator.py", line 123, in tpu_train_in_process
    self.trainer.train_loop.setup_training(model)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/trainer/training_loop.py", line 146, in setup_training
    self.trainer.on_pretrain_routine_start(ref_model)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/trainer/callback_hook.py", line 122, in on_pretrain_routine_start
    callback.on_pretrain_routine_start(self, model)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 161, in on_pretrain_routine_start
    self.__resolve_ckpt_dir(trainer, pl_module)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 405, in __resolve_ckpt_dir
    version, name = trainer.accelerator_backend.broadcast((version, trainer.logger.name))
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/pytorch_lightning/accelerators/tpu_accelerator.py", line 330, in broadcast
    buffer = io.BytesIO(data.cpu().byte().numpy())
RuntimeError: Not found: Op type not registered 'XRTMemoryInfo' in binary running on n-29d6244c-w-0. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.
terminate called after throwing an instance of 'std::runtime_error'
  what():  tensorflow/compiler/xla/xla_client/xrt_computation_client.cc:1110 : Check failed: session->session()->Run( feed_inputs, {}, {cached_node.operations[0]}, &outputs) == ::tensorflow::Status::OK() (Invalid argument: A session is not created yet.... vs. OK)
*** Begin stack trace ***
    tensorflow::CurrentStackTrace()
    xla::XrtComputationClient::ReleaseHandles(std::vector<xla::XrtComputationClient::DeviceHandle, std::allocator<xla::XrtComputationClient::DeviceHandle> >*, std::function<xla::XrtSession::CachedNode const& (xla::XrtSession*, tensorflow::Scope const&, std::string const&)> const&, xla::metrics::Metric*, xla::metrics::Counter*)
    xla::XrtComputationClient::HandleReleaser()
    xla::util::TriggeredTask::Runner()

    clone
*** End stack trace ***

Aborted
JackCaoG commented 4 years ago

Hi, this is most likely due to the TPU runtime version being too old. Did you select the TPU version to be --version=pytorch-1.6 ?

maitham commented 4 years ago

@JackCaoG Yep was using TF2.3.1 when should be using pytorch1.6. Thanks!