phlippe / uvadlc_notebooks

Repository of Jupyter notebook tutorials for teaching the Deep Learning Course at the University of Amsterdam (MSc AI), Fall 2023
https://uvadlc-notebooks.readthedocs.io/en/latest/
MIT License
2.59k stars 590 forks source link

jax tutorial 17: cannot mix flax and jit #94

Closed murphyk closed 1 year ago

murphyk commented 1 year ago

Tutorial: -1 (Fill-in number of tutorial)

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial17/SimCLR.html

simclr_model = simclr_trainer.model.bind({'params': simclr_trainer.state.params, 
                                          'batch_stats': simclr_trainer.state.batch_stats})
encode_fn = jax.jit(lambda img: simclr_model.encode(img))
train_feats_simclr = prepare_data_features(encode_fn, train_img_data)
test_feats_simclr = prepare_data_features(encode_fn, test_img_data)

produces the error below. Omitting jax.jit fixes this (but is slower)

JaxTransformError                         Traceback (most recent call last)
[<ipython-input-31-23f764a52004>](https://localhost:8080/#) in <cell line: 4>()
      2                                           'batch_stats': simclr_trainer.state.batch_stats})
      3 encode_fn = jax.jit(lambda img: simclr_model.encode(img))
----> 4 train_feats_simclr = prepare_data_features(encode_fn, train_img_data)
      5 test_feats_simclr = prepare_data_features(encode_fn, test_img_data)

3 frames
    [... skipping hidden 12 frame]

    [... skipping hidden 5 frame]

    [... skipping hidden 5 frame]

[/usr/local/lib/python3.9/dist-packages/flax/core/tracers.py](https://localhost:8080/#) in check_trace_level(base_level)
     34   level = trace_level(current_trace())
     35   if level != base_level:
---> 36     raise errors.JaxTransformError()

JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
phlippe commented 1 year ago

Thanks, finally fixed in 2324e9ae21b470cc254af8800800a70bff41178a!