stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Make a better error message for abstract tracer errors #28

Open dlwh opened 10 months ago

dlwh commented 10 months ago

To make this less scary, we should add a guide for how to fix these errors, and detect them automatically.

This one was solved by using jax.ensure_compile_time_eval(), but sometimes the answer is to not use a static field.

​
The above exception was the direct cause of the following exception:
​
Traceback (most recent call last):
  File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 145, in <module>
    main()
  File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 41, in main
    model = converter.load_pretrained(LlamaLMHeadModel, RepoRef(model_id))
  File "/Users/ivan/dev/levanter/src/levanter/compat/hf_checkpoints.py", line 469, in load_pretrained
    lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
  File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_eval_shape.py", line 36, in filter_eval_shape
    dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
  File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
​
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:220 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163549296> is referred to by <NamedArray 6163407600>.array
<NamedArray 6163407600> is referred to by <tuple 6163920192>[2]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values
​
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:219 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163548816> is referred to by <NamedArray 6163404000>.array
<NamedArray 6163404000> is referred to by <tuple 6163920192>[1]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values