james77777778 / keras-aug

A library that includes Keras 3 preprocessing and augmentation layers, providing support for various data types such as images, labels, bounding boxes, segmentation masks, and more.
Apache License 2.0
15 stars 0 forks source link

Auto Augmentation with mixed precision bug #125

Closed EugenioTL closed 10 months ago

EugenioTL commented 10 months ago

I am receiving the following error when trying to apply TrivialAugmentWide or RandAugment layers with mixed precision. Any guess?

from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

def make_dataset(X,y,batch_size,autotune = tf.data.AUTOTUNE,augmentation = None,seed = seed):
    def preprocess_data(images, labels, augmentation=None):
        inputs = {"images": images, "labels": labels}
        outputs = augmentation(inputs) if augmentation != None else inputs
        return outputs["images"], outputs["labels"]

    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    dataset = dataset.batch(batch_size).map(lambda x, y: preprocess_data(x, y, augmentation=augmentation), num_parallel_calls=autotune).prefetch(autotune)
return dataset

augmentation_layer = tfk.Sequential([
    keras_aug.layers.TrivialAugmentWide(value_range=(0,1), interpolation='bilinear', name='trivial_augment'),
    keras_aug.layers.RandomErase(area_factor=(0.02, 0.1), fill_mode='gaussian_noise', name='random_erase')
    ], name='preprocessing')

training_dataset = make_dataset(X_train, y_train, batch_size=batch_size, augmentation=augmentation_layer)

--------------------------------------------------------------------------------------------------------------------------------

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/dataset_ops.py:2268, in DatasetV2.map(self, map_func, num_parallel_calls, deterministic, name)
   2264 # Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
   2265 # dataset_ops).
   2266 # pylint: disable=g-import-not-at-top,protected-access
   2267 from tensorflow.python.data.ops import map_op
-> 2268 return map_op._map_v2(
   2269     self,
   2270     map_func,
   2271     num_parallel_calls=num_parallel_calls,
   2272     deterministic=deterministic,
   2273     name=name)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/map_op.py:40, in _map_v2(input_dataset, map_func, num_parallel_calls, deterministic, name)
     37   return _MapDataset(
     38       input_dataset, map_func, preserve_cardinality=True, name=name)
     39 else:
---> 40   return _ParallelMapDataset(
     41       input_dataset,
     42       map_func,
     43       num_parallel_calls=num_parallel_calls,
     44       deterministic=deterministic,
     45       preserve_cardinality=True,
     46       name=name)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/map_op.py:148, in _ParallelMapDataset.__init__(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function, name)
    146 self._input_dataset = input_dataset
    147 self._use_inter_op_parallelism = use_inter_op_parallelism
--> 148 self._map_func = structured_function.StructuredFunctionWrapper(
    149     map_func,
    150     self._transformation_name(),
    151     dataset=input_dataset,
    152     use_legacy_function=use_legacy_function)
    153 if deterministic is None:
    154   self._deterministic = "default"

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/structured_function.py:265, in StructuredFunctionWrapper.__init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
    258       warnings.warn(
    259           "Even though the `tf.config.experimental_run_functions_eagerly` "
    260           "option is set, this option does not apply to tf.data functions. "
    261           "To force eager execution of tf.data functions, please use "
    262           "`tf.data.experimental.enable_debug_mode()`.")
    263     fn_factory = trace_tf_function(defun_kwargs)
--> 265 self._function = fn_factory()
    266 # There is no graph to add in eager mode.
    267 add_to_graph &= not context.executing_eagerly()

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:1222, in Function.get_concrete_function(self, *args, **kwargs)
   1220 def get_concrete_function(self, *args, **kwargs):
   1221   # Implements GenericFunction.get_concrete_function.
-> 1222   concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1223   concrete._garbage_collector.release()  # pylint: disable=protected-access
   1224   return concrete

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:1192, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
   1190   if self._variable_creation_config is None:
   1191     initializers = []
