huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.65k stars 27.16k forks source link

How to save wrapped DistilBERT without using `save_pretrained`? #14152

Closed hardianlawi closed 3 years ago

hardianlawi commented 3 years ago

Environment info

Who can help

@Rocketknight1

Information

Model I am using (Bert, XLNet ...):

The problem arises when using:

The tasks I am working on is:

To reproduce

Simply run the codes below

import tensorflow as tf
from transformers import (
    TFDistilBertModel,
    DistilBertTokenizerFast,
    DistilBertConfig,
)

def build_classifier_model():
    input_ids = tf.keras.layers.Input(shape=(None,), name="input_ids", dtype=tf.int32)
    attention_mask = tf.keras.layers.Input(
        shape=(None,), name="attention_mask", dtype=tf.int32
    )

    config = DistilBertConfig(
        dropout=0.2,
        attention_dropout=0.2,
        output_attentions=True,
        output_hidden_states=False,
        return_dict=False,
    )
    transformer = TFDistilBertModel.from_pretrained(
        "distilbert-base-uncased", config=config
    )
    transformer.trainable = False

    last_hidden_state = transformer(
        [input_ids, attention_mask],
    )[0]

    x = last_hidden_state[:, 0, :]
    x = tf.keras.layers.Dense(768, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.2)(x)

    outputs = {
        label_name: tf.keras.layers.Dense(1, activation="sigmoid", name=label_name)(x)
        for label_name in ['A', 'B', 'C']
    }

    return tf.keras.Model([input_ids, attention_mask], outputs)

model = build_classifier_model()
model.save('./dump/savedmodel')

Expected behavior

I expect this to generate artifacts containing the model in savedmodel format, but instead I got

~/miniforge3/envs/folder/lib/python3.8/site-packages/transformers/models/distilbert/modeling_tf_distilbert.py in call(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    561         **kwargs,
    562     ):
