jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

TPU deadlock #4153

Closed agemagician closed 1 year ago

agemagician commented 4 years ago

Hello,

I am trying to train reformer model using Trax and JAX. The training fails on Google Colab because of memory limitation. When I run it on google cloud server + TPU, it hangs on the "trax.supervised.Trainer".

The warning is as follows:

2020-08-26 17:46:37.421334: W external/org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc:601] TPU Execute is taking a long time. This might be due to a deadlock between multiple TPU cores or a very slow program.

The code is very straight forward:

import requests
import os

from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + "10.206.164.18"
print(config.FLAGS.jax_backend_target)

from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.data import inputs

import numpy as np
import jax.numpy as jnp

from scipy.special import softmax

import sentencepiece as spm
from sentencepiece import SentencePieceProcessor

import random, glob, os

def fake_data():
  with open("vocab.txt",'w') as f:
    f.write("[MASK]\nL\nA\nG\nV\nE\nS\nI\nK\nR\nD\nT\nP\nN\nQ\nF\nY\nM\nH\nC\nW\nX\nU\nB\nZ\nO")

  if not os.path.exists('dataset'):
    os.makedirs('dataset')

  with open("dataset/train_0.txt",'w') as f:
    for i in range(50):
      f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
      f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
      f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
      f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
      f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
      f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
      f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
      f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
      f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n") 

    with open("dataset/train_1.txt",'w') as f:
      for i in range(50):
        f.write("M A F S A E D V L K E Y D R R R R M E A L L L S L Y Y P N D R K L L D Y K E W S P P R V Q V E C P K A P V E W N N P P S E K G L I V G H F S G I K Y K G E K A Q A S E V D V N K M C C W V S K F K D A M R R Y Q G I Q T C K I P G K V L S D L D A K I K A Y N L T V E G V E G F V R Y S R V T K Q H V A A F L K E L R H S K Q Y E N V N L I H Y I L T D K R V D I Q H L E K D L V K D F K A L V E S A H R M R Q G H M I N V K Y I L Y Q L L K K H G H G P D G P D I L T V K T G S K G V L Y D D S F R K I Y T D L G W K F T P L\n")
        f.write("M S I I G A T R L Q N D K S D T Y S A G P C Y A G G C S A F T P R G T C G K D W D L G E Q T C A S G F C T S Q P L C A R I K K T Q V C G L R Y S S K G K D P L V S A E W D S R G A P Y V R C T Y D A D L I D T Q A Q V D Q F V S M F G E S P S L A E R Y C M R G V K N T A G E L V S R V S S D A D P A G G W C R K W Y S A H R G P D Q D A A L G S F C I K N P G A A D C K C I N R A S D P V Y Q K V K T L H A Y P D Q C W Y V P C A A D V G E L K M G T Q R D T P T N C P T Q V C Q I V F N M L D D G S V T M D D V K N T I N C D F S K Y V P P P P P P K P T P P T P P T P P T P P T P P T P P T P P T P R P V H N R K V M F F V A G A V L V A I L I S T V R W\n")
        f.write("M A S N T V S A Q G G S N R P V R D F S N I Q D V A Q F L L F D P I W N E Q P G S I V P W K M N R E Q A L A E R Y P E L Q T S E P S E D Y S G P V E S L E L L P L E I K L D I M Q Y L S W E Q I S W C K H P W L W T R W Y K D N V V R V S A I T F E D F Q R E Y A F P E K I Q E I H F T D T R A E E I K A I L E T T P N V T R L V I R R I D D M N Y N T H G D L G L D D L E F L T H L M V E D A C G F T D F W A P S L T H L T I K N L D M H P R W F G P V M D G I K S M Q S T L K Y L Y I F E T Y G V N K P F V Q W C T D N I E T F Y C T N S Y R Y E N V P R P I Y V W V L F Q E D E W H G Y R V E D N K F H R R Y M Y S T I L H K R D T D W V E N N P L K T P A Q V E M Y K F L L R I S Q L N R D G T G Y E S D S D P E N E H F D D E S F S S G E E D S S D E D D P T W A P D S D D S D W E T E T E E E P S V A A R I L E K G K L T I T N L M K S L G F K P K P K K I Q S I D R Y F C S L D S N Y N S E D E D F E Y D S D S E D D D S D S E D D C\n")
        f.write("M Y Q A I N P C P Q S W Y G S P Q L E R E I V C K M S G A P H Y P N Y Y P V H P N A L G G A W F D T S L N A R S L T T T P S L T T C T P P S L A A C T P P T S L G M V D S P P H I N P P R R I G T L C F D F G S A K S P Q R C E C V A S D R P S T T S N T A P D T Y R L L I T N S K T R K N N Y G T C R L E P L T Y G I\n")
        f.write("M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S\n")
        f.write("M N A K Y D T D Q G V G R M L F L G T I G L A V V V G G L M A Y G Y Y Y D G K T P S S G T S F H T A S P S F S S R Y R Y\n")
        f.write("M R Y T V L I A L Q G A L L L L L L I D D G Q G Q S P Y P Y P G M P C N S S R Q C G L G T C V H S R C A H C S S D G T L C S P E D P T M V W P C C P E S S C Q L V V G L P S L V N H Y N C L P N Q C T D S S Q C P G G F G C M T R R S K C E L C K A D G E A C N S P Y L D W R K D K E C C S G Y C H T E A R G L E G V C I D P K K I F C T P K N P W Q L A P Y P P S Y H Q P T T L R P P T S L Y D S W L M S G F L V K S T T A P S T Q E E E D D Y\n")
        f.write("M Q N P L P E V M S P E H D K R T T T P M S K E A N K F I R E L D K K P G D L A V V S D F V K R N T G K R L P I G K R S N L Y V R I C D L S G T I Y M G E T F I L E S W E E L Y L P E P T K M E V L G T L E S C C G I P P F P E W I V M V G E D Q C V Y A Y G D E E I L L F A Y S V K Q L V E E G I Q E T G I S Y K Y P D D I S D V D E E V L Q Q D E E I Q K I R K K T R E F V D K D A Q E F Q D F L N S L D A S L L S\n")
        f.write("M D S L N E V C Y E Q I K G T F Y K G L F G D F P L I V D K K T G C F N A T K L C V L G G K R F V D W N K T L R S K K L I Q Y Y E T R C D I K T E S L L Y E I K G D N N D E I T K Q I T G T Y L P K E F I L D I A S W I S V E F Y D K C N N I I I N Y F V N E Y K T M D K K T L Q S K I N E V E E K M Q K L L N E K E E E L Q E K N D K I D E L I L F S K R M E E D R K K D R E M M I K Q E K M L R E L G I H L E D V S S Q N N E L I E K V D E Q V E Q N A V L N F K I D N I Q N K L E I A V E D R A P Q P K Q N L K R E R F I L L K R N D D Y Y P Y Y T I R A Q D I N A R S A L K R Q K N L Y N E V S V L L D L T C H P N S K T L Y V R V K D E L K Q K G V V F N L C K V S I S N S K I N E E E L I K A M E T I N D E K R D V\n") 

