To make this less scary, we should add a guide for how to fix these errors, and detect them automatically.
This one was solved by using jax.ensure_compile_time_eval(), but sometimes the answer is to not use a static field.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 145, in <module>
main()
File "/Users/ivan/dev/levanter/local_dev_roundtrip.py", line 41, in main
model = converter.load_pretrained(LlamaLMHeadModel, RepoRef(model_id))
File "/Users/ivan/dev/levanter/src/levanter/compat/hf_checkpoints.py", line 469, in load_pretrained
lev_model = eqx.filter_eval_shape(lm_model_cls.init, Vocab, config, key=PRNGKey(0))
File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/site-packages/equinox/_eval_shape.py", line 36, in filter_eval_shape
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
File "/Users/ivan/dev/miniconda3/envs/levanter/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:220 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163549296> is referred to by <NamedArray 6163407600>.array
<NamedArray 6163407600> is referred to by <tuple 6163920192>[2]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values
Traced<ShapedArray(float32[128,4])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /Users/ivan/dev/levanter/src/levanter/models/llama.py:219 (_get_cos_sin_cache)
<DynamicJaxprTracer 6163548816> is referred to by <NamedArray 6163404000>.array
<NamedArray 6163404000> is referred to by <tuple 6163920192>[1]
<tuple 6163920192> is referred to by <_FlattenedData 6163738976>.static_field_values
To make this less scary, we should add a guide for how to fix these errors, and detect them automatically.
This one was solved by using jax.ensure_compile_time_eval(), but sometimes the answer is to not use a static field.