remicres / otbtf

Deep learning with otb (mirror of https://forgemia.inra.fr/orfeo-toolbox/otbtf)
Apache License 2.0
161 stars 39 forks source link

Support for TF 2.16.0 #97

Open daspk04 opened 8 months ago

daspk04 commented 8 months ago

Hi @remicres !

The last week tensorflow released the 2.16.0-rc version. One interesting point is that Keras 3 will be the default version. Keras 3 seems quite interesting, it supports multi framework (tesorflow, pytorch, jax).

Use data pipelines from any source. The Keras 3 fit()/evaluate()/predict() routines are compatible with tf.data.Dataset objects, with PyTorch DataLoader objects, with NumPy arrays, Pandas dataframes — regardless of the backend you're using. You can train a Keras 3 + TensorFlow model on a PyTorch DataLoader or train a Keras 3 + PyTorch model on a tf.data.Dataset.

So then I'm assuming that we would be able to run directly any model written in pytorch with OTBTF as well?

remicres commented 8 months ago

Hi @Pratyush1991 ,

It looks like, yes... thanks for the information, I have to dig into that. If somebody would be kind enough to provide us a minimal working example, it would be so great.

Like

daspk04 commented 8 months ago

Thanks @remicres!

I might be able do it. At-least based on current OTBTF version:

vidlb commented 3 weeks ago

Hi @daspk04 , I've pushed releases candidate images of OTBTF including TF 2.18 and OTB 9.1.
We still need to update some python scripts to Keras 3, but if you already wanna try something with otb apps you may use one of the following images :

I've made a lot of changes in the build process, and hope I didn't break anything in the TF install, thus any feedback would be much appreciated !

Cheers

daspk04 commented 2 weeks ago

Hi @vidlb !

Thanks for the upgrade. I gave it a try with the OTBTF tutorial.

There were few specific issues:

Working model example:

class SimpleCNNModel(otbtf.ModelBase):
    """" This is a subclass of `otbtf.ModelBase` to implement a CNN """

    def normalize_inputs(self, inputs):
        """ This function nomalizes the input, scaling values by 0.0001 """
        return {inp_key: keras.ops.cast(inputs[inp_key], "float32") * 0.0001}

    def get_outputs(self, normalized_inputs):
        """ This function implements the model """
        inp = normalized_inputs[inp_key]
        net = conv(inp, 16, 5, "conv1")  # 12x12x16
        net = pool(net)                  # 6x6x16
        net = conv(net, 32, 3, "conv2")  # 4x4x32
        net = pool(net)                  # 2x2x32
        net = conv(net, 64, 2, "feats")  # 1x1x32

        net = conv(net, class_nb, 1, "classifier", None)
        softmax_op = keras.layers.Softmax(name="softmax_layer")
        estim = softmax_op(net)

        return {
            tgt_key: estim
        }
remicres commented 2 weeks ago

Looks like keras 3 will give us a bit of extra work :)

vidlb commented 2 weeks ago

A bit yes but it shouldn't be too hard.
Initially I started to update the code then thought my MR was already way too big, so I saved this in a patch file.
If you want to take a look, as pointed by Pratyush this is mostly stuff related to keras.ops, _keras_history, and some func arguments (e.g. keras.ops.one_hot ) :