fake_data()

spm.SentencePieceTrainer.train(input='vocab.txt', 
                               model_prefix='protein', 
                               vocab_size=30, 
                               model_type="word", 
                               #user_defined_symbols="<MASK>",
                               pad_id=0,
                               unk_id=1,
                               bos_id=2,
                               eos_id=3,
                               pad_piece="[PAD]",
                               unk_piece="[UNK]",
                               bos_piece="[BOS]",
                               eos_piece="[EOS]")

tokenizer = spm.SentencePieceProcessor(model_file='protein.model')

train_files = glob.glob("dataset/train*",recursive=True)
random.shuffle(train_files)

def mask_seq(seq,mask_prob=0.15):
  seq = np.array(seq)
  minValue = 1
  maxValue = len(seq) - 2
  max_mask_tokens = int(maxValue * 0.15 + 0.5)
  randomlist = random.sample(range(minValue, maxValue), max_mask_tokens)
  seq_masked = seq
  seq_masked[randomlist] = tokenizer.encode("[MASK]")[0]
  return seq_masked

def get_seq(train_files):
  while True:
    for file in train_files:
      with open(file) as fp: 
        for line in fp:
          yield line

def get_batch(seq_gen, batch_length):

  batch = []
  while True:
    seq = next(seq_gen)
    seq_ids = tokenizer.encode(seq,add_bos=True,add_eos=True)

    new_batch_len = len(batch) + len(seq_ids)

    if new_batch_len <= batch_length :
      batch = batch + seq_ids
      continue

    next_batch = batch
    batch = seq_ids

    yield next_batch

# Set up the data pipeline.
def my_inputs(n_devices):

  MAX_BATCH_LENGTH = 1024*4

  seq_gen = get_seq(train_files)
  batch_gen = get_batch(seq_gen,MAX_BATCH_LENGTH)

  while True:

    inputs = []
    targets = []
    mask = []
    for i in range(n_devices):
      batch_ids = next(batch_gen)
      masked_seq_ids = mask_seq(batch_ids)
      pad_amount = MAX_BATCH_LENGTH - len(batch_ids)

      inputs.append(np.pad(masked_seq_ids, (0,pad_amount)))
      targets.append(np.pad(batch_ids, (0,pad_amount)))
      mask.append(np.pad(np.ones_like(batch_ids, dtype=np.float32),
                          (0,pad_amount),
                          mode='constant'))
    inputs = np.stack(inputs)
    targets = np.stack(targets)
    mask = np.stack(mask)
    yield (inputs, targets, mask)

