Closed Ivan-Zhou closed 1 year ago
no, it's standard in all of these array libraries that bool(array)
is defined iff array.size == 1.
(jax, numpy, torch, etc)
this looks like an error within an error... something else is wrong and it's trying to report an error, but it explodes first
@dlwh I agree with your judgement that this is an error within an error.
I spent some time tweaking and debugging with llama_nano of https://github.com/stanford-crfm/levanter/pull/298. I haven't found the exact root cause yet, but I had a few observations:
bool()
is very likely part of the entire error. With the proposed fix, the model training moves on well. jnp.equal(x1, x2)
outputs a bool array of shape [2048, 8], which comes from cos_cached and sin_cached of this line. File "/home/ivan/levanter/src/levanter/trainer.py", line 203, in initial_state
model, opt_state = named_jit(self._init_model_and_opt_state, self.parameter_axis_mapping)(model_init)
File "/home/ivan/haliax/src/haliax/partitioning.py", line 333, in f
out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
File "<string>", line 4, in __eq__
File "/home/ivan/haliax/src/haliax/core.py", line 478, in __eq__
return haliax.equal(self, other)
File "/home/ivan/haliax/src/haliax/wrap.py", line 100, in binop
return NamedArray(op(a.array, b.array), axes)
File "/home/ivan/haliax/src/haliax/__init__.py", line 632, in equal
return jnp.equal(x1, x2) # type: ignore
Given this is likely not an issue at Haliax and the PR is not the proper fix, I am going to close this PR and instead move the investigation to https://github.com/stanford-crfm/levanter/issues/299.
I encountered an error here when training Levanter's Llama. The complain was
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
.Through debugging, I found that the self.array is an (non-numpy) Array.
bool(Array)
triggeres an error, butbool(self.array.all())
would output a bool. If I understand correctly, we should useself.array.all()
instead ofself.array.any()
here, correct?I also tried training the GPT model, but this line was not triggered. I don't know exactly what was the intended input format to this line. If I miss any edge cases, please correct me.
For reference, here's the full tracer of my original error:
Here is the debugging on Levanter side: