google-research / t5x

Apache License 2.0
2.65k stars 301 forks source link

bad_alloc error before training starts, seemingly caused by TPU backend not found on TPU v2 and v3 VM #1004

Open RobertLiJN opened 1 year ago

RobertLiJN commented 1 year ago

Hi,

I have been trying to run the wmt demo on TPUv2 or TPUv3 VMs, but I keep encountering a bac_alloc error before training even starts. It seems that the output also says that no TPU backend is found, although I have verified that JAX is able to see the 8 TPUs.

Specifically, I first acquire a TPU VM with the command gcloud compute tpus tpu-vm create t5_test_3 --zone=europe-west4-a --accelerator-type=v3-8 --version=tpu-vm-tf-2.11.0.

Then, I login to the VM using gcloud alpha compute tpus tpu-vm ssh t5_test_3 --zone=europe-west4-a.

Now, I get T5X and dependencies with

git clone --branch=main https://github.com/google-research/t5x
cd t5x

python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
cd ..

The directories are set with

mkdir model
MODEL_DIR="model"
T5X_DIR="t5x"
TFDS_DATA_DIR="gs://<somebucketname>/wmt_t2t_translate/de-en/1.0.0"

Now I run the pretrain command

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --tfds_data_dir=${TFDS_DATA_DIR}

And the following is the output