inp_gen_test = my_inputs(trax.fastmath.device_count())

res = next(inp_gen_test)

print(tokenizer.decode(res[0][0].tolist()))

print(tokenizer.decode(res[1][0].tolist()))

# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.Reformer
n_layers = 15
n_heads = 16
dropout = 0.1
n_tokens = 40000 # They have used very small n_tokens = 2048
vocab_size= 30
d_model = 1024
d_ff = 4096

# Done
# Parameters for MultifactorSchedule:
# ==============================================================================
multifactor.constant = 0.088
multifactor.decay_factor = 0.5
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.steps_per_cycle = 100000
multifactor.steps_per_decay = 20000
multifactor.warmup_steps = 8000

# Done
# Parameters for Adam:
# ==============================================================================
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-09
Adam.weight_decay_rate = 1e-05

# Done
# Parameters for SelfAttention:
# ==============================================================================
#trax.layers.SelfAttention.attention_dropout = 0.05
#trax.layers.SelfAttention.chunk_len = 64
#trax.layers.SelfAttention.n_chunks_before = 1
#trax.layers.SelfAttention.n_parallel_heads = 1
trax.layers.SelfAttention.causal = False
trax.layers.SelfAttention.chunk_len = None
trax.layers.SelfAttention.masked = False
trax.layers.SelfAttention.n_chunks_after = 0
trax.layers.SelfAttention.n_chunks_before = 0
trax.layers.SelfAttention.n_parallel_heads = None
trax.layers.SelfAttention.predict_drop_len = 64
trax.layers.SelfAttention.predict_mem_len = 192
trax.layers.SelfAttention.share_qk = False
trax.layers.SelfAttention.use_python_loop = False
trax.layers.SelfAttention.use_reference_code = False

# Done
# Parameters for EncDecAttention:
# ==============================================================================
trax.layers.EncDecAttention.masked = True
trax.layers.EncDecAttention.n_parallel_heads = None
trax.layers.EncDecAttention.use_python_loop = False
trax.layers.EncDecAttention.use_reference_code = False

# Done
# Parameters for LSHSelfAttention:
# ==============================================================================
#LSHSelfAttention.attention_dropout = 0.0
#LSHSelfAttention.chunk_len = 64
#LSHSelfAttention.n_buckets = [64, 128]
#LSHSelfAttention.n_chunks_after = 0
#LSHSelfAttention.n_chunks_before = 1
#LSHSelfAttention.n_hashes = 1
#LSHSelfAttention.n_parallel_heads = 1
#LSHSelfAttention.predict_drop_len = 128
#LSHSelfAttention.predict_mem_len = 1024

# Done
# Parameters for Reformer:
# ==============================================================================
Reformer.d_model = %d_model
Reformer.d_ff = %d_ff
Reformer.dropout = %dropout
Reformer.ff_activation = @trax.layers.Relu
Reformer.max_len = %n_tokens
Reformer.mode = 'train'
Reformer.n_heads = %n_heads
Reformer.n_encoder_layers = %n_layers
Reformer.n_decoder_layers = %n_layers
Reformer.input_vocab_size = %vocab_size

