keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.86k stars 19.44k forks source link

Use of torch loss function with Keras 3.0 #18686

Closed emi-dm closed 10 months ago

emi-dm commented 11 months ago

Is possible to use a torch loss function (giving it in compile method) with the new Keras?

SuryanarayanaY commented 11 months ago

Hi @emi-research-dl ,

You can refer this tutorial for writing a custom training loop for torch model.

However if you are asking for passing torch loss function to model.compile() directly, then AFAIK it may not be possible directly.For optimizers it will choose based on backend,but for loss I doubt it. For loss functions Keras3 has its own implementation and I can see epsilon values of respective backend being used in calculations.

I may be wrong but would like to hear from Keras Dev team for confirmation.

emi-dm commented 11 months ago

Hi, I was referring to using loss functions or torch metrics with the model. In the case of using only CPU I imagine it should not be complex, but could they be used with CUDA? I think cost functions can be passed, but when using the "compute_loss" method, you have to take into account that torch functions receive logits first (contrary to TF or Keras standard).

My doubt comes from a failure when trying to use a loss function (CrossEntropy) on a GPU device. I think Keras may have a bug when moving the tensors to the GPU (tensor.to(device)).

Any dev could clarify it for me and provide an example?

Thanks in advance!!

innat commented 11 months ago

@emi-research-dl This may help. Precisely, see cell no. 15

emi-dm commented 11 months ago

Yes, but it simply loads a torch model and performs a train. I meant to directly pass a loss function and use it with the Keras model itself, not the preloaded torch one. do you know of any examples or if it could be done? I imagine that yes, using it inside train_step in a class that inherits from keras.Model, but I think it doesn't work when using CUDA, due to the management of the tensors (I imagine that the torch function needs them in GPU or CPU and Keras manages them in another place completely different). Attached example:

import os

os.environ["KERAS_BACKEND"] = "torch"

import keras
from keras import layers
import torch
import tensorflow as tf

print(
    f"Keras version: {keras.__version__}, Keras backend: {keras.backend.backend()}, PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

class CustomModel(keras.Model):

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        logits = self(x, training=True)
        loss = self.compute_loss(y=logits, y_pred=y)
        y_pred = keras.activations.softmax(logits, axis=-1)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        logits = self(x, training=False)
        y_pred = keras.activations.softmax(logits, axis=-1)
        # Updates the metrics tracking the loss
        loss = self.compute_loss(y=logits, y_pred=y)
        # Update the metrics.
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

AUTOTUNE = tf.data.AUTOTUNE
train_ds = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(keras.utils.to_categorical(y_train,10)))
test_ds = torch.utils.data.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(keras.utils.to_categorical(y_test,10)))
# train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(lambda x, y: (x, keras.utils.to_categorical(y, 10))).batch(32).prefetch(buffer_size=AUTOTUNE)
# test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(lambda x, y: (x, keras.utils.to_categorical(y, 10))).batch(32).prefetch(buffer_size=AUTOTUNE)

inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu")(inputs)
x = layers.Flatten()(x)
x = layers.Dense(128, activation="relu")(x)
outputs = layers.Dense(10)(x)
model = CustomModel(inputs=inputs, outputs=outputs)

model.compile(loss=torch.nn.CrossEntropyLoss(),
              optimizer=keras.optimizers.Adam(),
              metrics=["accuracy"])

model.fit(train_ds, epochs=5, validation_data=test_ds, verbose=1)
model.save("model.keras")

If you use torch.utils.data.TensorDataset into train method, it throws this error:


2023-10-26 08:41:56.462750: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Keras version: 3.0.0, Keras backend: torch, PyTorch version: 2.1.0+cu118
CUDA available: True
Traceback (most recent call last):
  File "/home/user/PycharmProjects/Segmentation/main.py", line 81, in <module>
    x = layers.Conv2D(32, 3, activation="relu")(inputs)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

If you use tf.data.Dataset.from_tensor_slices, it returns this one:

2023-10-26 08:56:50.869581: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Keras version: 3.0.0, Keras backend: torch, PyTorch version: 2.1.0+cu118
CUDA available: True
Traceback (most recent call last):
  File "/home/user/PycharmProjects/CV19-Segmentation/main.py", line 76, in <module>
    train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(lambda x, y: (x, keras.utils.to_categorical(y, 10))).batch(32).prefetch(buffer_size=AUTOTUNE)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2280, in map
    return map_op._map_v2(
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/data/ops/map_op.py", line 37, in _map_v2
    return _MapDataset(
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/data/ops/map_op.py", line 107, in __init__
    self._map_func = structured_function.StructuredFunctionWrapper(
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 265, in __init__
    self._function = fn_factory()
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1227, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1197, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 695, in _initialize
    self._concrete_variable_creation_fn = tracing_compilation.trace_function(
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    concrete_function = _maybe_define_function(
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    concrete_function = _create_concrete_function(
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    traced_func_graph = func_graph_module.func_graph_from_py_func(
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 598, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 231, in wrapped_fn
    ret = wrapper_helper(*args)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 161, in wrapper_helper
    ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 693, in wrapper
    raise e.ag_error_metadata.to_exception(e)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 690, in wrapper
    return converted_call(f, args, kwargs, options=options)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/tmp/__autograph_generated_filestywdmce.py", line 5, in <lambda>
    tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: (x, ag__.converted_call(keras.utils.to_categorical, (y, 10), None, lscope)), 'lscope', ag__.STD)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/autograph/core/function_wrappers.py", line 113, in with_function_scope
    return thunk(scope)
  File "/tmp/__autograph_generated_filestywdmce.py", line 5, in <lambda>
    tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: (x, ag__.converted_call(keras.utils.to_categorical, (y, 10), None, lscope)), 'lscope', ag__.STD)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
    return _call_unconverted(f, args, kwargs, options)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 460, in _call_unconverted
    return f(*args)
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/keras/src/utils/numerical_utils.py", line 92, in to_categorical
    x = np.array(x, dtype="int64")
  File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/tensorflow/python/framework/tensor.py", line 628, in __array__
    raise NotImplementedError(
NotImplementedError: in user code:

    File "/home/user/PycharmProjects/Segmentation/main.py", line 76, in None  *
        lambda x, y: (x, keras.utils.to_categorical(y, 10))
    File "/home/user/miniconda3/envs/keras-torch/lib/python3.9/site-packages/keras/src/utils/numerical_utils.py", line 92, in to_categorical  **
        x = np.array(x, dtype="int64")

    NotImplementedError: Cannot convert a symbolic tf.Tensor (args_1:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

Without using these methods, it works but can't serialize the loss function to save the model:

TypeError: Cannot serialize object CrossEntropyLoss() of type <class 'torch.nn.modules.loss.CrossEntropyLoss'>. To be serializable, a class must implement the `get_config()` method.ions.

SuryanarayanaY commented 11 months ago

Hi @emi-research-dl ,

ALternatively, you can create a class subclassing torch.nn.Module and define a keras model inside it and override forward method to call the kears model with given input. Please refer this example whether it can resolve your purpose.

emi-dm commented 11 months ago

It's great, thanks @SuryanarayanaY . But I would like Keras to automatically log the metrics for me, without having to define that logic.

fchollet commented 11 months ago

RuntimeError: CUDA error: out of memory

That doesn't quite sound related to the use of the torch function as the loss argument. What is really causing the issue?

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(lambda x, y: (x, keras.utils.to_categorical(y, 10))).batch(32).prefetch(buffer_size=AUTOTUNE)

What is failing is this line, which is expected: you are using the Torch backend (so keras.utils.to_categorical will assume torch tensors) and mapping Keras ops in a TF data pipeline (where tensors are symbolic TF tensors).

github-actions[bot] commented 10 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] commented 10 months ago

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.