stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

Update __bool__() in core.py #31

Closed Ivan-Zhou closed 1 year ago

Ivan-Zhou commented 1 year ago

I encountered an error here when training Levanter's Llama. The complain was ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

Through debugging, I found that the self.array is an (non-numpy) Array. bool(Array) triggeres an error, but bool(self.array.all()) would output a bool. If I understand correctly, we should use self.array.all() instead of self.array.any() here, correct?

-> return bool(self.array)
(Pdb) self.array
Array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]], dtype=bool)
(Pdb) self.array.shape
(2048, 8)
(Pdb) bool(self.array)
*** ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
(Pdb) bool(self.array.all())
True

I also tried training the GPT model, but this line was not triggered. I don't know exactly what was the intended input format to this line. If I miss any edge cases, please correct me.


For reference, here's the full tracer of my original error:

Traceback (most recent call last):
  File "/home/ivan/levanter/src/levanter/main/train_lm.py", line 187, in <module>
    levanter.config.main(main)()
  File "/home/ivan/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/ivan/levanter/src/levanter/main/train_lm.py", line 118, in main
    state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
  File "/home/ivan/levanter/src/levanter/trainer.py", line 203, in initial_state
    model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init)
  File "/home/ivan/venv310/lib/python3.10/site-packages/haliax/partitioning.py", line 333, in f
    return cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
  File "<string>", line 4, in __eq__
Traceback (most recent call last):
  File "/home/ivan/levanter/src/levanter/main/train_lm.py", line 187, in <module>
    levanter.config.main(main)()
  File "/home/ivan/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/ivan/levanter/src/levanter/main/train_lm.py", line 118, in main
    state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
  File "/home/ivan/levanter/src/levanter/trainer.py", line 203, in initial_state
    model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init)
  File "/home/ivan/venv310/lib/python3.10/site-packages/haliax/partitioning.py", line 333, in f
    return cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/pjit.py", line 775, in infer_params
    return common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/pjit.py", line 973, in _pjit_jaxpr
    canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/pjit.py", line 951, in _check_and_canonicalize_out_shardings
    out_shardings_flat = flatten_axis_resources(
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/pjit.py", line 878, in flatten_axis_resources
    errors = prefix_errors(axis_tree, dummy_tree)
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 433, in prefix_errors
    return list(_prefix_error((), prefix_tree, full_tree, is_leaf))
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 907, in _prefix_error
    yield from _prefix_error((*key_path, k), t1, t2)
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 907, in _prefix_error
    yield from _prefix_error((*key_path, k), t1, t2)
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 907, in _prefix_error
    yield from _prefix_error((*key_path, k), t1, t2)
  [Previous line repeated 3 more times]
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/tree_util.py", line 880, in _prefix_error
    if prefix_tree_meta != full_tree_meta:
  File "<string>", line 4, in __eq__
  File "/home/ivan/venv310/lib/python3.10/site-packages/haliax/core.py", line 590, in __bool__
    return bool(self.array)
  File "/home/ivan/venv310/lib/python3.10/site-packages/jax/_src/array.py", line 257, in __bool__
    return bool(self._value)
jax._src.traceback_util.UnfilteredStackTrace: ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

Here is the debugging on Levanter side:

> /home/ivan/levanter/src/levanter/main/train_lm.py(118)main()
-> state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
(Pdb) training_key
Array([2467461003, 3840466878], dtype=uint32)

# trainer.py
(Pdb) self.parameter_axis_mapping
{'mlp': 'model', 'heads': 'model', 'batch': 'data', 'embed': 'data'}
(Pdb) self._init_model_and_opt_state
<bound method Trainer._init_model_and_opt_state of <levanter.trainer.Trainer object at 0x7f657f9b78e0>>

# Haliax.core
-> return bool(self.array)
(Pdb) self.array
Array([[ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       ...,
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True],
       [ True,  True,  True, ...,  True,  True,  True]], dtype=bool)
(Pdb) self.array.shape
(2048, 8)
(Pdb)
dlwh commented 1 year ago

no, it's standard in all of these array libraries that bool(array) is defined iff array.size == 1. (jax, numpy, torch, etc)

this looks like an error within an error... something else is wrong and it's trying to report an error, but it explodes first

Ivan-Zhou commented 1 year ago

@dlwh I agree with your judgement that this is an error within an error.

I spent some time tweaking and debugging with llama_nano of https://github.com/stanford-crfm/levanter/pull/298. I haven't found the exact root cause yet, but I had a few observations:

  1. The bool() is very likely part of the entire error. With the proposed fix, the model training moves on well.
  2. The error happens immediately after the following line. jnp.equal(x1, x2) outputs a bool array of shape [2048, 8], which comes from cos_cached and sin_cached of this line.
  File "/home/ivan/levanter/src/levanter/trainer.py", line 203, in initial_state
    model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init)
  File "/home/ivan/haliax/src/haliax/partitioning.py", line 333, in f
    out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
  File "<string>", line 4, in __eq__
  File "/home/ivan/haliax/src/haliax/core.py", line 478, in __eq__
    return haliax.equal(self, other)
  File "/home/ivan/haliax/src/haliax/wrap.py", line 100, in binop
    return NamedArray(op(a.array, b.array), axes)
  File "/home/ivan/haliax/src/haliax/__init__.py", line 632, in equal
    return jnp.equal(x1, x2)  # type: ignore
  1. The error is more related to partition. Without partition, this error won't be triggered.
Ivan-Zhou commented 1 year ago

Given this is likely not an issue at Haliax and the PR is not the proper fix, I am going to close this PR and instead move the investigation to https://github.com/stanford-crfm/levanter/issues/299.