--> 563         inputs = input_processing(
    564             func=self.call,
    565             config=self.config,

~/miniforge3/envs/folder/lib/python3.8/site-packages/transformers/modeling_tf_utils.py in input_processing(func, config, input_ids, **kwargs)
    376                     output[tensor_name] = input
    377                 else:
--> 378                     output[parameter_names[i]] = input
    379             elif isinstance(input, allowed_types) or input is None:
    380                 output[parameter_names[i]] = input

IndexError: list index out of range
hardianlawi commented 3 years ago

I saw others posted similar issues https://github.com/huggingface/transformers/issues/13610 and https://github.com/huggingface/transformers/issues/13742. However, since I am wrapping the model in tf.keras.Model, save_pretrained isn't a viable solution. Are there any workarounds?

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Rocketknight1 commented 3 years ago

This issue should be resolved by recent PRs - if you're still encountering difficulties after installing the most recent release, please reopen it and let us know!

kapilkd13 commented 2 years ago

@hardianlawi Were you able to solve this issue? I am still facing this issue, can you help?

Hi @Rocketknight1 i am still facing this issue. I am not able to save finetuned TFDistilBertModel model in keras with model.save() . Since the model is wrapped in tf.keras.Model I can't use save_pretrained. transformers version: 4.15 Platform: Ubuntu 20 Python version: 3.8 PyTorch version (GPU?): Tensorflow version (GPU?): 2.6.2 Using GPU in script?: Yes Using distributed or parallel set-up in script?: Yes

Error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_47167/2242234445.py in <module>
----> 1 model.save("save_path")

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
   2143     """
   2144     # pylint: enable=line-too-long
-> 2145     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
   2146                     signatures, options, save_traces)
   2147 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
    147   else:
    148     with generic_utils.SharedObjectSavingScope():
--> 149       saved_model_save.save(model, filepath, overwrite, include_optimizer,
    150                             signatures, options, save_traces)
    151 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
     88   with K.deprecated_internal_learning_phase_scope(0):
     89     with utils.keras_option_scope(save_traces):
---> 90       saved_nodes, node_paths = save_lib.save_and_return_nodes(
     91           model, filepath, signatures, options)
     92 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)
   1226 
   1227   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1228       _build_meta_graph(obj, signatures, options, meta_graph_def))
   1229   saved_model.saved_model_schema_version = (
   1230       pywrap_libexport.SAVED_MODEL_SCHEMA_VERSION)

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
   1397 
   1398   with save_context.save_context(options):
-> 1399     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
   1333   checkpoint_graph_view = _AugmentedGraphView(obj)
   1334   if signatures is None:
-> 1335     signatures = signature_serialization.find_function_to_export(
   1336         checkpoint_graph_view)
   1337 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
     97   # If the user did not specify signatures, check the root object for a function
     98   # that can be made into a signature.
---> 99   functions = saveable_view.list_functions(saveable_view.root)
    100   signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
    101   if signature is not None:

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj)
    161     obj_functions = self._functions.get(obj, None)
    162     if obj_functions is None:
--> 163       obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
    164           self._serialization_cache)
    165       self._functions[obj] = obj_functions

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
   2810     self.predict_function = None
   2811     self.train_tf_function = None
-> 2812     functions = super(
   2813         Model, self)._list_functions_for_serialization(serialization_cache)
   2814     self.train_function = train_function

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
   3083 
   3084   def _list_functions_for_serialization(self, serialization_cache):
-> 3085     return (self._trackable_saved_model_saver
   3086             .list_functions_for_serialization(serialization_cache))
   3087 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
     91       return {}
     92 
---> 93     fns = self.functions_to_serialize(serialization_cache)
     94 
     95     # The parent AutoTrackable class saves all user-defined tf.functions, and

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
     71 
     72   def functions_to_serialize(self, serialization_cache):
---> 73     return (self._get_serialized_attributes(
     74         serialization_cache).functions_to_serialize)
     75 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     87       return serialized_attr
     88 
---> 89     object_dict, function_dict = self._get_serialized_attributes_internal(
     90         serialization_cache)
     91 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     54     # the ones serialized by Layer.
     55     objects, functions = (
---> 56         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
     57             serialization_cache))
     58     functions['_default_save_signature'] = default_signature

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     97     """Returns dictionary of serialized attributes."""
     98     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
---> 99     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
    100     # Attribute validator requires that the default save signature is added to
    101     # function dict, even if the value is None.

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
    195       for fn in fns.values():
    196         if fn is not None and not isinstance(fn, LayerCall):
--> 197           fn.get_concrete_function()
    198 
    199   # Restore overwritten functions and losses

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/contextlib.py in __exit__(self, type, value, traceback)
    118         if type is None:
    119             try:
--> 120                 next(self.gen)
    121             except StopIteration:
    122                 return False

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in tracing_scope()
    357       if training is not None:
    358         with K.deprecated_internal_learning_phase_scope(training):
--> 359           fn.get_concrete_function(*args, **kwargs)
    360       else:
    361         fn.get_concrete_function(*args, **kwargs)

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1231   def get_concrete_function(self, *args, **kwargs):
   1232     # Implements GenericFunction.get_concrete_function.
-> 1233     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1234     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1235     return concrete

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1211       if self._stateful_fn is None:
   1212         initializers = []
-> 1213         self._initialize(args, kwargs, add_initializers_to=initializers)
   1214         self._initialize_uninitialized_variables(initializers)
   1215 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    757     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    758     self._concrete_stateful_fn = (
--> 759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    760             *args, **kwds))
    761 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3296     arg_names = base_arg_names + missing_arg_names
   3297     graph_function = ConcreteFunction(
-> 3298         func_graph_module.func_graph_from_py_func(
   3299             self._name,
   3300             self._python_function,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in call(inputs, *args, **kwargs)
    649     return layer.keras_api.__call__  # pylint: disable=protected-access
    650   def call(inputs, *args, **kwargs):
--> 651     return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
    652   return _create_call_fn_decorator(layer, call)
    653 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
    607   def __call__(self, *args, **kwargs):
    608     self._maybe_trace(args, kwargs)
--> 609     return self.wrapped_call(*args, **kwargs)
    610 
    611   def get_concrete_function(self, *args, **kwargs):

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    931       # This is the first call of __call__, so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args, kwds, add_initializers_to=initializers)
    934     finally:
    935       # At this point we know that the initialization is complete (or less

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    757     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    758     self._concrete_stateful_fn = (
--> 759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    760             *args, **kwds))
    761 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3296     arg_names = base_arg_names + missing_arg_names
   3297     graph_function = ConcreteFunction(
-> 3298         func_graph_module.func_graph_from_py_func(
   3299             self._name,
   3300             self._python_function,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
    631   def call_and_return_conditional_losses(*args, **kwargs):
    632     """Returns layer (call_output, conditional losses) tuple."""
--> 633     call_output = layer_call(*args, **kwargs)
    634     if version_utils.is_v1_layer_or_model(layer):
    635       conditional_losses = layer.get_losses_for(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/functional.py in call(self, inputs, training, mask)
    412         a list of tensors if there are more than one outputs.
    413     """
--> 414     return self._run_internal_graph(
    415         inputs, training=training, mask=mask)
    416 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
    548 
    549         args, kwargs = node.map_arguments(tensor_dict)
--> 550         outputs = node.layer(*args, **kwargs)
    551 
    552         # Update tensor_dict.

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
   1035         with autocast_variable.enable_auto_cast_variables(
   1036             self._compute_dtype_object):
-> 1037           outputs = call_fn(inputs, *args, **kwargs)
   1038 
   1039         if self._activity_regularizer:

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in return_outputs_and_add_losses(*args, **kwargs)
     66       args = args[1:]
     67 
---> 68     outputs, losses = fn(*args, **kwargs)
     69     layer.add_loss(losses, inputs=True)
     70 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
    607   def __call__(self, *args, **kwargs):
    608     self._maybe_trace(args, kwargs)
--> 609     return self.wrapped_call(*args, **kwargs)
    610 
    611   def get_concrete_function(self, *args, **kwargs):

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    922       # In this case we have not created variables on the first call. So we can
    923       # run the first trace but we should fail if variables are created.
--> 924       results = self._stateful_fn(*args, **kwds)
    925       if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION:
    926         raise ValueError("Creating variables on a non-first call to a function"

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   3036     with self._lock:
   3037       (graph_function,
-> 3038        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   3039     return graph_function._call_flat(
   3040         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3296     arg_names = base_arg_names + missing_arg_names
   3297     graph_function = ConcreteFunction(
-> 3298         func_graph_module.func_graph_from_py_func(
   3299             self._name,
   3300             self._python_function,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
    631   def call_and_return_conditional_losses(*args, **kwargs):
    632     """Returns layer (call_output, conditional losses) tuple."""
--> 633     call_output = layer_call(*args, **kwargs)
    634     if version_utils.is_v1_layer_or_model(layer):
    635       conditional_losses = layer.get_losses_for(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/transformers/models/distilbert/modeling_tf_distilbert.py in call(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    560         **kwargs,
    561     ):
--> 562         inputs = input_processing(
    563             func=self.call,
    564             config=self.config,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/transformers/modeling_tf_utils.py in input_processing(func, config, input_ids, **kwargs)
    418                     output[tensor_name] = input
    419                 else:
--> 420                     output[parameter_names[i]] = input
    421             elif isinstance(input, allowed_types) or input is None:
    422                 output[parameter_names[i]] = input

IndexError: list index out of range

Attaching code to replicate

import os

import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from transformers import TFDistilBertModel, DistilBertConfig
from focal_loss import SparseCategoricalFocalLoss

MAX_LENGTH = 256
LAYER_DROPOUT = 0.2
LEARNING_RATE = 5e-5
RANDOM_STATE = 42
NUM_CLASSES=3

# Compatible with tensorflow backend

def focal_loss(gamma=2., alpha=.25):
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
    return focal_loss_fixed

def build_model(transformer, max_length=MAX_LENGTH):

    # Define weight initializer with a random seed to ensure reproducibility
    weight_initializer = tf.keras.initializers.GlorotNormal(seed=RANDOM_STATE) 

    # Define input layers
    input_ids_layer = tf.keras.layers.Input(shape=(max_length,), 
                                            name='input_ids', 
                                            dtype='int32')
    input_attention_layer = tf.keras.layers.Input(shape=(max_length,), 
                                                  name='attention_mask', 
                                                  dtype='int32')
#     input_attention_layer = tf.keras.layers.Input(shape=(max_length,), 
#                                                   name='attention_mask', 
#                                                   dtype='int32')

    # Extract [CLS] embedding
    # It is a tf.Tensor of shape (batch_size, sequence_length, hidden_size=768).
    last_hidden_state = transformer([input_ids_layer, input_attention_layer])[0]
    cls_token = last_hidden_state[:, 0, :]

    ##                                                 ##
    ## Define additional dropout and dense layers here ##
    ##                                                 ##

    # Define a FCN layer
    output = tf.keras.layers.Dense(NUM_CLASSES, 
                                   activation='softmax',
                                   kernel_initializer=weight_initializer,  
                                   kernel_constraint=None,
                                   bias_initializer='zeros'
                                   )(cls_token)

    # Define the model
#     {"input_ids": input_ids}
    model = tf.keras.Model([input_ids_layer, input_attention_layer], output)

    # Compile the model
    model.compile(tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE), 
                  loss=SparseCategoricalFocalLoss(gamma=2),
                  metrics=['accuracy'])

    return model

def get_distil_bert_model(trainable=False, config=None):
    if not config:
        DISTILBERT_DROPOUT = 0.2
        DISTILBERT_ATT_DROPOUT = 0.2

        # Configure DistilBERT's initialization
        config = DistilBertConfig(dropout=DISTILBERT_DROPOUT, 
                                  attention_dropout=DISTILBERT_ATT_DROPOUT, 
                                  output_hidden_states=False)

    distilBert = TFDistilBertModel.from_pretrained('distilbert-base-uncased', config=config)

    if trainable is False:
        for layer in distilBert.layers:
            layer.trainable = False

    return distilBert

def get_compiled_model():
    distilBert=get_distil_bert_model()
    classification_model=build_model(distilBert)
    return classification_model

model=get_compiled_model()
model.save("model_save_path")
Zjq9409 commented 2 years ago

@hardianlawi Were you able to solve this issue? I am still facing this issue, can you help?

Hi @Rocketknight1 i am still facing this issue. I am not able to save finetuned TFDistilBertModel model in keras with model.save() . Since the model is wrapped in tf.keras.Model I can't use save_pretrained. transformers version: 4.15 Platform: Ubuntu 20 Python version: 3.8 PyTorch version (GPU?): Tensorflow version (GPU?): 2.6.2 Using GPU in script?: Yes Using distributed or parallel set-up in script?: Yes

Error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
/tmp/ipykernel_47167/2242234445.py in <module>
----> 1 model.save("save_path")

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
   2143     """
   2144     # pylint: enable=line-too-long
-> 2145     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
   2146                     signatures, options, save_traces)
   2147 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
    147   else:
    148     with generic_utils.SharedObjectSavingScope():
--> 149       saved_model_save.save(model, filepath, overwrite, include_optimizer,
    150                             signatures, options, save_traces)
    151 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
     88   with K.deprecated_internal_learning_phase_scope(0):
     89     with utils.keras_option_scope(save_traces):
---> 90       saved_nodes, node_paths = save_lib.save_and_return_nodes(
     91           model, filepath, signatures, options)
     92 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)
   1226 
   1227   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1228       _build_meta_graph(obj, signatures, options, meta_graph_def))
   1229   saved_model.saved_model_schema_version = (
   1230       pywrap_libexport.SAVED_MODEL_SCHEMA_VERSION)

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
   1397 
   1398   with save_context.save_context(options):
-> 1399     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
   1333   checkpoint_graph_view = _AugmentedGraphView(obj)
   1334   if signatures is None:
-> 1335     signatures = signature_serialization.find_function_to_export(
   1336         checkpoint_graph_view)
   1337 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
     97   # If the user did not specify signatures, check the root object for a function
     98   # that can be made into a signature.
---> 99   functions = saveable_view.list_functions(saveable_view.root)
    100   signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
    101   if signature is not None:

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj)
    161     obj_functions = self._functions.get(obj, None)
    162     if obj_functions is None:
--> 163       obj_functions = obj._list_functions_for_serialization(  # pylint: disable=protected-access
    164           self._serialization_cache)
    165       self._functions[obj] = obj_functions

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
   2810     self.predict_function = None
   2811     self.train_tf_function = None
-> 2812     functions = super(
   2813         Model, self)._list_functions_for_serialization(serialization_cache)
   2814     self.train_function = train_function

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
   3083 
   3084   def _list_functions_for_serialization(self, serialization_cache):
-> 3085     return (self._trackable_saved_model_saver
   3086             .list_functions_for_serialization(serialization_cache))
   3087 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
     91       return {}
     92 
---> 93     fns = self.functions_to_serialize(serialization_cache)
     94 
     95     # The parent AutoTrackable class saves all user-defined tf.functions, and

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
     71 
     72   def functions_to_serialize(self, serialization_cache):
---> 73     return (self._get_serialized_attributes(
     74         serialization_cache).functions_to_serialize)
     75 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     87       return serialized_attr
     88 
---> 89     object_dict, function_dict = self._get_serialized_attributes_internal(
     90         serialization_cache)
     91 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     54     # the ones serialized by Layer.
     55     objects, functions = (
---> 56         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
     57             serialization_cache))
     58     functions['_default_save_signature'] = default_signature

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     97     """Returns dictionary of serialized attributes."""
     98     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
---> 99     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
    100     # Attribute validator requires that the default save signature is added to
    101     # function dict, even if the value is None.

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
    195       for fn in fns.values():
    196         if fn is not None and not isinstance(fn, LayerCall):
--> 197           fn.get_concrete_function()
    198 
    199   # Restore overwritten functions and losses

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/contextlib.py in __exit__(self, type, value, traceback)
    118         if type is None:
    119             try:
--> 120                 next(self.gen)
    121             except StopIteration:
    122                 return False

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in tracing_scope()
    357       if training is not None:
    358         with K.deprecated_internal_learning_phase_scope(training):
--> 359           fn.get_concrete_function(*args, **kwargs)
    360       else:
    361         fn.get_concrete_function(*args, **kwargs)

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1231   def get_concrete_function(self, *args, **kwargs):
   1232     # Implements GenericFunction.get_concrete_function.
-> 1233     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1234     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1235     return concrete

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1211       if self._stateful_fn is None:
   1212         initializers = []
-> 1213         self._initialize(args, kwargs, add_initializers_to=initializers)
   1214         self._initialize_uninitialized_variables(initializers)
   1215 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    757     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    758     self._concrete_stateful_fn = (
--> 759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    760             *args, **kwds))
    761 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3296     arg_names = base_arg_names + missing_arg_names
   3297     graph_function = ConcreteFunction(
-> 3298         func_graph_module.func_graph_from_py_func(
   3299             self._name,
   3300             self._python_function,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in call(inputs, *args, **kwargs)
    649     return layer.keras_api.__call__  # pylint: disable=protected-access
    650   def call(inputs, *args, **kwargs):
--> 651     return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
    652   return _create_call_fn_decorator(layer, call)
    653 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
    607   def __call__(self, *args, **kwargs):
    608     self._maybe_trace(args, kwargs)
--> 609     return self.wrapped_call(*args, **kwargs)
    610 
    611   def get_concrete_function(self, *args, **kwargs):

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    931       # This is the first call of __call__, so we have to initialize.
    932       initializers = []
--> 933       self._initialize(args, kwds, add_initializers_to=initializers)
    934     finally:
    935       # At this point we know that the initialization is complete (or less

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    757     self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
    758     self._concrete_stateful_fn = (
--> 759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
    760             *args, **kwds))
    761 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3296     arg_names = base_arg_names + missing_arg_names
   3297     graph_function = ConcreteFunction(
-> 3298         func_graph_module.func_graph_from_py_func(
   3299             self._name,
   3300             self._python_function,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
    631   def call_and_return_conditional_losses(*args, **kwargs):
    632     """Returns layer (call_output, conditional losses) tuple."""
--> 633     call_output = layer_call(*args, **kwargs)
    634     if version_utils.is_v1_layer_or_model(layer):
    635       conditional_losses = layer.get_losses_for(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/functional.py in call(self, inputs, training, mask)
    412         a list of tensors if there are more than one outputs.
    413     """
--> 414     return self._run_internal_graph(
    415         inputs, training=training, mask=mask)
    416 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
    548 
    549         args, kwargs = node.map_arguments(tensor_dict)
--> 550         outputs = node.layer(*args, **kwargs)
    551 
    552         # Update tensor_dict.

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
   1035         with autocast_variable.enable_auto_cast_variables(
   1036             self._compute_dtype_object):
-> 1037           outputs = call_fn(inputs, *args, **kwargs)
   1038 
   1039         if self._activity_regularizer:

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in return_outputs_and_add_losses(*args, **kwargs)
     66       args = args[1:]
     67 
---> 68     outputs, losses = fn(*args, **kwargs)
     69     layer.add_loss(losses, inputs=True)
     70 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
    607   def __call__(self, *args, **kwargs):
    608     self._maybe_trace(args, kwargs)
--> 609     return self.wrapped_call(*args, **kwargs)
    610 
    611   def get_concrete_function(self, *args, **kwargs):

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    883 
    884       with OptionalXlaContext(self._jit_compile):
--> 885         result = self._call(*args, **kwds)
    886 
    887       new_tracing_count = self.experimental_get_tracing_count()

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    922       # In this case we have not created variables on the first call. So we can
    923       # run the first trace but we should fail if variables are created.
--> 924       results = self._stateful_fn(*args, **kwds)
    925       if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION:
    926         raise ValueError("Creating variables on a non-first call to a function"

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   3036     with self._lock:
   3037       (graph_function,
-> 3038        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   3039     return graph_function._call_flat(
   3040         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3296     arg_names = base_arg_names + missing_arg_names
   3297     graph_function = ConcreteFunction(
-> 3298         func_graph_module.func_graph_from_py_func(
   3299             self._name,
   3300             self._python_function,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    162       return wrapped_call(*args, **kwargs)
    163 
--> 164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
    166         lambda: replace_training_and_call(False))

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    103     return tf.cond(
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 105   return tf.__internal__.smart_cond.smart_cond(
    106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     56       return true_fn()
     57     else:
---> 58       return false_fn()
     59   else:
     60     return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in <lambda>()
    164     return control_flow_util.smart_cond(
    165         training, lambda: replace_training_and_call(True),
--> 166         lambda: replace_training_and_call(False))
    167 
    168   # Create arg spec for decorated function. If 'training' is not defined in the

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    160     def replace_training_and_call(training):
    161       set_training_arg(training, training_arg_index, args, kwargs)
--> 162       return wrapped_call(*args, **kwargs)
    163 
    164     return control_flow_util.smart_cond(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
    631   def call_and_return_conditional_losses(*args, **kwargs):
    632     """Returns layer (call_output, conditional losses) tuple."""
--> 633     call_output = layer_call(*args, **kwargs)
    634     if version_utils.is_v1_layer_or_model(layer):
    635       conditional_losses = layer.get_losses_for(

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/transformers/models/distilbert/modeling_tf_distilbert.py in call(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, training, **kwargs)
    560         **kwargs,
    561     ):
--> 562         inputs = input_processing(
    563             func=self.call,
    564             config=self.config,

~/SageMaker/custom-miniconda/miniconda/envs/custom_python_38/lib/python3.8/site-packages/transformers/modeling_tf_utils.py in input_processing(func, config, input_ids, **kwargs)
    418                     output[tensor_name] = input
    419                 else:
--> 420                     output[parameter_names[i]] = input
    421             elif isinstance(input, allowed_types) or input is None:
    422                 output[parameter_names[i]] = input

IndexError: list index out of range

Attaching code to replicate

import os

import tensorflow as tf
from tensorflow import keras
from keras import backend as K
from transformers import TFDistilBertModel, DistilBertConfig
from focal_loss import SparseCategoricalFocalLoss

MAX_LENGTH = 256
LAYER_DROPOUT = 0.2
LEARNING_RATE = 5e-5
RANDOM_STATE = 42
NUM_CLASSES=3

# Compatible with tensorflow backend

def focal_loss(gamma=2., alpha=.25):
    def focal_loss_fixed(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
    return focal_loss_fixed

def build_model(transformer, max_length=MAX_LENGTH):

    # Define weight initializer with a random seed to ensure reproducibility
    weight_initializer = tf.keras.initializers.GlorotNormal(seed=RANDOM_STATE) 

    # Define input layers
    input_ids_layer = tf.keras.layers.Input(shape=(max_length,), 
                                            name='input_ids', 
                                            dtype='int32')
    input_attention_layer = tf.keras.layers.Input(shape=(max_length,), 
                                                  name='attention_mask', 
                                                  dtype='int32')
#     input_attention_layer = tf.keras.layers.Input(shape=(max_length,), 
#                                                   name='attention_mask', 
#                                                   dtype='int32')

    # Extract [CLS] embedding
    # It is a tf.Tensor of shape (batch_size, sequence_length, hidden_size=768).
    last_hidden_state = transformer([input_ids_layer, input_attention_layer])[0]
    cls_token = last_hidden_state[:, 0, :]

    ##                                                 ##
    ## Define additional dropout and dense layers here ##
    ##                                                 ##

    # Define a FCN layer
    output = tf.keras.layers.Dense(NUM_CLASSES, 
                                   activation='softmax',
                                   kernel_initializer=weight_initializer,  
                                   kernel_constraint=None,
                                   bias_initializer='zeros'
                                   )(cls_token)

    # Define the model
#     {"input_ids": input_ids}
    model = tf.keras.Model([input_ids_layer, input_attention_layer], output)

    # Compile the model
    model.compile(tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE), 
                  loss=SparseCategoricalFocalLoss(gamma=2),
                  metrics=['accuracy'])

    return model

def get_distil_bert_model(trainable=False, config=None):
    if not config:
        DISTILBERT_DROPOUT = 0.2
        DISTILBERT_ATT_DROPOUT = 0.2

        # Configure DistilBERT's initialization
        config = DistilBertConfig(dropout=DISTILBERT_DROPOUT, 
                                  attention_dropout=DISTILBERT_ATT_DROPOUT, 
                                  output_hidden_states=False)

    distilBert = TFDistilBertModel.from_pretrained('distilbert-base-uncased', config=config)

    if trainable is False:
        for layer in distilBert.layers:
            layer.trainable = False

    return distilBert

def get_compiled_model():
    distilBert=get_distil_bert_model()
    classification_model=build_model(distilBert)
    return classification_model

model=get_compiled_model()
model.save("model_save_path")

I have the same problem, How to solve it?

hardianlawi commented 2 years ago

@kapilkd13 @Zjq9409 I completely switched to Pytorch and Pytorch Lightning since they made my life easier :')