Rewritten gin arg: --gin_bindings=MODEL_DIR = "model"
I0106 20:23:28.418207 140414174514240 resource_reader.py:50] system_path_file_exists:t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin
I0106 20:23:28.419633 140414174514240 resource_reader.py:37] gin-config opened resource file:/home/robertli/t5x/t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin
I0106 20:23:28.436774 140414174514240 resource_reader.py:50] system_path_file_exists:t5x/examples/t5/t5_1_1/base.gin
I0106 20:23:28.436974 140414174514240 resource_reader.py:37] gin-config opened resource file:/home/robertli/t5x/t5x/examples/t5/t5_1_1/base.gin
I0106 20:23:28.461038 140414174514240 resource_reader.py:50] system_path_file_exists:t5x/configs/runs/pretrain.gin
I0106 20:23:28.461671 140414174514240 resource_reader.py:37] gin-config opened resource file:/home/robertli/t5x/t5x/configs/runs/pretrain.gin
I0106 20:23:28.492016 140414174514240 gin_utils.py:86] Gin Configuration:
I0106 20:23:28.501875 140414174514240 gin_utils.py:88] from __gin__ import dynamic_registration
I0106 20:23:28.501985 140414174514240 gin_utils.py:88] import __main__ as train_script
I0106 20:23:28.502058 140414174514240 gin_utils.py:88] import seqio
I0106 20:23:28.502127 140414174514240 gin_utils.py:88] from t5.data import mixtures
I0106 20:23:28.502193 140414174514240 gin_utils.py:88] from t5x import adafactor
I0106 20:23:28.502257 140414174514240 gin_utils.py:88] from t5x.examples.t5 import network
I0106 20:23:28.502321 140414174514240 gin_utils.py:88] from t5x import gin_utils
I0106 20:23:28.502383 140414174514240 gin_utils.py:88] from t5x import models
I0106 20:23:28.502447 140414174514240 gin_utils.py:88] from t5x import partitioning
I0106 20:23:28.502509 140414174514240 gin_utils.py:88] from t5x import trainer
I0106 20:23:28.502572 140414174514240 gin_utils.py:88] from t5x import utils
I0106 20:23:28.502641 140414174514240 gin_utils.py:88] 
I0106 20:23:28.502705 140414174514240 gin_utils.py:88] # Macros:
I0106 20:23:28.502770 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.502834 140414174514240 gin_utils.py:88] BATCH_SIZE = 128
I0106 20:23:28.502897 140414174514240 gin_utils.py:88] DROPOUT_RATE = 0.0
I0106 20:23:28.502961 140414174514240 gin_utils.py:88] LABEL_SMOOTHING = 0.0
I0106 20:23:28.503024 140414174514240 gin_utils.py:88] LOSS_NORMALIZING_FACTOR = None
I0106 20:23:28.503087 140414174514240 gin_utils.py:88] MIXTURE_OR_TASK_MODULE = None
I0106 20:23:28.503150 140414174514240 gin_utils.py:88] MIXTURE_OR_TASK_NAME = 'wmt_t2t_ende_v003'
I0106 20:23:28.503214 140414174514240 gin_utils.py:88] MODEL = @models.EncoderDecoderModel()
I0106 20:23:28.503276 140414174514240 gin_utils.py:88] MODEL_DIR = 'model'
I0106 20:23:28.503339 140414174514240 gin_utils.py:88] OPTIMIZER = @adafactor.Adafactor()
I0106 20:23:28.503402 140414174514240 gin_utils.py:88] RANDOM_SEED = None
I0106 20:23:28.503465 140414174514240 gin_utils.py:88] SHUFFLE_TRAIN_EXAMPLES = True
I0106 20:23:28.503527 140414174514240 gin_utils.py:88] TASK_FEATURE_LENGTHS = {'inputs': 256, 'targets': 256}
I0106 20:23:28.503595 140414174514240 gin_utils.py:88] TRAIN_STEPS = 50000
I0106 20:23:28.503658 140414174514240 gin_utils.py:88] USE_CACHED_TASKS = True
I0106 20:23:28.503721 140414174514240 gin_utils.py:88] USE_HARDWARE_RNG = False
I0106 20:23:28.503784 140414174514240 gin_utils.py:88] VOCABULARY = @seqio.SentencePieceVocabulary()
I0106 20:23:28.503847 140414174514240 gin_utils.py:88] Z_LOSS = 0.0001
I0106 20:23:28.503910 140414174514240 gin_utils.py:88] 
I0106 20:23:28.503972 140414174514240 gin_utils.py:88] # Parameters for adafactor.Adafactor:
I0106 20:23:28.504035 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.504098 140414174514240 gin_utils.py:88] adafactor.Adafactor.decay_rate = 0.8
I0106 20:23:28.504180 140414174514240 gin_utils.py:88] adafactor.Adafactor.logical_factor_rules = \
I0106 20:23:28.504245 140414174514240 gin_utils.py:88]     @adafactor.standard_logical_factor_rules()
I0106 20:23:28.504308 140414174514240 gin_utils.py:88] adafactor.Adafactor.step_offset = 0
I0106 20:23:28.504371 140414174514240 gin_utils.py:88] 
I0106 20:23:28.504434 140414174514240 gin_utils.py:88] # Parameters for utils.CheckpointConfig:
I0106 20:23:28.504497 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.504560 140414174514240 gin_utils.py:88] utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
I0106 20:23:28.504628 140414174514240 gin_utils.py:88] utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
I0106 20:23:28.504691 140414174514240 gin_utils.py:88] 
I0106 20:23:28.504755 140414174514240 gin_utils.py:88] # Parameters for utils.create_learning_rate_scheduler:
I0106 20:23:28.504817 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.504881 140414174514240 gin_utils.py:88] utils.create_learning_rate_scheduler.base_learning_rate = 1.0
I0106 20:23:28.504944 140414174514240 gin_utils.py:88] utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay'
I0106 20:23:28.505007 140414174514240 gin_utils.py:88] utils.create_learning_rate_scheduler.warmup_steps = 10000
I0106 20:23:28.505070 140414174514240 gin_utils.py:88] 
I0106 20:23:28.505133 140414174514240 gin_utils.py:88] # Parameters for infer_eval/utils.DatasetConfig:
I0106 20:23:28.505196 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.505259 140414174514240 gin_utils.py:88] infer_eval/utils.DatasetConfig.batch_size = 128
I0106 20:23:28.505323 140414174514240 gin_utils.py:88] infer_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
I0106 20:23:28.505386 140414174514240 gin_utils.py:88] infer_eval/utils.DatasetConfig.seed = 0
I0106 20:23:28.505449 140414174514240 gin_utils.py:88] infer_eval/utils.DatasetConfig.shuffle = False
I0106 20:23:28.505512 140414174514240 gin_utils.py:88] infer_eval/utils.DatasetConfig.split = 'validation'
I0106 20:23:28.505580 140414174514240 gin_utils.py:88] infer_eval/utils.DatasetConfig.task_feature_lengths = None
I0106 20:23:28.505647 140414174514240 gin_utils.py:88] infer_eval/utils.DatasetConfig.use_cached = False
I0106 20:23:28.505710 140414174514240 gin_utils.py:88] 
I0106 20:23:28.505773 140414174514240 gin_utils.py:88] # Parameters for train/utils.DatasetConfig:
I0106 20:23:28.505836 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.505899 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.batch_size = 128
I0106 20:23:28.505962 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
I0106 20:23:28.506025 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
I0106 20:23:28.506089 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.pack = True
I0106 20:23:28.506152 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.seed = 0
I0106 20:23:28.506215 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES
I0106 20:23:28.506278 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.split = 'train'
I0106 20:23:28.506341 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
I0106 20:23:28.506404 140414174514240 gin_utils.py:88] train/utils.DatasetConfig.use_cached = False
I0106 20:23:28.506467 140414174514240 gin_utils.py:88] 
I0106 20:23:28.506529 140414174514240 gin_utils.py:88] # Parameters for train_eval/utils.DatasetConfig:
I0106 20:23:28.506600 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.506664 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.batch_size = 128
I0106 20:23:28.506727 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
I0106 20:23:28.506790 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
I0106 20:23:28.506853 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.pack = True
I0106 20:23:28.506916 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.seed = 0
I0106 20:23:28.506979 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.shuffle = False
I0106 20:23:28.507042 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.split = 'validation'
I0106 20:23:28.507105 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
I0106 20:23:28.507168 140414174514240 gin_utils.py:88] train_eval/utils.DatasetConfig.use_cached = False
I0106 20:23:28.507231 140414174514240 gin_utils.py:88] 
I0106 20:23:28.507294 140414174514240 gin_utils.py:88] # Parameters for models.EncoderDecoderModel:
I0106 20:23:28.507357 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.507420 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
I0106 20:23:28.507483 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
I0106 20:23:28.507546 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
I0106 20:23:28.507614 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.module = @network.Transformer()
I0106 20:23:28.507677 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
I0106 20:23:28.507740 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
I0106 20:23:28.507803 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.z_loss = %Z_LOSS
I0106 20:23:28.507866 140414174514240 gin_utils.py:88] 
I0106 20:23:28.507928 140414174514240 gin_utils.py:88] # Parameters for models.EncoderDecoderModel.predict_batch_with_aux:
I0106 20:23:28.507992 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.508055 140414174514240 gin_utils.py:88] models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 4
I0106 20:23:28.508131 140414174514240 gin_utils.py:88] 
I0106 20:23:28.508196 140414174514240 gin_utils.py:88] # Parameters for seqio.Evaluator:
I0106 20:23:28.508260 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.508323 140414174514240 gin_utils.py:88] seqio.Evaluator.logger_cls = \
I0106 20:23:28.508387 140414174514240 gin_utils.py:88]     [@seqio.PyLoggingLogger, @seqio.TensorBoardLogger, @seqio.JSONLogger]
I0106 20:23:28.508450 140414174514240 gin_utils.py:88] seqio.Evaluator.num_examples = None
I0106 20:23:28.508513 140414174514240 gin_utils.py:88] seqio.Evaluator.use_memory_cache = True
I0106 20:23:28.508577 140414174514240 gin_utils.py:88] 
I0106 20:23:28.508644 140414174514240 gin_utils.py:88] # Parameters for partitioning.PjitPartitioner:
I0106 20:23:28.508707 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.508771 140414174514240 gin_utils.py:88] partitioning.PjitPartitioner.logical_axis_rules = \
I0106 20:23:28.508834 140414174514240 gin_utils.py:88]     @partitioning.standard_logical_axis_rules()
I0106 20:23:28.508897 140414174514240 gin_utils.py:88] partitioning.PjitPartitioner.model_parallel_submesh = None
I0106 20:23:28.508961 140414174514240 gin_utils.py:88] partitioning.PjitPartitioner.num_partitions = 2
I0106 20:23:28.509024 140414174514240 gin_utils.py:88] 
I0106 20:23:28.509087 140414174514240 gin_utils.py:88] # Parameters for utils.RestoreCheckpointConfig:
I0106 20:23:28.509150 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.509214 140414174514240 gin_utils.py:88] utils.RestoreCheckpointConfig.path = []
I0106 20:23:28.509277 140414174514240 gin_utils.py:88] 
I0106 20:23:28.509339 140414174514240 gin_utils.py:88] # Parameters for utils.SaveCheckpointConfig:
I0106 20:23:28.509402 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.509466 140414174514240 gin_utils.py:88] utils.SaveCheckpointConfig.dtype = 'float32'
I0106 20:23:28.509529 140414174514240 gin_utils.py:88] utils.SaveCheckpointConfig.keep = None
I0106 20:23:28.509597 140414174514240 gin_utils.py:88] utils.SaveCheckpointConfig.period = 5000
I0106 20:23:28.509660 140414174514240 gin_utils.py:88] utils.SaveCheckpointConfig.save_dataset = False
I0106 20:23:28.509724 140414174514240 gin_utils.py:88] 
I0106 20:23:28.509786 140414174514240 gin_utils.py:88] # Parameters for seqio.SentencePieceVocabulary:
I0106 20:23:28.509849 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.509912 140414174514240 gin_utils.py:88] seqio.SentencePieceVocabulary.sentencepiece_model_file = \
I0106 20:23:28.509975 140414174514240 gin_utils.py:88]     'gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model'
I0106 20:23:28.510038 140414174514240 gin_utils.py:88] 
I0106 20:23:28.510101 140414174514240 gin_utils.py:88] # Parameters for network.T5Config:
I0106 20:23:28.510164 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.510227 140414174514240 gin_utils.py:88] network.T5Config.dropout_rate = %DROPOUT_RATE
I0106 20:23:28.510290 140414174514240 gin_utils.py:88] network.T5Config.dtype = 'bfloat16'
I0106 20:23:28.510353 140414174514240 gin_utils.py:88] network.T5Config.emb_dim = 768
I0106 20:23:28.510416 140414174514240 gin_utils.py:88] network.T5Config.head_dim = 64
I0106 20:23:28.510479 140414174514240 gin_utils.py:88] network.T5Config.logits_via_embedding = False
I0106 20:23:28.510542 140414174514240 gin_utils.py:88] network.T5Config.mlp_activations = ('gelu', 'linear')
I0106 20:23:28.510609 140414174514240 gin_utils.py:88] network.T5Config.mlp_dim = 2048
I0106 20:23:28.510672 140414174514240 gin_utils.py:88] network.T5Config.num_decoder_layers = 12
I0106 20:23:28.510735 140414174514240 gin_utils.py:88] network.T5Config.num_encoder_layers = 12
I0106 20:23:28.510798 140414174514240 gin_utils.py:88] network.T5Config.num_heads = 12
I0106 20:23:28.510861 140414174514240 gin_utils.py:88] network.T5Config.vocab_size = 32128
I0106 20:23:28.510924 140414174514240 gin_utils.py:88] 
I0106 20:23:28.510987 140414174514240 gin_utils.py:88] # Parameters for train_script.train:
I0106 20:23:28.511049 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.511114 140414174514240 gin_utils.py:88] train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
I0106 20:23:28.511177 140414174514240 gin_utils.py:88] train_script.train.eval_period = 500
I0106 20:23:28.511240 140414174514240 gin_utils.py:88] train_script.train.eval_steps = 20
I0106 20:23:28.511302 140414174514240 gin_utils.py:88] train_script.train.infer_eval_dataset_cfg = @infer_eval/utils.DatasetConfig()
I0106 20:23:28.511365 140414174514240 gin_utils.py:88] train_script.train.inference_evaluator_cls = @seqio.Evaluator
I0106 20:23:28.511429 140414174514240 gin_utils.py:88] train_script.train.model = %MODEL
I0106 20:23:28.511491 140414174514240 gin_utils.py:88] train_script.train.model_dir = %MODEL_DIR
I0106 20:23:28.511554 140414174514240 gin_utils.py:88] train_script.train.partitioner = @partitioning.PjitPartitioner()
I0106 20:23:28.511625 140414174514240 gin_utils.py:88] train_script.train.random_seed = 0
I0106 20:23:28.511689 140414174514240 gin_utils.py:88] train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
I0106 20:23:28.511752 140414174514240 gin_utils.py:88] train_script.train.total_steps = %TRAIN_STEPS
I0106 20:23:28.511815 140414174514240 gin_utils.py:88] train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
I0106 20:23:28.511878 140414174514240 gin_utils.py:88] train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
I0106 20:23:28.511941 140414174514240 gin_utils.py:88] train_script.train.trainer_cls = @trainer.Trainer
I0106 20:23:28.512004 140414174514240 gin_utils.py:88] train_script.train.use_hardware_rng = True
I0106 20:23:28.512067 140414174514240 gin_utils.py:88] 
I0106 20:23:28.512146 140414174514240 gin_utils.py:88] # Parameters for trainer.Trainer:
I0106 20:23:28.512211 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.512274 140414174514240 gin_utils.py:88] trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
I0106 20:23:28.512337 140414174514240 gin_utils.py:88] trainer.Trainer.num_microbatches = None
I0106 20:23:28.512400 140414174514240 gin_utils.py:88] 
I0106 20:23:28.512462 140414174514240 gin_utils.py:88] # Parameters for network.Transformer:
I0106 20:23:28.512525 140414174514240 gin_utils.py:88] # ==============================================================================
I0106 20:23:28.512592 140414174514240 gin_utils.py:88] network.Transformer.config = @network.T5Config()
I0106 20:23:28.514295 140414174514240 partitioning.py:498] `activation_partitioning_dims` = 1, `parameter_partitioning_dims` = 1
I0106 20:23:28.534243 140414174514240 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0106 20:23:28.534394 140414174514240 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0106 20:23:28.534486 140414174514240 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
D0106 20:23:28.754797177   75726 config.cc:113]              gRPC EXPERIMENT tcp_frame_size_tuning               OFF (default:OFF)
D0106 20:23:28.754821824   75726 config.cc:113]              gRPC EXPERIMENT tcp_read_chunks                     OFF (default:OFF)
D0106 20:23:28.754830329   75726 config.cc:113]              gRPC EXPERIMENT tcp_rcv_lowat                       OFF (default:OFF)
D0106 20:23:28.754837909   75726 config.cc:113]              gRPC EXPERIMENT peer_state_based_framing            OFF (default:OFF)
D0106 20:23:28.754845063   75726 config.cc:113]              gRPC EXPERIMENT flow_control_fixes                  OFF (default:OFF)
D0106 20:23:28.754852281   75726 config.cc:113]              gRPC EXPERIMENT memory_pressure_controller          OFF (default:OFF)
D0106 20:23:28.754859526   75726 config.cc:113]              gRPC EXPERIMENT periodic_resource_quota_reclamation ON  (default:ON)
D0106 20:23:28.754866837   75726 config.cc:113]              gRPC EXPERIMENT unconstrained_max_quota_buffer_size OFF (default:OFF)
D0106 20:23:28.754873901   75726 config.cc:113]              gRPC EXPERIMENT new_hpack_huffman_decoder           OFF (default:OFF)
D0106 20:23:28.754880949   75726 config.cc:113]              gRPC EXPERIMENT event_engine_client                 OFF (default:OFF)
D0106 20:23:28.754888043   75726 config.cc:113]              gRPC EXPERIMENT monitoring_experiment               ON  (default:ON)
D0106 20:23:28.754895088   75726 config.cc:113]              gRPC EXPERIMENT promise_based_client_call           OFF (default:OFF)
I0106 20:23:28.755228101   75726 ev_epoll1_linux.cc:121]     grpc epoll fd: 7
D0106 20:23:28.755266738   75726 ev_posix.cc:141]            Using polling engine: epoll1
D0106 20:23:28.755309299   75726 dns_resolver_ares.cc:824]   Using ares dns resolver
D0106 20:23:28.755678463   75726 lb_policy_registry.cc:45]   registering LB policy factory for "priority_experimental"
D0106 20:23:28.755716641   75726 lb_policy_registry.cc:45]   registering LB policy factory for "outlier_detection_experimental"
D0106 20:23:28.755725218   75726 lb_policy_registry.cc:45]   registering LB policy factory for "weighted_target_experimental"
D0106 20:23:28.755733248   75726 lb_policy_registry.cc:45]   registering LB policy factory for "pick_first"
D0106 20:23:28.755741496   75726 lb_policy_registry.cc:45]   registering LB policy factory for "round_robin"
D0106 20:23:28.755754622   75726 lb_policy_registry.cc:45]   registering LB policy factory for "ring_hash_experimental"
D0106 20:23:28.755781182   75726 lb_policy_registry.cc:45]   registering LB policy factory for "grpclb"
D0106 20:23:28.755870535   75726 lb_policy_registry.cc:45]   registering LB policy factory for "rls_experimental"
D0106 20:23:28.755914982   75726 lb_policy_registry.cc:45]   registering LB policy factory for "xds_cluster_manager_experimental"
D0106 20:23:28.755946604   75726 lb_policy_registry.cc:45]   registering LB policy factory for "xds_cluster_impl_experimental"
D0106 20:23:28.755955454   75726 lb_policy_registry.cc:45]   registering LB policy factory for "cds_experimental"
D0106 20:23:28.755963256   75726 lb_policy_registry.cc:45]   registering LB policy factory for "xds_cluster_resolver_experimental"
D0106 20:23:28.755972515   75726 certificate_provider_registry.cc:35] registering certificate provider factory for "file_watcher"
I0106 20:23:28.776241706   75726 socket_utils_common_posix.cc:336] TCP_USER_TIMEOUT is available. TCP_USER_TIMEOUT will be used thereafter
I0106 20:23:31.340027 140414174514240 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
I0106 20:23:31.340459 140414174514240 train.py:195] Process ID: 0
I0106 20:23:31.340698 140414174514240 train.py:199] GlobalDeviceArray enabled.
I0106 20:23:31.340821 140414174514240 train.py:239] Using fast RngBitGenerator PRNG for initialization and dropout.
W0106 20:23:31.340916 140414174514240 train.py:246] When using hardware RNG with a fixed seed, repeatability is only guaranteed for fixed hardware and partitioning schemes and for a fixed version of this code and its dependencies.
tcmalloc: large alloc 140413996990464 bytes == (nil) @  0x7fb4b9605680 0x7fb4b9625ff4 0x7fb4b6a3228d 0x7fb4b6a33bbf 0x7fb4b6a2acaa 0x7fb4b4b1b66e 0x7fb4b25b57b8 0x7fb4b25b12b2 0x7fb4b25b7a4b 0x7fb4b259de85 0x7fb4b242ed35 0x7fb4b240697c 0x5f6929 0x5f74f6 0x50c383 0x570b26 0x5f6cd6 0x56bacd 0x569dba 0x5f6eb3 0x570b26 0x501923 0x524f74 0x5f15c4 0x5f745f 0x570d55 0x569dba 0x5f6eb3 0x56bacd 0x569dba 0x5f6eb3
Traceback (most recent call last):
  File "t5x/t5x/train.py", line 793, in <module>
    gin_utils.run(main)
  File "/home/robertli/t5x/t5x/gin_utils.py", line 130, in run
    app.run(
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "t5x/t5x/train.py", line 759, in main
    _main(argv)
  File "t5x/t5x/train.py", line 789, in _main
    train_using_gin()
  File "/home/robertli/.local/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/robertli/.local/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/robertli/.local/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "t5x/t5x/train.py", line 251, in train
    rng = random.PRNGKey(random_seed)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/random.py", line 133, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/prng.py", line 267, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/prng.py", line 572, in random_seed
    seeds_arr = jnp.asarray(np.int64(seeds))
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2036, in asarray
    return array(a, dtype=dtype, copy=False, order=order)  # type: ignore
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 2017, in array
    out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 588, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/core.py", line 329, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/core.py", line 712, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 120, in apply_primitive
    return compiled_fun(*args)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 205, in <lambda>
    return lambda *args, **kw: compiled(*args, **kw)[0]
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3638, in _execute_trivial
    return out_handler(in_handler(outs))
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1895, in __call__
    return self.handler(input_buffers)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 416, in shard_args
    return [_shard_arg(arg, devices, indices[i], mode) for i, arg in enumerate(args)]
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 416, in <listcomp>
    return [_shard_arg(arg, devices, indices[i], mode) for i, arg in enumerate(args)]
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 394, in _shard_arg
    return shard_arg_handlers[type(arg)](arg, devices, arg_indices, mode)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 428, in _shard_array
    return device_put([x[i] for i in indices], devices)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3916, in device_put
    return list(it.chain.from_iterable(dispatch.device_put(val, device) for val, device in safe_zip(x, devices)))
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 3916, in <genexpr>
    return list(it.chain.from_iterable(dispatch.device_put(val, device) for val, device in safe_zip(x, devices)))
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 1266, in device_put
    return device_put_handlers[type(x)](x, device)
  File "/home/robertli/.local/lib/python3.8/site-packages/jax/_src/dispatch.py", line 1277, in _device_put_array
    return (backend.buffer_from_pyval(x, device),)
MemoryError: std::bad_alloc
  In call to configurable 'train' (<function train at 0x7fb35e783d30>)
D0106 20:23:32.244766237   75726 init.cc:190]                grpc_shutdown starts clean-up now

