Please describe the bug
IndexError: InlinedVector::at(size_type) const failed bounds check
System information and environment
OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker):
Python version:3.8.10
CUDA version:11.3
NCCL version:2.9
cupy version:11.3
GPU model and memory:2*A100(80G)
Alpa version:0.2.3
TensorFlow version:2.8.0
JAX version:0.3.22
To Reproduce
Steps to reproduce the behavior:
1.Training an llama model implemented by flax produces the following error
See error
2023-09-24 12:29:49,782 INFO worker.py:1342 -- Connecting to existing Ray cluster at address: 10.233.115.148:6379...
2023-09-24 12:29:49,795 INFO worker.py:1528 -- Connected to Ray cluster.
Training/epoch 0: 0%| | 0/7473 [00:01<?, ?it/s]
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, *kwargs)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 121, in call
self._decode_args_and_get_executable(args))
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 191, in _decode_args_and_get_executable
executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
File "/home/mpi/.local/lib/python3.8/site-packages/jax/linear_util.py", line 309, in memoized_fun
ans = call(fun, *args)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/api.py", line 223, in _compile_parallel_executable
return method.compile_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/parallel_method.py", line 108, in compile_executable
return compile_shard_executable(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 78, in compile_shard_executable
return shard_parallel_internal(fun, in_tree, out_tree_thunk,
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/compile_executable.py", line 139, in shard_parallel_internal
hlo, stage_plan = run_auto_sharding_pass(hlo, logical_mesh_choices[0],
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
jax._src.traceback_util.UnfilteredStackTrace: IndexError: InlinedVector::at(size_type) const failed bounds check
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
IndexError: InlinedVector::at(size_type) const failed bounds check
Screenshots
Code snippet to reproduce the problem
@alpa.parallelize(batch_argnums=(1,2,3,4))
def train_step(state, seq, seq_mask, labels, labels_mask):
for epoch in range(n_epochs):
with tqdm(dataloader) as tepoch:
tepoch.set_description(f"Training/epoch {epoch}")
for batch in tepoch:
seq, seq_mask, labels, labels_mask = batch
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
if name == 'main':
main()
Additional information
Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.
Please describe the bug IndexError:
InlinedVector::at(size_type) const
failed bounds checkSystem information and environment
To Reproduce Steps to reproduce the behavior: 1.Training an llama model implemented by flax produces the following error
InlinedVector::at(size_type) const
failed bounds checkThe stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "./Trainer/train_ray_batch.py", line 149, in
main()
File "./Trainer/train_ray_batch.py", line 139, in main
state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
File "/home/mpi/.local/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345, in run_auto_sharding_pass
xe.run_auto_sharding(hlo.get_module(), compile_options)
IndexError:
InlinedVector::at(size_type) const
failed bounds checkScreenshots
Code snippet to reproduce the problem @alpa.parallelize(batch_argnums=(1,2,3,4)) def train_step(state, seq, seq_mask, labels, labels_mask):
def train_forward(params):
seq, seq_mask, labels, labels_mask = data_batch
dynamic_scale = state.dynamic_scale if dynamic_scale: grad_fn = dynamic_scale.value_and_grad(train_forward) dynamic_scale, is_fin, loss, grads = grad_fn(state.params)
new_state = state.apply_gradients(grads=grads)
if dynamic_scale: new_state = new_state.replace( opt_state=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.opt_state, state.opt_state), params=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.params, state.params), master_copy=jax.tree_map( functools.partial(jnp.where, is_fin), new_state.master_copy, state.master_copy), dynamic_scale=dynamic_scale)
return new_state, loss
def main() -> None: global llama_model alpa.init(cluster="ray") lr = 0.001 batch_size = 1 max_len = 640 n_epochs = 7
load_pretrained_model = False ckpt_dir="./JAX_model/7B"
prepare dataset
tokenizer = LLaMATokenizer("./JAX_model/tokenizer.model") dataset = GSMDataset(split='train') collate_fn = partial(gsm_collate_fn_train, tokenizer=tokenizer, max_len=max_len) dataloader = LlamaDataLoader(dataset, batch_size, collate_fn)
set config
if load_pretrained_model: with open(Path(ckpt_dir)/"params.json", "r") as f: config_params = json.loads(f.read()) config_params.update({'vocab_size': len(tokenizer), 'max_seq_len':max_len}) llama_config = LLaMAConfig(**config_params) else: llama_config = LLaMAConfig(num_hidden_layers=4) llama_model = LLaMAForCausalLMModule(llama_config)
init model
input_ids = jnp.ones((batch_size,max_len)) attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),input_ids.shape) params = llama_model.init(input_ids, attention_mask, position_ids, return_dict=False, init_cache=False)
if load_pretrained_model: param = restore(Path(ckpt_dir)/"consolidated.nra", replace_keys=False) params['param'] = param
n_steps = math.ceil(len(dataloader))
schedule = warmup_cosine_decay_schedule( init_value=0., peak_value=lr, warmup_steps=n_steps, decay_steps=n_steps + 1, end_value=lr, ) optimizer = adamw(learning_rate=schedule)
use_master_copy = True dynamic_scale = DynamicScale() alpa.global_config.flax_always_use_fp16_embedding = True state = TrainState.create(apply_fn=llama_model.run, params=params, tx=optimizer,dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)
for epoch in range(n_epochs): with tqdm(dataloader) as tepoch: tepoch.set_description(f"Training/epoch {epoch}") for batch in tepoch: seq, seq_mask, labels, labels_mask = batch state, loss = train_step(state, seq, seq_mask, labels, labels_mask)
if name == 'main': main()
Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.