-> 1192     self._initialize(args, kwargs, add_initializers_to=initializers)
   1193     self._initialize_uninitialized_variables(initializers)
   1195 if self._created_variables:
   1196   # In this case we have created variables on the first call, so we run the
   1197   # version which is guaranteed to never create variables.

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694, in Function._initialize(self, args, kwds, add_initializers_to)
    689 self._variable_creation_config = self._generate_scoped_tracing_options(
    690     variable_capturing_scope,
    691     tracing_compilation.ScopeType.VARIABLE_CREATION,
    692 )
    693 # Force the definition of the function for these arguments
--> 694 self._concrete_variable_creation_fn = tracing_compilation.trace_function(
    695     args, kwds, self._variable_creation_config
    696 )
    698 def invalid_creator_scope(*unused_args, **unused_kwds):
    699   """Disables variable creation."""

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178, in trace_function(args, kwargs, tracing_options)
    175     args = tracing_options.input_signature
    176     kwargs = {}
--> 178   concrete_function = _maybe_define_function(
    179       args, kwargs, tracing_options
    180   )
    181   _set_arg_keywords(concrete_function)
    183 if not tracing_options.bind_graph_to_function:

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284, in _maybe_define_function(args, kwargs, tracing_options)
    282 else:
    283   target_func_type = lookup_func_type
--> 284 concrete_function = _create_concrete_function(
    285     target_func_type, lookup_func_context, func_graph, tracing_options
    286 )
    288 if tracing_options.function_cache is not None:
    289   tracing_options.function_cache.add(
    290       concrete_function, current_func_context
    291   )

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308, in _create_concrete_function(function_type, type_context, func_graph, tracing_options)
    303 with func_graph.as_default():
    304   placeholder_bound_args = function_type.placeholder_arguments(
    305       placeholder_context
    306   )
--> 308 traced_func_graph = func_graph_module.func_graph_from_py_func(
    309     tracing_options.name,
    310     tracing_options.python_function,
    311     placeholder_bound_args.args,
    312     placeholder_bound_args.kwargs,
    313     None,
    314     func_graph=func_graph,
    315     arg_names=function_type_utils.to_arg_names(function_type),
    316     create_placeholders=False,
    317 )
    319 transform.apply_func_graph_transforms(traced_func_graph)
    321 graph_capture_container = traced_func_graph.function_captures

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/framework/func_graph.py:1059, in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, create_placeholders)
   1056   return x
   1058 _, original_func = tf_decorator.unwrap(python_func)
-> 1059 func_outputs = python_func(*func_args, **func_kwargs)
   1061 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
   1062 # TensorArrays and `None`s.
   1063 func_outputs = variable_utils.convert_variables_to_tensors(func_outputs)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597, in Function._generate_scoped_tracing_options.<locals>.wrapped_fn(*args, **kwds)
    593 with default_graph._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
    594   # __wrapped__ allows AutoGraph to swap in a converted function. We give
    595   # the function a weak reference to itself to avoid a reference cycle.
    596   with OptionalXlaContext(compile_with_xla):
--> 597     out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    598   return out

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/structured_function.py:231, in StructuredFunctionWrapper.__init__.<locals>.trace_tf_function.<locals>.wrapped_fn(*args)
    230 def wrapped_fn(*args):  # pylint: disable=missing-docstring
--> 231   ret = wrapper_helper(*args)
    232   ret = structure.to_tensor_list(self._output_structure, ret)
    233   return [ops.convert_to_tensor(t) for t in ret]

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/structured_function.py:161, in StructuredFunctionWrapper.__init__.<locals>.wrapper_helper(*args)
    159 if not _should_unpack(nested_args):
    160   nested_args = (nested_args,)
--> 161 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
    162 ret = variable_utils.convert_variables_to_tensors(ret)
    163 if _should_pack(ret):

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:693, in convert.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    691 except Exception as e:  # pylint:disable=broad-except
    692   if hasattr(e, 'ag_error_metadata'):
--> 693     raise e.ag_error_metadata.to_exception(e)
    694   else:
    695     raise

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:690, in convert.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    688 try:
    689   with conversion_ctx:
--> 690     return converted_call(f, args, kwargs, options=options)
    691 except Exception as e:  # pylint:disable=broad-except
    692   if hasattr(e, 'ag_error_metadata'):

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:439, in converted_call(f, args, kwargs, caller_fn_scope, options)
    437 try:
    438   if kwargs is not None:
