amd / ZenDNN-tensorflow-plugin

Other
6 stars 4 forks source link

keras_nlp Gemma error #1

Open pwipo opened 2 months ago

pwipo commented 2 months ago

Hello, I had error then exec simple test with Gemma on keras_nlp: Input dims must be <= 4 and >=1 tensorflow lib 2.17.0 (try 2.16.0) python 3.11

script: import os os.environ["KAGGLE_USERNAME"] = "" os.environ["KAGGLE_KEY"] = "" os.environ["KERAS_BACKEND"] = "tensorflow" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9" os.environ["TF_ENABLE_ZENDNN_OPTS"] = "1" os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" import keras #3.5.0 import keras_nlp #0.14.4 gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en") gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)

sanadani commented 1 month ago

Hi @pwipo Thanks for reporting this issue.

We are looking into it and it will be fixed in upcoming release.

Till we fix and update this repo. Please find the below workaround -

Disable the Softmax rewrite as below

Set below env variable

Build the zentf from source as mentioned in README.