keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.48k forks source link

AttributeError: 'SymbolicTensor' object has no attribute '_datatype_enum' #19067

Open innat opened 9 months ago

innat commented 9 months ago

Works in Keras 2 but not in Keras 3.

!pip install keras-nightly -q

import tensorflow as tf # 2.15
import keras
from keras import layers
import numpy as np
keras.__version__ # 3.0.3.dev2024011803

processing_model = keras.Sequential(
    [
        layers.Normalization(
            mean=[123.675, 116.28, 103.53],
            variance=[np.square(58.395), np.square(57.12), np.square(57.375)]
        )
    ]
)

dataset = tf.data.Dataset.from_tensor_slices(
    (
        tf.random.uniform([4, 32, 224, 224, 3]),
        tf.random.uniform([4, 3]),
    )
)
for i in range(3):
    ds = dataset
    ds = ds.map(lambda x, y: (processing_model(x), y))
    print(i)
0
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-10-d75ead39f56a>](https://localhost:8080/#) in <cell line: 1>()
      1 for i in range(3):
      2     ds = dataset
----> 3     ds = ds.map(lambda x, y: (processing_model(x), y))
      4     print(i)

7 frames
[/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/tensor.py](https://localhost:8080/#) in __getattr__(self, name)
    259         tf.experimental.numpy.experimental_enable_numpy_behavior()
    260       """)
--> 261     self.__getattribute__(name)
    262 
    263   @property

AttributeError: 'SymbolicTensor' object has no attribute '_datatype_enum'

Similar issue https://github.com/tensorflow/tensorflow/issues/29931#issuecomment-529637694

grasskin commented 8 months ago

Thank you @innat, as a temporary workaround you can call processing_model.build((None, 224, 224, 3)) before the loop. Alternatively, the issue also goes away by setting tensorflow to eager mode so it seems likely that the model build is not being triggered in graph mode from within the map call.