Note the line at 20:23:28.534243 that says I0106 20:23:28.534243 140414174514240 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:

I wonder if this is the cause of the bad_alloc error and if there is a way to fix it. Thanks in advance!

skye commented 1 year ago

I'm not sure, but I think the t5x requirements may be messed up and not installing libtpu, which is the low-level library required for jax and other frameworks to access the libtpu. Can you try manually installing the jax[tpu] setup, which includes the proper libtpu version, and see if that fixes the issue?

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

I also recommend using --version=tpu-vm-base when creating a TPU VM for use with jax (or tpu-vm-v4-base if creating a TPU v4 VM). The TF images come with a preinstalled libtpu version for the specified TF version, whereas the base images do not. I think jax ended up using the incorrect presintalled libtpu version in this case, which can lead to confusing errors like this one. (The Unable to initialize backend 'tpu_driver' error is actually talking about the old TPU Node architecture, and doesn't mean jax isn't using the TPU on a TPU VM.)

RobertLiJN commented 1 year ago

I'm not sure, but I think the t5x requirements may be messed up and not installing libtpu, which is the low-level library required for jax and other frameworks to access the libtpu. Can you try manually installing the jax[tpu] setup, which includes the proper libtpu version, and see if that fixes the issue?

pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

I also recommend using --version=tpu-vm-base when creating a TPU VM for use with jax (or tpu-vm-v4-base if creating a TPU v4 VM). The TF images come with a preinstalled libtpu version for the specified TF version, whereas the base images do not. I think jax ended up using the incorrect presintalled libtpu version in this case, which can lead to confusing errors like this one. (The Unable to initialize backend 'tpu_driver' error is actually talking about the old TPU Node architecture, and doesn't mean jax isn't using the TPU on a TPU VM.)