--> 439     result = converted_f(*effective_args, **kwargs)
    440   else:
    441     result = converted_f(*effective_args)

File ~tmp/__autograph_generated_filebrqr6ub4.py:7, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(x, y)
      6 def inner_factory(ag__):
----> 7     tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: ag__.converted_call(preprocess_data, (x, y), dict(augmentation=augmentation), lscope), 'lscope', ag__.STD)
      8     return tf__lam

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/core/function_wrappers.py:113, in with_function_scope(thunk, scope_name, options)
    111 """Inline version of the FunctionScope context manager."""
    112 with FunctionScope('lambda_', scope_name, options) as scope:
--> 113   return thunk(scope)

File ~tmp/__autograph_generated_filebrqr6ub4.py:7, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(lscope)
      6 def inner_factory(ag__):
----> 7     tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: ag__.converted_call(preprocess_data, (x, y), dict(augmentation=augmentation), lscope), 'lscope', ag__.STD)
      8     return tf__lam

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:439, in converted_call(f, args, kwargs, caller_fn_scope, options)
    437 try:
    438   if kwargs is not None:
--> 439     result = converted_f(*effective_args, **kwargs)
    440   else:
    441     result = converted_f(*effective_args)

File ~tmp/__autograph_generated_filelvb17_ff.py:11, in outer_factory.<locals>.inner_factory.<locals>.tf__preprocess_data(images, labels, augmentation)
      9 retval_ = ag__.UndefinedReturnValue()
     10 inputs = {'images': ag__.ld(images), 'labels': ag__.ld(labels)}
---> 11 outputs = ag__.if_exp(ag__.ld(augmentation) != None, lambda: ag__.converted_call(ag__.ld(augmentation), (ag__.ld(inputs),), None, fscope), lambda: ag__.ld(inputs), 'augmentation != None')
     12 try:
     13     do_return = True

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/operators/conditional_expressions.py:27, in if_exp(cond, if_true, if_false, expr_repr)
     25   return _tf_if_exp(cond, if_true, if_false, expr_repr)
     26 else:
---> 27   return _py_if_exp(cond, if_true, if_false)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/operators/conditional_expressions.py:52, in _py_if_exp(cond, if_true, if_false)
     51 def _py_if_exp(cond, if_true, if_false):
---> 52   return if_true() if cond else if_false()

File ~tmp/__autograph_generated_filelvb17_ff.py:11, in outer_factory.<locals>.inner_factory.<locals>.tf__preprocess_data.<locals>.<lambda>()
      9 retval_ = ag__.UndefinedReturnValue()
     10 inputs = {'images': ag__.ld(images), 'labels': ag__.ld(labels)}
---> 11 outputs = ag__.if_exp(ag__.ld(augmentation) != None, lambda: ag__.converted_call(ag__.ld(augmentation), (ag__.ld(inputs),), None, fscope), lambda: ag__.ld(inputs), 'augmentation != None')
     12 try:
     13     do_return = True

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:331, in converted_call(f, args, kwargs, caller_fn_scope, options)
    329 if conversion.is_in_allowlist_cache(f, options):
    330   logging.log(2, 'Allowlisted %s: from cache', f)
--> 331   return _call_unconverted(f, args, kwargs, options, False)
    333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
    334   logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:460, in _call_unconverted(f, args, kwargs, options, update_cache)
    458 if kwargs is not None:
    459   return f(*args, **kwargs)
--> 460 return f(*args)

File ~usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File ~tmp/__autograph_generated_file12c62saa.py:33, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs)
     31     nonlocal do_return, retval_
     32     raise ag__.converted_call(ag__.ld(ValueError), (f'Image augmentation layers are expecting inputs to be rank 3 (HWC) or 4D (NHWC) tensors. Got shape: {ag__.ld(images).shape}',), None, fscope)
---> 33 ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
     34 return fscope.ret(retval_, do_return)

File ~tmp/__autograph_generated_file12c62saa.py:25, in outer_factory.<locals>.inner_factory.<locals>.tf__call.<locals>.if_body()
     23 try:
     24     do_return = True
