amd / ZenDNN-tensorflow-plugin

Other
6 stars 4 forks source link

keras_nlp Gemma error #1

Closed pwipo closed 2 weeks ago

pwipo commented 3 months ago

Hello, I had error then exec simple test with Gemma on keras_nlp: Input dims must be <= 4 and >=1 tensorflow lib 2.17.0 (try 2.16.0) python 3.11

script: import os os.environ["KAGGLE_USERNAME"] = "" os.environ["KAGGLE_KEY"] = "" os.environ["KERAS_BACKEND"] = "tensorflow" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9" os.environ["TF_ENABLE_ZENDNN_OPTS"] = "1" os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" import keras #3.5.0 import keras_nlp #0.14.4 gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en") gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)

sanadani commented 2 months ago

Hi @pwipo Thanks for reporting this issue.

We are looking into it and it will be fixed in upcoming release.

Till we fix and update this repo. Please find the below workaround -

Disable the Softmax rewrite as below

Set below env variable

Build the zentf from source as mentioned in README.

sanadani commented 2 weeks ago

@pwipo Thank you for your patience. We have fixed the issue in 81bbf17. Please give a try with v5.0.