google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

Pipeline Parallelism: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size Fatal Python error: Aborted #4

Closed abhinavgoel95 closed 1 year ago

abhinavgoel95 commented 2 years ago

Hello!

I am trying to implement 126 million parameter GPT-3 with Pipeline Parallelism on PAXML. I run into some errors when NUM_MICROBATCHES > 1.

System:

8X NVIDIA A100-SXM 80 GB

Gin Configs:

from __gin__ import dynamic_registration

import __main__ as train_script
from paxml import gin_utils
from paxml.tasks.lm import model_params_with_gin
from paxml.tasks.lm.params import datasets_gin
from praxis import optimizers
from praxis import schedules
from praxis.layers import activations
from praxis.layers import repeats
from jax import numpy as jnp

MAX_SL=2048
SUMMARY_INTERVAL_STEPS=100
CHECKPOINT_EVERY_N_STEPS=1000
EVAL_INTERVAL_STEPS=100
MAX_STEPS=600000
NUM_STAGES = 4
ICI_MESH_SHAPE=[%NUM_STAGES, 1, 1, 2]
PERCORE_BATCH_SIZE = 2

MODEL = @model_params_with_gin.TransformerLmSpmdPipeline()
model_params_with_gin.TransformerLmSpmdPipeline:
  USE_REPEATED_LAYER = False
  MAX_SEQ_LEN = %MAX_SL
  NUM_LAYERS = 12
  NUM_HEADS = 12
  MODEL_DIMS = 768
  HIDDEN_DIMS = 3072
  DIMS_PER_HEAD = 64
  VOCAB_SIZE = 51200
  TRAINABLE_POSITION_EMB = True
  TRAINABLE_PE_MAX_SEQ_LEN = %MAX_SL
  ACTIVATION_CLS = @activations.GELU.HParams()
  PACKED_INPUT = True
  USE_BIAS = False
  MAX_STEPS=%MAX_STEPS
  INIT_STD = 0.023
  EVAL_INTERVAL_STEPS = 100
  NUM_STAGES = %NUM_STAGES
  NUM_MICROBATCHES = 2
  ICI_MESH_SHAPE = %ICI_MESH_SHAPE
  FPROP_DTYPE = @jnp.bfloat16
  SUMMARY_INTERVAL_STEPS=%SUMMARY_INTERVAL_STEPS
  CHECKPOINT_EVERY_N_STEPS=%CHECKPOINT_EVERY_N_STEPS
  EVAL_INTERVAL_STEPS=%EVAL_INTERVAL_STEPS

OPTIMIZER = @optimizers.Adam.HParams()
optimizers.Adam.HParams:
  beta1 = 0.9
  beta2 = 0.95
  learning_rate = 6e-4
  epsilon_root = 0.0
  epsilon = 1e-8
  weight_decay = 0.1
  clip_threshold = 1.0
  clip_gradient_norm_to_value = 5.0

SCHEDULER = @schedules.LinearRampupCosineDecay.HParams()
schedules.LinearRampupCosineDecay.HParams:
  warmup_steps = 636
  decay_start = 637
  decay_end = 500000
  min_ratio = 0.1
  max = 1.0

DATASET = @datasets_gin.PileUnsupervisedDataset()
datasets_gin.PileUnsupervisedDataset:
  MAX_SEQ_LEN = %MAX_SL
  PERCORE_BATCH_SIZE = %PERCORE_BATCH_SIZE

## experiment == model + dataset
EXPERIMENT = @model_params_with_gin.Experiment()
model_params_with_gin.Experiment:
  model = %MODEL
  dataset = %DATASET
  optimizer = %OPTIMIZER
  scheduler = %SCHEDULER

train_script.run:
  experiment_config = %EXPERIMENT

Command:

#! /bin/bash

set -x

PYTHONPATH=/pax/paxml:/pax/praxis python3 /pax/paxml/paxml/main.py \
    --exp=tasks.lm.params.c4.PileSpmdAdam \
    --gin_file="/pax/paxml/configs/gpt3_126_pp.gin" \
    --tfds_data_dir="/pax/datasets" \
    --vocab_path='/pax/vocab/c4_en_301_5Mexp2_spm.model' \
    --pmap_use_tensorstore=True \
    --job_log_dir=/logs/ \
    --alsologtostderr 

set +x

XLA Complie Time Error:

2022-10-10 16:01:05.537760: F external/org_tensorflow/tensorflow/compiler/xla/array.h:446] Check failed: n < sizes_size 
Fatal Python error: Aborted

Current thread 0x00007f5c10b73740 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 940 in backend_compile
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py", line 294 in wrapper
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py", line 996 in compile_or_get_cached
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 3048 in from_hlo
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/pxla.py", line 2890 in compile
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 815 in _pjit_call_impl
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 685 in process_primitive
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 327 in bind_with_trace
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 324 in bind
  File "/usr/local/lib/python3.8/dist-packages/jax/experimental/pjit.py", line 385 in wrapped
  File "/pax/paxml/paxml/train.py", line 1087 in train_and_evaluate_spmd_model
  File "/pax/paxml/paxml/train.py", line 271 in train_and_evaluate
  File "/pax/paxml/paxml/main.py", line 290 in run_experiment
  File "/pax/paxml/paxml/main.py", line 535 in run
  File "/usr/local/lib/python3.8/dist-packages/gin/config.py", line 1582 in gin_wrapper
  File "/pax/paxml/paxml/main.py", line 588 in main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251 in _run_main
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303 in run
  File "/pax/paxml/paxml/main.py", line 631 in <module>

There is no problem when NUM_MICROBATCHES = 1.

It would be great if someone could look into this to figure out what may be causing XLA to break when using NUM_MICROBATCHES > 1.

abhinavgoel95 commented 2 years ago

HLO: https://drive.google.com/drive/folders/1eaMwD6EWdA8XB5FS4KD9XI6iXtyLJn0D?usp=sharing

cheshire commented 2 years ago

Tracked internally in b/253051570.

zhangqiaorjc commented 1 year ago

Already fixed.