---> 25     retval_ = ag__.converted_call(ag__.ld(self)._format_output, (ag__.converted_call(ag__.ld(self)._batch_augment, (ag__.ld(inputs),), None, fscope), ag__.ld(metadata)), None, fscope)
     26 except:
     27     do_return = False

File ~tmp/__autograph_generated_filep1xqu7wy.py:35, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment(self, inputs)
     33 ag__.if_stmt(ag__.ld(bounding_boxes) is not None, if_body, else_body, get_state, set_state, ('inputs[BOUNDING_BOXES]', 'ori_bbox_info'), 2)
     34 inputs_for_trivial_augment_single_input = {'inputs': ag__.ld(inputs), 'transformations': ag__.ld(transformations)}
---> 35 result = ag__.converted_call(ag__.ld(tf).map_fn, (ag__.ld(self).trivial_augment_single_input, ag__.ld(inputs_for_trivial_augment_single_input)), dict(fn_output_signature=ag__.converted_call(ag__.ld(augmentation_utils).compute_signature, (ag__.ld(inputs), ag__.ld(self).compute_dtype), None, fscope)), fscope)
     36 bounding_boxes = ag__.converted_call(ag__.ld(result).get, (ag__.ld(BOUNDING_BOXES), None), None, fscope)
     38 def get_state_2():

File ~tmp/__autograph_generated_filev6ti3m33.py:26, in outer_factory.<locals>.inner_factory.<locals>.tf__trivial_augment_single_input(self, inputs)
     24 idx = ag__.Undefined('idx')
     25 ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.ld(self).aug_layers,), None, fscope), None, loop_body, get_state, set_state, (), {'iterate_names': '(idx, layer)'})
---> 26 result = ag__.converted_call(ag__.ld(tf).switch_case, (ag__.ld(random_indice),), dict(branch_fns=ag__.ld(branch_fns)), fscope)
     28 def get_state_1():
     29     return (ag__.ldu(lambda: result[BOUNDING_BOXES], 'result[BOUNDING_BOXES]'),)

File ~tmp/__autograph_generated_file12c62saa.py:33, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs)
     31     nonlocal do_return, retval_
     32     raise ag__.converted_call(ag__.ld(ValueError), (f'Image augmentation layers are expecting inputs to be rank 3 (HWC) or 4D (NHWC) tensors. Got shape: {ag__.ld(images).shape}',), None, fscope)
---> 33 ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
     34 return fscope.ret(retval_, do_return)

File ~tmp/__autograph_generated_file12c62saa.py:25, in outer_factory.<locals>.inner_factory.<locals>.tf__call.<locals>.if_body()
     23 try:
     24     do_return = True
---> 25     retval_ = ag__.converted_call(ag__.ld(self)._format_output, (ag__.converted_call(ag__.ld(self)._batch_augment, (ag__.ld(inputs),), None, fscope), ag__.ld(metadata)), None, fscope)
     26 except:
     27     do_return = False

File ~tmp/__autograph_generated_filevkq0dz7z.py:36, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment(self, inputs)
     34     images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
     35 inputs_for_raggeds = ag__.Undefined('inputs_for_raggeds')
---> 36 ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
     38 def get_state_1():
     39     return (images,)

File ~tmp/__autograph_generated_filevkq0dz7z.py:34, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment.<locals>.else_body()
     32 def else_body():
     33     nonlocal images
---> 34     images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)

File ~tmp/__autograph_generated_filegmle6ti9.py:15, in outer_factory.<locals>.inner_factory.<locals>.tf__augment_images(self, images, transformations, **kwargs)
     13 scales = 255.0 / (ag__.ld(highs) - ag__.ld(lows))
     14 eq_idxs = ag__.converted_call(ag__.ld(tf).math.is_inf, (ag__.ld(scales),), None, fscope)
---> 15 lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)
     16 scales = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 1.0, ag__.ld(scales)), None, fscope)
     17 images = ag__.converted_call(ag__.ld(tf).clip_by_value, ((ag__.ld(images) - ag__.ld(lows)) * ag__.ld(scales), 0, 255), None, fscope)

