djsaber / Keras-ViT

基于Keras实现ViT模型,通过预训练权重在cifar-10数据集进行fine tuning,对图像分类准确率测试。
MIT License
7 stars 1 forks source link

Running on Google Colab got Keras error #1

Open michelpf opened 3 months ago

michelpf commented 3 months ago

Trying the very simple demo shared on documentation page, running on Google Colab:


from keras_vit.vit import ViT_B16
vit_1 = ViT_B16(weights = "imagenet21k")

I got the following error.

AttributeError                            Traceback (most recent call last)
[<ipython-input-39-e0ab52c82c08>](https://localhost:8080/#) in <cell line: 2>()
      1 from keras_vit.vit import ViT_B16
----> 2 vit_1 = ViT_B16(weights = "imagenet21k")

7 frames
[/usr/local/lib/python3.10/dist-packages/keras_vit/layers.py](https://localhost:8080/#) in build(self, input_shape)
     71 
     72     def build(self, input_shape):
---> 73         self.dk = K.sqrt(K.cast(input_shape[-1]//self.heads, dtype=K.tf.float32))
     74         self.q_dense = layers.Dense(input_shape[-1], name="query")
     75         self.k_dense = layers.Dense(input_shape[-1], name="key")

AttributeError: Exception encountered when calling layer 'transformer_block_0' (type TransformerEncoder).

in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras_vit/layers.py", line 146, in call  *
        x = self.multi_head_attens(x)
    File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.10/dist-packages/keras_vit/layers.py", line 73, in build
        self.dk = K.sqrt(K.cast(input_shape[-1]//self.heads, dtype=K.tf.float32))

    AttributeError: module 'keras.backend' has no attribute 'tf'

Call arguments received by layer 'transformer_block_0' (type TransformerEncoder):
  • inputs=tf.Tensor(shape=(None, 197, 768), dtype=float32)

Checking the Keras version. Seems to be compatible.

keras.__version__
2.15.0
Cyrus-Hikari commented 2 months ago

我试着运行cifar10 demo,收到以下报错: Traceback (most recent call last): File "D:\Keras-ViT-1.0.1\fine_tuning_on_CIFAR10_demo.py", line 67, in bestloss, = vit.evaluate(valid_data_gen, steps=VALIDATION_STEPS) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "G:\P\Lib\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler raise e.with_traceback(filtered_tb) from None File "G:\P\Lib\site-packages\tensorflow\python\eager\execute.py", line 53, in quick_execute tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, ^^^^^^^ TypeError: <tf.Tensor 'multi_head_attention_layer/Sqrt:0' shape=() dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'multi_head_attention_layer/Sqrt:0' shape=() dtype=float32> was defined here: File "D:\Keras-ViT-1.0.1\fine_tuning_on_CIFAR10_demo.py", line 32, in File "D:\Keras-ViT-1.0.1\ViT_Keras\vit.py", line 226, in ViT_B32 File "D:\Keras-ViT-1.0.1\ViT_Keras\vit.py", line 76, in init File "G:\P\Lib\site-packages\keras\src\layers\layer.py", line 223, in build_wrapper File "D:\Keras-ViT-1.0.1\ViT_Keras\vit.py", line 94, in build File "D:\Keras-ViT-1.0.1\ViT_Keras\vit.py", line 101, in call File "G:\P\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler File "G:\P\Lib\site-packages\keras\src\layers\layer.py", line 846, in call File "G:\P\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler File "G:\P\Lib\site-packages\keras\src\ops\operation.py", line 48, in call File "G:\P\Lib\site-packages\keras\src\utils\traceback_utils.py", line 156, in error_handler File "G:\P\Lib\site-packages\keras\src\ops\operation.py", line 60, in symbolic_call File "G:\P\Lib\site-packages\keras\src\layers\layer.py", line 1011, in compute_output_spec File "G:\P\Lib\site-packages\keras\src\ops\operation.py", line 80, in compute_output_spec File "G:\P\Lib\site-packages\keras\src\backend\tensorflow\core.py", line 192, in compute_output_spec File "D:\Keras-ViT-1.0.1\ViT_Keras\layers.py", line 147, in call File "G:\P\Lib\site-packages\keras\src\utils\traceback_utils.py", line 117, in error_handler File "G:\P\Lib\site-packages\keras\src\layers\layer.py", line 771, in call File "G:\P\Lib\site-packages\keras\src\layers\layer.py", line 1279, in _maybe_build File "G:\P\Lib\site-packages\keras\src\layers\layer.py", line 223, in build_wrapper File "D:\Keras-ViT-1.0.1\ViT_Keras\layers.py", line 74, in build File "G:\P\Lib\site-packages\tensorflow\python\ops\weak_tensor_ops.py", line 88, in wrapper File "G:\P\Lib\site-packages\tensorflow\python\util\traceback_utils.py", line 150, in error_handler File "G:\P\Lib\site-packages\tensorflow\python\util\dispatch.py", line 1260, in op_dispatch_handler File "G:\P\Lib\site-packages\tensorflow\python\ops\math_ops.py", line 5455, in sqrt File "G:\P\Lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 11943, in sqrt File "G:\P\Lib\site-packages\tensorflow\python\framework\op_def_library.py", line 796, in _apply_op_helper File "G:\P\Lib\site-packages\tensorflow\python\framework\func_graph.py", line 670, in _create_op_internal File "G:\P\Lib\site-packages\tensorflow\python\framework\ops.py", line 2682, in _create_op_internal File "G:\P\Lib\site-packages\tensorflow\python\framework\ops.py", line 1177, in from_node_def

The tensor <tf.Tensor 'multi_head_attention_layer/Sqrt:0' shape=() dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=scratch_graph_3, id=1796492036544), which is out of scope.

进程已结束,退出代码1