alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.08k stars 360 forks source link

IndexError: `InlinedVector::at(size_type) const` failed bounds check #957

Open caixiiaoyang opened 1 year ago

caixiiaoyang commented 1 year ago

Please describe the bug IndexError: InlinedVector::at(size_type) const failed bounds check

System information and environment

To Reproduce Steps to reproduce the behavior: 1.Training an llama model implemented by flax produces the following error

  1. See error 2023-09-24 12:29:49,782 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 10.233.115.148:6379... 2023-09-24 12:29:49,795 INFO worker.py:1528 -- Connected to Ray cluster. Training/epoch 0: 0%| | 0/7473 [00:01<?, ?it/s] Traceback (most recent call last): File "./Trainer/train_ray_batch.py", line 149, in main() File "./Trainer/train_ray_batch.py", line 139, in main state, loss = train_step(state, seq, seq_mask, labels, labels_mask) File "/home/mpi/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, *kwargs) File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 121, in call self._decode_args_and_get_executable(args)) File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 191, in _decode_args_and_get_executable executable = _compile_parallel_executable(f, in_tree, out_tree_hashable, File "/home/mpi/.local/lib/python3.8/site-packages/jax/linear_util.py", line 309, in memoized_fun ans = call(fun, *args) File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 223, in _compile_parallel_executable return method.compile_executable(fun, in_tree, out_tree_thunk, File "/home/mpi/.local/lib/python3.8/site-packages/alpa/parallel_method.py", line 108, in compile_executable return compile_shard_executable(fun, in_tree, out_tree_thunk, File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 78, in compile_shard_executable return shard_parallel_internal(fun, in_tree, out_tree_thunk, File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 139, in shard_parallel_internal hlo, stage_plan = run_auto_sharding_pass(hlo, logical_mesh_choices[0], File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass xe.run_auto_sharding(hlo.get_module(), compile_options) jax._src.traceback_util.UnfilteredStackTrace: IndexError: InlinedVector::at(size_type) const failed bounds check

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "./Trainer/train_ray_batch.py", line 149, in main() File "./Trainer/train_ray_batch.py", line 139, in main state, loss = train_step(state, seq, seq_mask, labels, labels_mask) File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass xe.run_auto_sharding(hlo.get_module(), compile_options) IndexError: InlinedVector::at(size_type) const failed bounds check

Screenshots image

Code snippet to reproduce the problem @alpa.parallelize(batch_argnums=(1,2,3,4)) def train_step(state, seq, seq_mask, labels, labels_mask):

def train_forward(params):

seq, seq_mask, labels, labels_mask = data_batch

position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(seq).shape[-1]),seq.shape)
outputs = state.apply_fn(
  params, 
  seq, 
  seq_mask, 
  position_ids, 
  deterministic = False,
  return_dict = False,
)
logits = outputs[0]
loss = cross_entropy_loss(logits, labels, mask=labels_mask)
return loss

dynamic_scale = state.dynamic_scale if dynamic_scale: grad_fn = dynamic_scale.value_and_grad(train_forward) dynamic_scale, is_fin, loss, grads = grad_fn(state.params)

new_state = state.apply_gradients(grads=grads)

if dynamic_scale: new_state = new_state.replace( opt_state=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params), master_copy=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.master_copy, state.master_copy), dynamic_scale=dynamic_scale)

return new_state, loss

def main() -> None: global llama_model alpa.init(cluster="ray") lr = 0.001 batch_size = 1 max_len = 640 n_epochs = 7

load_pretrained_model = False ckpt_dir="./JAX_model/7B"

prepare dataset

tokenizer = LLaMATokenizer("./JAX_model/tokenizer.model") dataset = GSMDataset(split='train') collate_fn = partial(gsm_collate_fn_train, tokenizer=tokenizer, max_len=max_len) dataloader = LlamaDataLoader(dataset, batch_size, collate_fn)

set config

if load_pretrained_model: with open(Path(ckpt_dir)/"params.json", "r") as f: config_params = json.loads(f.read()) config_params.update({'vocab_size': len(tokenizer), 'max_seq_len':max_len}) llama_config = LLaMAConfig(**config_params) else: llama_config = LLaMAConfig(num_hidden_layers=4) llama_model = LLaMAForCausalLMModule(llama_config)

init model

input_ids = jnp.ones((batch_size,max_len)) attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),input_ids.shape) params = llama_model.init(input_ids, attention_mask, position_ids, return_dict=False, init_cache=False)

if load_pretrained_model: param = restore(Path(ckpt_dir)/"consolidated.nra", replace_keys=False) params['param'] = param

n_steps = math.ceil(len(dataloader))

schedule = warmup_cosine_decay_schedule( init_value=0., peak_value=lr, warmup_steps=n_steps, decay_steps=n_steps + 1, end_value=lr, ) optimizer = adamw(learning_rate=schedule)

use_master_copy = True dynamic_scale = DynamicScale() alpa.global_config.flax_always_use_fp16_embedding = True state = TrainState.create(apply_fn=llama_model.run, params=params, tx=optimizer,dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)

for epoch in range(n_epochs): with tqdm(dataloader) as tepoch: tepoch.set_description(f"Training/epoch {epoch}") for batch in tepoch: seq, seq_mask, labels, labels_mask = batch state, loss = train_step(state, seq, seq_mask, labels, labels_mask)

if name == 'main': main()

Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.