TypeError: in user code:

    File "/tmp/ipykernel_11/1206971647.py", line 15, in None  *
        lambda x, y: preprocess_data(x, y, augmentation=augmentation)
    File "/tmp/ipykernel_11/1206971647.py", line 11, in preprocess_data  *
        outputs = augmentation(inputs) if augmentation != None else inputs
    File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_file12c62saa.py", line 33, in tf__call
        ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
    File "/tmp/__autograph_generated_file12c62saa.py", line 28, in if_body
        raise
    File "/tmp/__autograph_generated_filep1xqu7wy.py", line 35, in tf___batch_augment
        result = ag__.converted_call(ag__.ld(tf).map_fn, (ag__.ld(self).trivial_augment_single_input, ag__.ld(inputs_for_trivial_augment_single_input)), dict(fn_output_signature=ag__.converted_call(ag__.ld(augmentation_utils).compute_signature, (ag__.ld(inputs), ag__.ld(self).compute_dtype), None, fscope)), fscope)
    File "/tmp/__autograph_generated_filev6ti3m33.py", line 26, in tf__trivial_augment_single_input
        result = ag__.converted_call(ag__.ld(tf).switch_case, (ag__.ld(random_indice),), dict(branch_fns=ag__.ld(branch_fns)), fscope)
    File "/tmp/__autograph_generated_file12c62saa.py", line 33, in tf__call
        ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
    File "/tmp/__autograph_generated_file12c62saa.py", line 28, in if_body
        raise
    File "/tmp/__autograph_generated_filevkq0dz7z.py", line 36, in tf___batch_augment
        ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
    File "/tmp/__autograph_generated_filevkq0dz7z.py", line 34, in else_body
        images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
    File "/tmp/__autograph_generated_filegmle6ti9.py", line 15, in tf__augment_images
        lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)

    TypeError: Exception encountered when calling layer 'trivial_augment' (type TrivialAugmentWide).

    in user code:

        File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/vectorized_base_random_layer.py", line 613, in call  *
            if images.shape.rank == 3 or images.shape.rank == 4:
        File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/augmentation/auto/trivial_augment_wide.py", line 281, in _batch_augment  *
            result = tf.map_fn(
        File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/augmentation/auto/trivial_augment_wide.py", line 316, in trivial_augment_single_input  *
            result = tf.switch_case(random_indice, branch_fns=branch_fns)
        File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
            raise e.with_traceback(filtered_tb) from None
        File "/tmp/__autograph_generated_file12c62saa.py", line 33, in tf__call
            ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        File "/tmp/__autograph_generated_file12c62saa.py", line 28, in if_body
            raise
        File "/tmp/__autograph_generated_filevkq0dz7z.py", line 36, in tf___batch_augment
            ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
        File "/tmp/__autograph_generated_filevkq0dz7z.py", line 34, in else_body
            images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
        File "/tmp/__autograph_generated_filegmle6ti9.py", line 15, in tf__augment_images
            lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)

        TypeError: Exception encountered when calling layer 'trivial_augment' (type AutoContrast).

        in user code:

            File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/vectorized_base_random_layer.py", line 614, in call  *
                return self._format_output(self._batch_augment(inputs), metadata)
            File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/vectorized_base_random_layer.py", line 416, in _batch_augment  *
                images = self.augment_images(
            File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/preprocessing/intensity/auto_contrast.py", line 50, in augment_images  *
                lows = tf.where(eq_idxs, 0.0, lows)

            TypeError: Input 'e' of 'SelectV2' Op has type float16 that does not match type float32 of argument 't'.

        Call arguments received by layer 'trivial_augment' (type AutoContrast):
          • inputs={'images': 'tf.Tensor(shape=(1, 32, 32, 3), dtype=float16)', 'labels': 'tf.Tensor(shape=(1, 10), dtype=float16)'}

    Call arguments received by layer 'trivial_augment' (type TrivialAugmentWide):
      • inputs={'images': 'tf.Tensor(shape=(None, 32, 32, 3), dtype=float16)', 'labels': 'tf.Tensor(shape=(None, 10), dtype=float16)'}
james77777778 commented 10 months ago

Hi @EugenioTL Thanks for reaching out. I have published a new version that addresses the inconsistent dtype bug in tf.where. You can install it by

pip3 install keras-aug -U

Feel free to let me know if the issue persists

EugenioTL commented 10 months ago

It seems to work properly now, thank you very much!