keras-team / keras-io

Keras documentation, hosted live at keras.io
Apache License 2.0
2.69k stars 2.01k forks source link

Convert 'Vision Transformer without Attention' to Keras 3. #1855

Open fkouteib opened 1 month ago

fkouteib commented 1 month ago

Tensorflow and PyTorch only compatibilty.

fkouteib commented 1 month ago

On Tensorflow, I am able to train and test the model, but hit this issue when loading the saved model to do inference on it. It may be the same issue as https://github.com/keras-team/keras/issues/19492 but I am not 100% sure.

$HOME/.tf_venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:418: UserWarning: Skipping variable loading for optimizer 'adamw', because it has 1 variables whereas the saved optimizer has 219 variables. trackable.load_own_variables(weights_store.get(inner_path)) Traceback (most recent call last): File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 1092, in probabilities = predict(predict_ds) File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 1062, in predict logits = saved_model.predict(predict_ds) File "$HOME/.tf_venv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler raise e.with_traceback(filtered_tb) from None File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 720, in call augmented_images = self.data_augmentation(images) TypeError: Exception encountered when calling ShiftViTModel.call(). 'TrackedDict' object is not callable Arguments received by ShiftViTModel.call(): • images=tf.Tensor(shape=(10, 32, 32, 3), dtype=uint8)

fkouteib commented 1 month ago

On PyTorch, I am hitting this issue when compiling the initial model before training starts.

File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 937, in model(sample_ds, training=False) File "/$HOME/.torch_venv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler raise e.with_traceback(filtered_tb) from None File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "$HOMEkeras-io_rw/examples/vision/shiftvit.py", line 723, in call x = stage(x, training=False) File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/$HOME/keras-io_rw/examples/vision/shiftvit.py", line 569, in call x = shift_block(x, training=training) File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 429, in call x_splits[0] = self.get_shift_pad(x_splits[0], mode="left") TypeError: Exception encountered when calling ShiftViTBlock.call(). 'tuple' object does not support item assignment Arguments received by ShiftViTBlock.call(): • x=torch.Tensor(shape=torch.Size([256, 12, 12, 96]), dtype=float32) • training=False

fkouteib commented 1 month ago

Thx for the review and suggestion Francois! I dropped the custom train and test steps. The combination of overriding call() method and the native compute_loss() method was equivalent to the custom loss method.

Current issues I am debugging:

Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation. JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state. The function being traced when the value leaked was wrapped_fn at /home/faycel.kouteib/.tf_jax_venv/lib/python3.10/site-packages/keras/src/backend/jax/core.py:153 traced for make_jaxpr. The leaked intermediate value was created on line /home/faycel.kouteib/keras-io_rw/examples/vision/shiftvit.py:535 (). When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were: $HOME/.tf_jax_venv/lib/python3.10/site-packages/keras/src/layers/layer.py:771 (call) $HOME/.tf_jax_venv/lib/python3.10/site-packages/keras/src/layers/layer.py:1279 (_maybe_build) $HOME/.tf_jax_venv/lib/python3.10/site-packages/keras/src/layers/layer.py:223 (build_wrapper) $HOME/keras-io_rw/examples/vision/shiftvit.py:535 (build) $HOME/keras-io_rw/examples/vision/shiftvit.py:535 ()