keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.12k stars 19.49k forks source link

Custom activation functions cause TensorFlow to crash #20333

Open AtticusBeachy opened 1 month ago

AtticusBeachy commented 1 month ago

I originally posted this issue in the TensorFlow GitHub, and was told it looks like a Keras issue and I should post it here.

TensorFlow version: 2.17.0

OS: Linux Mint 22

Python version: 3.12.7

Issue: I can successfully define a custom activation function, but when I try to use it TensorFlow crashes.

Minimal reproducible example:

import tensorflow as tf
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.layers import Activation

def fourier_activation_lambda(freq):
    fn = lambda x : tf.sin(freq*x)
    return(fn)

freq = 1.0
fourier = fourier_activation_lambda(freq)

get_custom_objects()["fourier"] = Activation(fourier)

print(3*"\n")
print(f"After addition: {get_custom_objects()=}")

x_input = tf.keras.Input(shape=[5])
activation = "fourier"
layer_2 = tf.keras.layers.Dense(100, input_shape = [5],
                                activation=activation,
                                )(x_input)

model = tf.keras.Model(inputs=x_input, outputs=layer_2)
model.compile(optimizer='adam', loss='mse')
model.summary()

The output of the print statement above indicates that the custom activation function was added successfully. Maybe the crash is related to "built=False"?

# output of print statement
get_custom_objects()={'fourier': <Activation name=activation, built=False>}

The error message reads:

# error message
Traceback (most recent call last):
  File "/home/orca/Downloads/minimal_tf_err.py", line 20, in <module>
    layer_2 = tf.keras.layers.Dense(100, input_shape = [5],
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/orca/.local/lib/python3.12/site-packages/keras/src/layers/core/dense.py", line 89, in __init__
    self.activation = activations.get(activation)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/orca/.local/lib/python3.12/site-packages/keras/src/activations/__init__.py", line 104, in get
    raise ValueError(
ValueError: Could not interpret activation function identifier: fourier
fchollet commented 1 month ago

In the example above, you are passing the string "fourier" as activation. A string is not a tensor-in tensor-out callable, so it doesn't work as an activation.

Your code, simplified:

activation = "fourier"
layer_2 = tf.keras.layers.Dense(100, input_shape=[5], activation=activation)(x_input)
AtticusBeachy commented 1 month ago

I added "fourier" as a key in the global dictionary of custom objects, as described here:

# add fourier function to global dictionary of custom objects
get_custom_objects()["fourier"] = Activation(fourier)

Further, the code runs successfully on my old computer (using TensorFlow 2.10.0). Instead of crashing it runs to the end and outputs the model summary:

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 5)]               0         

 dense (Dense)               (None, 100)               600       

=================================================================
Total params: 600
Trainable params: 600
Non-trainable params: 0
mehtamansi29 commented 1 month ago

Hi @AtticusBeachy -

Instead of using get_custom_objects() with "fourier" you can directly use your custom fourier activation function "fourier_activation_lambda" in your model activation function.

def fourier_activation_lambda(freq):
    fn = lambda x : tf.sin(freq*x)
    return(fn)

freq = 1.0
fourier = fourier_activation_lambda(freq)

print(fourier)                  #<function fourier_activation_lambda.<locals>.<lambda> at 0x7ace59acec20>

x_input = tf.keras.Input(shape=[5])
layer_2 = tf.keras.layers.Dense(100, input_shape = [5],activation=fourier,)(x_input)
model = tf.keras.Model(inputs=x_input, outputs=layer_2)
model.compile(optimizer='adam', loss='mse')
model.summary()

Attached gist for the reference as well.

AtticusBeachy commented 1 month ago

Thanks for the help! That solution works well for the minimal example, but is not as clean for my actual code. I have it working now though, even if it's not pretty.

Do you know whether there are any plans to solve the underlying bug?

sampathweb commented 3 weeks ago

@AtticusBeachy - We should avoid adding keys to get_custom_objects dict directly. Its better to use register_keras_serializable API as shown below. I understand it had worked in the past, but that's not the intended behavior. By using the serializable API, you also can save and load these activation functions with the model. Hope this clarifies -

import tensorflow as tf
from tensorflow.keras.utils import get_custom_objects

@tf.keras.utils.register_keras_serializable(name="fourier")
def fourier_activation_lambda(freq):
    fn = lambda x : tf.sin(freq*x)
    return(fn)

freq = 1.0
fourier = fourier_activation_lambda(freq)

print(3*"\n")
print(f"After addition: {get_custom_objects()=}")

x_input = tf.keras.Input(shape=[5])
activation = "fourier"
layer_2 = tf.keras.layers.Dense(100, input_shape = [5],
                                activation=fourier,
                                )(x_input)

model = tf.keras.Model(inputs=x_input, outputs=layer_2)
model.compile(optimizer='adam', loss='mse')
model.summary()

# Output
After addition: get_custom_objects()={'fourier': <Activation name=activation, built=False>, 'Custom>forier': <function fourier_activation_lambda at 0x78b543235b40>, 'Custom>fourier': <function fourier_activation_lambda at 0x78b5430b20e0>}
github-actions[bot] commented 1 week ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.