""")

# Set up a Trainer.
output_dir = os.path.expanduser('train_dir/')

trainer = trax.supervised.Trainer(
    model=trax.models.Reformer,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)

# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)

Any idea how I can solve this issue ?

agemagician commented 4 years ago

I have test it with a Google cloud machine with 8x V100, and it gives a similar issue:

2020-08-27 18:34:07.848853: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55]
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
agemagician commented 4 years ago

For the GPU problem, I found out that Jax doesn't detect the GPU because of Cuda 11 problem with tensorflow.

Still, I can't figure out what is the problem with the TPU.

jekbradbury commented 4 years ago

I think it may just be taking a very long time to compile. (And potentially the tpu_driver has a time-out that's too short here).

agemagician commented 4 years ago

Thanks @jekbradbury for your reply. Is there a solution around it ?

I usually train much bigger models with tensorflow and Pytorch without any issues. Is this a problem related to Jax ?

agemagician commented 4 years ago

I have fixed the Tensorflow issue but still Jax can't see the GPUs:

Python 3.6.10 |Anaconda, Inc.| (default, May  8 2020, 02:54:21) 
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
2020-08-27 20:21:00.806963: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
>>> physical_devices = tf.config.list_physical_devices('GPU')
2020-08-27 20:21:04.095189: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
2020-08-27 20:21:04.548746: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.550269: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:00:04.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.550419: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.551886: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 1 with properties: 
pciBusID: 0000:00:05.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.551995: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.553494: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 2 with properties: 
pciBusID: 0000:00:06.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.553588: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.555113: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 3 with properties: 
pciBusID: 0000:00:07.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.555202: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.556676: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 4 with properties: 
pciBusID: 0000:00:08.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.556769: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.558275: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 5 with properties: 
pciBusID: 0000:00:09.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.558362: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.559834: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 6 with properties: 
pciBusID: 0000:00:0a.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.559919: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.561362: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 7 with properties: 
pciBusID: 0000:00:0b.0 name: Tesla V100-SXM2-16GB computeCapability: 7.0
coreClock: 1.53GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2020-08-27 20:21:04.561405: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2020-08-27 20:21:04.563455: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2020-08-27 20:21:04.565374: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2020-08-27 20:21:04.565677: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2020-08-27 20:21:04.567552: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
2020-08-27 20:21:04.568606: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
2020-08-27 20:21:04.572985: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
2020-08-27 20:21:04.573151: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.574731: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.576168: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.577606: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.579063: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.580476: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.582068: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.583612: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.585144: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.586596: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.588028: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.589484: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.590972: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.592478: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.593962: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.595447: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2020-08-27 20:21:04.596795: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0, 1, 2, 3, 4, 5, 6, 7
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
/home/ahmed/anaconda3/envs/trax/lib/python3.6/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')
cpu
>>>
agemagician commented 4 years ago

Any news about this issue ?

skye commented 4 years ago

If you just import jax and not TF, does jax see the GPUs?

agemagician commented 4 years ago

No, it doesn't see the GPU, but TensorFlow sees it.

agemagician commented 4 years ago

My main issue here is the jax compilation speed, which is very very slow. I killed the process after left it for 1 hour.

skye commented 4 years ago

Can you tell which line of code is causing the long compile?

For the GPU issue, how are you installing jax+ jaxib?

agemagician commented 4 years ago

1) Can you tell which line of code is causing the long compile? The trainer:

trainer = trax.supervised.Trainer(
    model=trax.models.Reformer,
    loss_fn=trax.layers.CrossEntropyLoss(),
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.multifactor(),
    inputs=trax.data.inputs.Inputs(my_inputs),
    output_dir=output_dir)

I believe this trigger JAX compilation.

2) For the GPU issue, I figured out the problem. I was using : pip install --upgrade jax jaxlib Which doesn't support CUDA, and I followed the readme and Jax now can see the GPUs. I think the default Jax installation should support both TPUs and GPUs. This will be easier for users.

The issue of very long compilation doesn't exist using the GPU for the same command mentioned above.

agemagician commented 4 years ago

I have tested the above script using GPUs and it is working without any issue. The long compilation process only occurs with TPUs.

skye commented 4 years ago

Glad to hear you resolved the jaxlib issue! Unfortunately we cannot support GPUs by default since we need different jaxlibs for different CUDA versions. https://github.com/google/jax/pull/4065 should make this a little simpler at least.

For the long compilation on TPU, like @jekbradbury says, there might be an issue with the underlying TPU driver timing out. We're aware of this issue, but I don't know of any immediate workaround. Your best bet is probably to use GPU for now.

agemagician commented 4 years ago

Thanks a lot @skye for your explanation and support. I will use the GPUs for now and I hope the TPU issue will be solved in the near future.

jayendra13 commented 4 years ago

I am also getting the same error, where I am trying to train the standard TrnasformerLM, with the slight modification in the parameters(vocab_size=50000, max_len=1024). The same code works on the COLAB + TPU setup, but when I try it with the GCE-VM + CLOUD-TPU setup I am getting the same error.

ProBuro commented 2 years ago

now 12.01.22 and the problem still exists

sarthak-314 commented 2 years ago

I am also getting the same error, where I am trying to train the standard TrnasformerLM, with the slight modification in the parameters(vocab_size=50000, max_len=1024). The same code works on the COLAB + TPU setup, but when I try it with the GCE-VM + CLOUD-TPU setup I am getting the same error.

I have the same issue. I can run big bird model with max_length of 1024 but compilation times out with max length of 2048.

hawkinsp commented 1 year ago

We no longer support remote Cloud TPUs, which are the kind of TPU referenced in this issue. We only support Cloud TPU VMs, which have a very different runtime stack. You can get TPU VMs either from GCP or via Kaggle notebooks at the time of writing.