Diff ```patch diff --git a/otbtf/examples/tensorflow_v2x/deterministic/l2_norm.py b/otbtf/examples/tensorflow_v2x/deterministic/l2_norm.py index b23d86c..59e2b0a 100644 --- a/otbtf/examples/tensorflow_v2x/deterministic/l2_norm.py +++ b/otbtf/examples/tensorflow_v2x/deterministic/l2_norm.py @@ -14,14 +14,15 @@ otbcli_TensorflowModelServe \ """ -import tensorflow as tf +import keras + # Input -x = tf.keras.Input(shape=[None, None, None], name="x") # [1, h, w, N] +x = keras.Input(shape=[None, None, None], name="x") # [1, h, w, N] # Compute norm on the last axis -y = tf.norm(x, axis=-1) +y = keras.ops.norm(x, axis=-1) # Create model -model = tf.keras.Model(inputs={"x": x}, outputs={"y": y}) -model.save("l2_norm_savedmodel") +model = keras.Model(inputs={"x": x}, outputs={"y": y}) +model.export("l2_norm_savedmodel") diff --git a/otbtf/examples/tensorflow_v2x/deterministic/scalar_prod.py b/otbtf/examples/tensorflow_v2x/deterministic/scalar_prod.py index 57127c5..1d5be34 100644 --- a/otbtf/examples/tensorflow_v2x/deterministic/scalar_prod.py +++ b/otbtf/examples/tensorflow_v2x/deterministic/scalar_prod.py @@ -16,15 +16,16 @@ OTB_TF_NSOURCES=2 otbcli_TensorflowModelServe \ """ -import tensorflow as tf + +import keras # Input -x1 = tf.keras.Input(shape=[None, None, None], name="x1") # [1, h, w, N] -x2 = tf.keras.Input(shape=[None, None, None], name="x2") # [1, h, w, N] +x1 = keras.Input(shape=[None, None, None], name="x1") # [1, h, w, N] +x2 = keras.Input(shape=[None, None, None], name="x2") # [1, h, w, N] # Compute scalar product -y = tf.reduce_sum(tf.multiply(x1, x2), axis=-1) +y = keras.ops.reduce_sum(keras.ops.multiply(x1, x2), axis=-1) # Create model -model = tf.keras.Model(inputs={"x1": x1, "x2": x2}, outputs={"y": y}) -model.save("scalar_product_savedmodel") +model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs={"y": y}) +model.export("scalar_product_savedmodel") diff --git a/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py b/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py index fcd14a2..7327f69 100644 --- a/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py +++ b/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py @@ -1,16 +1,18 @@ """ Implementation of a small U-Net like model """ + import logging import tensorflow as tf +import keras from otbtf.model import ModelBase logging.basicConfig( - format='%(asctime)s %(levelname)-8s %(message)s', + format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, - datefmt='%Y-%m-%d %H:%M:%S' + datefmt="%Y-%m-%d %H:%M:%S", ) # Number of classes estimated by the model @@ -51,7 +53,7 @@ class FCNNModel(ModelBase): Returns: dict of normalized inputs, ready to be used from `get_outputs()` """ - return {INPUT_NAME: tf.cast(inputs[INPUT_NAME], tf.float32) * 0.0001} + return {INPUT_NAME: keras.ops.cast(inputs[INPUT_NAME], tf.float32) * 0.0001} def get_outputs(self, normalized_inputs: dict) -> dict: """ @@ -71,24 +73,24 @@ class FCNNModel(ModelBase): norm_inp = normalized_inputs[INPUT_NAME] def _conv(inp, depth, name): - conv_op = tf.keras.layers.Conv2D( + conv_op = keras.layers.Conv2D( filters=depth, kernel_size=3, strides=2, activation="relu", padding="same", - name=name + name=name, ) return conv_op(inp) def _tconv(inp, depth, name, activation="relu"): - tconv_op = tf.keras.layers.Conv2DTranspose( + tconv_op = keras.layers.Conv2DTranspose( filters=depth, kernel_size=3, strides=2, activation=activation, padding="same", - name=name + name=name, ) return tconv_op(inp) @@ -110,7 +112,7 @@ class FCNNModel(ModelBase): # command. # # Do not confuse **the name of the output layers** (i.e. the "name" - # property of the tf.keras.layer that is used to generate an output + # property of the keras.layer that is used to generate an output # tensor) and **the key of the output tensor**, in the dict returned # from `MyModel.get_output()`. They are two identifiers with a # different purpose: @@ -120,7 +122,7 @@ class FCNNModel(ModelBase): # fit the targets to model outputs during training process, but it # can also be used to access the tensors as tf/keras objects, for # instance to display previews images in TensorBoard. - softmax_op = tf.keras.layers.Softmax(name=OUTPUT_SOFTMAX_NAME) + softmax_op = keras.layers.Softmax(name=OUTPUT_SOFTMAX_NAME) predictions = softmax_op(out_tconv4) # note that we could also add additional outputs, for instance the @@ -158,10 +160,12 @@ def dataset_preprocessing_fn(examples: dict): """ return { INPUT_NAME: examples["input_xs_patches"], - TARGET_NAME: tf.one_hot( - tf.squeeze(tf.cast(examples["labels_patches"], tf.int32), axis=-1), - depth=N_CLASSES - ) + TARGET_NAME: keras.ops.one_hot( + keras.ops.squeeze( + keras.ops.cast(examples["labels_patches"], tf.int32), axis=-1 + ), + N_CLASSES, + ), } @@ -190,18 +194,12 @@ def train(params, ds_train, ds_valid, ds_test): # This ensures a better optimization control, and also avoids lots of # useless outputs (e.g. metrics computed over extra outputs). model.compile( - loss={ - TARGET_NAME: tf.keras.losses.CategoricalCrossentropy() - }, - optimizer=tf.keras.optimizers.Adam( - learning_rate=params.learning_rate - ), - metrics={ - TARGET_NAME: [ - tf.keras.metrics.Precision(class_id=1), - tf.keras.metrics.Recall(class_id=1) - ] - } + loss=keras.losses.CategoricalCrossentropy(), + optimizer=keras.optimizers.Adam(learning_rate=params.learning_rate), + metrics=[ + keras.metrics.Precision(class_id=1), + keras.metrics.Recall(class_id=1), + ], ) # Summarize the model (in CLI) @@ -215,4 +213,4 @@ def train(params, ds_train, ds_valid, ds_test): model.evaluate(ds_test, batch_size=params.batch_size) # Save trained model as SavedModel - model.save(params.model_dir) + model.export(params.model_dir) diff --git a/otbtf/layers.py b/otbtf/layers.py index ef65ec1..028ba9a 100644 --- a/otbtf/layers.py +++ b/otbtf/layers.py @@ -25,13 +25,14 @@ The utils module provides some useful keras layers to build deep nets. """ from typing import List, Tuple, Any import tensorflow as tf +import keras Tensor = Any Scalars = List[float] | Tuple[float] -class DilatedMask(tf.keras.layers.Layer): +class DilatedMask(keras.layers.Layer): """Layer to dilate a binary mask.""" def __init__(self, nodata_value: float, radius: int, name: str = None): """ @@ -70,7 +71,7 @@ class DilatedMask(tf.keras.layers.Layer): return tf.cast(conv2d_out, tf.uint8) -class ApplyMask(tf.keras.layers.Layer): +class ApplyMask(keras.layers.Layer): """Layer to apply a binary mask to one input.""" def __init__(self, out_nodata: float, name: str = None): """ @@ -95,7 +96,7 @@ class ApplyMask(tf.keras.layers.Layer): return tf.where(mask == 1, float(self.out_nodata), inp) -class ScalarsTile(tf.keras.layers.Layer): +class ScalarsTile(keras.layers.Layer): """ Layer to duplicate some scalars in a whole array. Simple example with only one scalar = 0.152: @@ -127,7 +128,7 @@ class ScalarsTile(tf.keras.layers.Layer): return tf.tile(inp, [1, tf.shape(ref)[1], tf.shape(ref)[2], 1]) -class Argmax(tf.keras.layers.Layer): +class Argmax(keras.layers.Layer): """ Layer to compute the argmax of a tensor. @@ -165,7 +166,7 @@ class Argmax(tf.keras.layers.Layer): return argmax -class Max(tf.keras.layers.Layer): +class Max(keras.layers.Layer): """ Layer to compute the max of a tensor. diff --git a/otbtf/model.py b/otbtf/model.py index 9958510..83013ff 100644 --- a/otbtf/model.py +++ b/otbtf/model.py @@ -27,6 +27,7 @@ from typing import List, Dict, Any import abc import logging import tensorflow as tf +import keras Tensor = Any TensorsDict = Dict[str, Tensor] @@ -116,7 +117,7 @@ class ModelBase(abc.ABC): if len(new_shape) > 2: new_shape[0] = None new_shape[1] = None - placeholder = tf.keras.Input(shape=new_shape, name=key) + placeholder = keras.Input(shape=new_shape, name=key) logging.info("New shape for input %s: %s", key, new_shape) model_inputs.update({key: placeholder}) return model_inputs @@ -185,21 +186,21 @@ class ModelBase(abc.ABC): for crop in self.inference_cropping: extra_output_key = cropped_tensor_name(out_key, crop) extra_output_name = cropped_tensor_name( - out_tensor._keras_history.layer.name, crop + out_tensor._keras_history.operation.name, crop ) logging.info( "Adding extra output for tensor %s with crop %s (%s)", out_key, crop, extra_output_name ) cropped = out_tensor[:, crop:-crop, crop:-crop, :] - identity = tf.keras.layers.Activation( + identity = keras.layers.Activation( 'linear', name=extra_output_name ) extra_outputs[extra_output_key] = identity(cropped) return extra_outputs - def create_network(self) -> tf.keras.Model: + def create_network(self) -> keras.Model: """ This method returns the Keras model. This needs to be called **inside** the strategy.scope(). Can be reimplemented depending on the @@ -230,7 +231,7 @@ class ModelBase(abc.ABC): outputs.update(postprocessed_outputs) # Return the keras model - return tf.keras.Model( + return keras.Model( inputs=inputs, outputs=outputs, name=self.__class__.__name__ @@ -265,7 +266,7 @@ class ModelBase(abc.ABC): # When multiworker strategy, only plot if the worker is chief if not strategy or _is_chief(strategy): - tf.keras.utils.plot_model( + keras.utils.plot_model( self.model, output_path, show_shapes=show_shapes ) diff --git a/otbtf/ops.py b/otbtf/ops.py index ef5c52b..5a47356 100644 --- a/otbtf/ops.py +++ b/otbtf/ops.py @@ -26,6 +26,7 @@ and train deep nets. """ from typing import List, Tuple, Any import tensorflow as tf +import keras Tensor = Any @@ -44,5 +45,7 @@ def one_hot(labels: Tensor, nb_classes: int): one-hot encoded vector (shape [x, y, nb_classes]) """ - labels_xy = tf.squeeze(tf.cast(labels, tf.int32), axis=-1) # shape [x, y] - return tf.one_hot(labels_xy, depth=nb_classes) # shape [x, y, nb_classes] + # shape [x, y] + labels_xy = keras.ops.squeeze(keras.ops.cast(labels, tf.int32), axis=-1) + # shape [x, y, nb_classes] + return keras.ops.one_hot(labels_xy, nb_classes) ```
vidlb commented 2 weeks ago

I believe the most annoying change is that Keras now refuse to take a dict of named outputs, the name should be set / inferred in the layer props, but sometimes it seems it is lost due to optimizations