Hi Skye,

This command of manually installing jax[tpu] fixes the issue, even on a TensorFlow VM! Looks like the command in this T5X repo for installing jax[tpu] is problematic here for some reason. Thank you so much for your help!

skye commented 1 year ago

Awesome!

even on a TensorFlow VM

Yup this makes sense, since jax will always use the pip-installed libtpu if available. I just recommend using the base image because in cases like this where the pip-installed libtpu isn't present for some reason, it can make it a bit easier to debug since it'll fall back to CPU, instead of crashing in a weird way.

I'm gonna leave this issue open until we fix the underlying install issue, since other people could easily hit this. It looks like t5x[tpu] pulls in jax[tpu] here: https://github.com/google-research/t5x/blob/2a62e14fd2806a28c8b24c7674fdd5423aa95e3d/setup.py#L72

I don't understand why this is only pulling in jaxlib and not libtpu. Here's the jax[tpu] definition: https://github.com/google/jax/blob/fc04c71d9342186b1ec51fcdb0a13fe1c6fcd5e2/setup.py#L84-L87

I can't dig into this right now, but I wonder if we're hitting some strange pip edge case around custom indices (which is how it locates the libtpu package).

skye commented 1 year ago

Oh I can't reopen it. @RobertLiJN if you're able to reopen please do so

RobertLiJN commented 1 year ago

Hi Skye, I have reopened it. Thanks again!