Closed murphyk closed 1 year ago
Tutorial: -1 (Fill-in number of tutorial)
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial17/SimCLR.html
simclr_model = simclr_trainer.model.bind({'params': simclr_trainer.state.params, 'batch_stats': simclr_trainer.state.batch_stats}) encode_fn = jax.jit(lambda img: simclr_model.encode(img)) train_feats_simclr = prepare_data_features(encode_fn, train_img_data) test_feats_simclr = prepare_data_features(encode_fn, test_img_data)
produces the error below. Omitting jax.jit fixes this (but is slower)
jax.jit
JaxTransformError Traceback (most recent call last) [<ipython-input-31-23f764a52004>](https://localhost:8080/#) in <cell line: 4>() 2 'batch_stats': simclr_trainer.state.batch_stats}) 3 encode_fn = jax.jit(lambda img: simclr_model.encode(img)) ----> 4 train_feats_simclr = prepare_data_features(encode_fn, train_img_data) 5 test_feats_simclr = prepare_data_features(encode_fn, test_img_data) 3 frames [... skipping hidden 12 frame] [... skipping hidden 5 frame] [... skipping hidden 5 frame] [/usr/local/lib/python3.9/dist-packages/flax/core/tracers.py](https://localhost:8080/#) in check_trace_level(base_level) 34 level = trace_level(current_trace()) 35 if level != base_level: ---> 36 raise errors.JaxTransformError() JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
Thanks, finally fixed in 2324e9ae21b470cc254af8800800a70bff41178a!
Tutorial: -1 (Fill-in number of tutorial)
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial17/SimCLR.html
produces the error below. Omitting
jax.jit
fixes this (but is slower)