awsaf49 / gcvit-tf

Tensorflow 2.0 Implementation of GCViT: Global Context Vision Transformer
MIT License
26 stars 5 forks source link

GCViT model load fails for subclassed models #20

Closed andreped closed 1 year ago

andreped commented 1 year ago

Describe the bug I have been finetuning a GCViTXXTiny architecture on some image classification problem. As discussion in https://github.com/awsaf49/gcvit-tf/discussions/19, I managed to get it converging and all seems good.

However, when attempting to load the model, I ran into a TypeError. I have seemed something like this before when subclassing Keras Models, which I also do in this case.

In Keras', they have tried to go away from needing to set the from_config()/get_config() methods for custom layers/classes. However, I have seen multiple scenarios (at least with older tf versions <= 2.11) where this did not in fact work for all situations.

I think it might be necessary to add the aforementioned methods to all custom layers and classes, such that trained models can be loaded as normal. This was the default before and can also been seen is still done in Keras' own tutorials (see here).

Might be that we also need to add the @keras.saving.register_keras_serializable() decorator to all custom layers to make them properly serializable. Maybe that is the main problem here. Not sure.

To Reproduce I have shared a simple gist to reproduce the error. For some reason, it does not reproduce it, but I believe there is something Im doing wrong in the simple reproducible script. Will update it if I find out what the problem is.

I simply finetuned a classifier architecture containing the GCViT as backbone, and tried to save and load it. It seems to fail at loading. It also complains during saving, which may be related to this issue.

Error log

Traceback (most recent call last):
  File "apps/train_lowres.py", line 242, in <module>
    main()
  File "apps/train_lowres.py", line 192, in main
    model = tf.keras.models.load_model(model_path + curr_time + "_" + curr_date + "_model_lowres_classifier_" + name, compile=False)
  File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/engine/training.py", line 3132, in from_config
    raise TypeError(
TypeError: Unable to revive model from config. When overriding the `get_config()`, make sure that the returned config contains all items used as arguments in the constructor to <class 'gcvit.models.gcvit.GCViT'>, which is the default behavior. You can override this default behavior by defining a `from_config` method to specify how to create an instance of GCViT from the config. 

Error encountered during deserialization:
__init__() missing 4 required positional arguments: 'window_size', 'dim', 'depths', and 'num_heads'

There was also these warnings when saving:

WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.reshaping.zero_padding2d.ZeroPadding2D object at 0x7f16c0476040>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <gcvit.layers.feature.ReduceSize object at 0x7f16c028b3a0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.reshaping.zero_padding2d.ZeroPadding2D object at 0x7f16c022a520>, because it is not built.
WARNING:absl:Found untraced functions such as proj_layer_call_fn, proj_layer_call_and_return_conditional_losses, _jit_compiled_convolution_op, conv_down_layer_call_fn, conv_down_layer_call_and_return_conditional_losses while saving (showing 5 of 1351). These functions will not be directly callable after loading.
WARNING:absl:<gcvit.layers.feature.Resizing object at 0x7f16c061e5e0> has the same name 'Resizing' as a built-in Keras object. Consider renaming <class 'gcvit.layers.feature.Resizing'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<gcvit.layers.feature.Resizing object at 0x7f16c0529c40> has the same name 'Resizing' as a built-in Keras object. Consider renaming <class 'gcvit.layers.feature.Resizing'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<gcvit.layers.feature.Resizing object at 0x7f16c0471100> has the same name 'Resizing' as a built-in Keras object. Consider renaming <class 'gcvit.layers.feature.Resizing'> to avoid naming conflicts when loading with `tf.keras.models.load_model`. If renaming is not possible, pass the object in the `custom_objects` parameter of the load function.
WARNING:absl:<gcvit.layers.feature.Resizing object at 0x7f16c02245e0> has the same name 'Resizing' as a built-in Keras object. Consider renaming <class 'gcvit.layers.feature.Resizing'> to avoid nami

Desktop (please complete the following information):

andreped commented 1 year ago

The last model saving warning can likely be solved by simply renaming the custom Resizing layer seen here to something else that does not collide with Keras' own built-in layer. Suggestion: Resizer or TensorResizer.

andreped commented 1 year ago

Perhaps I'm not building the backbone properly before training? Right now, I simply do:

import tensorflow as tf
from gcvit import GCViTXXTiny

base = GCViTXXTiny(input_shape=self.instance_size[1:], pretrain=True, resize_query=True)

# remove classifier head and pooling
base = tf.keras.models.Sequential(base.layers[:-2])
base.build((self.instance_size))

Before the base model is later added as part of a tf.keras.Model through the Functional API. Any ideas?

andreped commented 1 year ago

@awsaf49 Perhaps simply adding default arguments to this class is all that is required: https://github.com/awsaf49/gcvit-tf/blob/371e736ab8a707025facdafbc9a75a4dd9016723/gcvit/models/gcvit.py#L83

Would make sense why it was complaining on those four variables at least, as those are the only variables in that class without default variables.

andreped commented 1 year ago

Just tried to make a fix for this issue, but I realised that something is broken in the main branch compared to the stable release.

The stable release seems to work fine (with the exception of the aforementioned bug). If I try to use GCViT from the main branch directly, I get this when training starts:

Traceback (most recent call last):
  File "apps/train_lowres.py", line 242, in <module>
    main()
  File "apps/train_lowres.py", line 130, in main
    model = network.create()
  File "/home/andrep/workspace/bcgrade/pypai/models/classifiers.py", line 188, in create
    x = base_model(input_)
  File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_fileucbrdl9q.py", line 10, in tf__call
    x = ag__.converted_call(ag__.ld(self).forward_features, (ag__.ld(inputs),), None, fscope)
  File "/tmp/__autograph_generated_filemifzgwsj.py", line 25, in tf__forward_features
    ag__.for_stmt(ag__.ld(self).levels, None, loop_body, get_state, set_state, ('x',), {'iterate_names': 'level'})
  File "/tmp/__autograph_generated_filemifzgwsj.py", line 23, in loop_body
    x = ag__.converted_call(ag__.ld(level), (ag__.ld(x),), None, fscope)
  File "/tmp/__autograph_generated_file2xiefe65.py", line 12, in tf__call
    q_global = ag__.converted_call(ag__.ld(self).q_global_gen, (ag__.ld(x),), None, fscope)
  File "/tmp/__autograph_generated_filehxv07mbz.py", line 24, in tf__call
    ag__.for_stmt(ag__.ld(self).to_q_global, None, loop_body, get_state, set_state, ('x',), {'iterate_names': 'layer'})
  File "/tmp/__autograph_generated_filehxv07mbz.py", line 22, in loop_body
    x = ag__.converted_call(ag__.ld(layer), (ag__.ld(x),), None, fscope)
  File "/tmp/__autograph_generated_file_ipi8lex.py", line 25, in tf__call
    ag__.for_stmt(ag__.ld(self).conv, None, loop_body, get_state, set_state, ('xr',), {'iterate_names': 'layer'})
  File "/tmp/__autograph_generated_file_ipi8lex.py", line 23, in loop_body
    xr = ag__.converted_call(ag__.ld(layer), (ag__.ld(xr),), None, fscope)
  File "/tmp/__autograph_generated_filefg1s1a87.py", line 11, in tf__call
    x = ag__.converted_call(ag__.ld(tf).reshape, (ag__.converted_call(ag__.ld(self).avg_pool, (ag__.ld(inputs),), None, fscope), (ag__.ld(b), ag__.ld(c))), None, fscope)
  File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in tf__call
    ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
  File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in <lambda>
    ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
TypeError: Exception encountered when calling layer "gcvit_tiny" (type GCViT).

in user code:

    File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/gcvit/models/gcvit.py", line 200, in call  *
        x = self.forward_features(inputs)
    File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/gcvit/models/gcvit.py", line 187, in forward_features  *
        x = level(x)
    File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/__autograph_generated_file2xiefe65.py", line 12, in tf__call
        q_global = ag__.converted_call(ag__.ld(self).q_global_gen, (ag__.ld(x),), None, fscope)
    File "/tmp/__autograph_generated_filehxv07mbz.py", line 24, in tf__call
        ag__.for_stmt(ag__.ld(self).to_q_global, None, loop_body, get_state, set_state, ('x',), {'iterate_names': 'layer'})
    File "/tmp/__autograph_generated_filehxv07mbz.py", line 22, in loop_body
        x = ag__.converted_call(ag__.ld(layer), (ag__.ld(x),), None, fscope)
    File "/tmp/__autograph_generated_file_ipi8lex.py", line 25, in tf__call
        ag__.for_stmt(ag__.ld(self).conv, None, loop_body, get_state, set_state, ('xr',), {'iterate_names': 'layer'})
    File "/tmp/__autograph_generated_file_ipi8lex.py", line 23, in loop_body
        xr = ag__.converted_call(ag__.ld(layer), (ag__.ld(xr),), None, fscope)
    File "/tmp/__autograph_generated_filefg1s1a87.py", line 11, in tf__call
        x = ag__.converted_call(ag__.ld(tf).reshape, (ag__.converted_call(ag__.ld(self).avg_pool, (ag__.ld(inputs),), None, fscope), (ag__.ld(b), ag__.ld(c))), None, fscope)
    File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in tf__call
        ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
    File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in <lambda>
        ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)

    TypeError: Exception encountered when calling layer 'levels/0' (type GCViTLevel).

    in user code:

        File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/gcvit/layers/level.py", line 79, in call  *
            q_global = self.q_global_gen(x)  # (B, H, W, C)
        File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
            raise e.with_traceback(filtered_tb) from None
        File "/tmp/__autograph_generated_filehxv07mbz.py", line 24, in tf__call
            ag__.for_stmt(ag__.ld(self).to_q_global, None, loop_body, get_state, set_state, ('x',), {'iterate_names': 'layer'})
        File "/tmp/__autograph_generated_filehxv07mbz.py", line 22, in loop_body
            x = ag__.converted_call(ag__.ld(layer), (ag__.ld(x),), None, fscope)
        File "/tmp/__autograph_generated_file_ipi8lex.py", line 25, in tf__call
            ag__.for_stmt(ag__.ld(self).conv, None, loop_body, get_state, set_state, ('xr',), {'iterate_names': 'layer'})
        File "/tmp/__autograph_generated_file_ipi8lex.py", line 23, in loop_body
            xr = ag__.converted_call(ag__.ld(layer), (ag__.ld(xr),), None, fscope)
        File "/tmp/__autograph_generated_filefg1s1a87.py", line 11, in tf__call
            x = ag__.converted_call(ag__.ld(tf).reshape, (ag__.converted_call(ag__.ld(self).avg_pool, (ag__.ld(inputs),), None, fscope), (ag__.ld(b), ag__.ld(c))), None, fscope)
        File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in tf__call
            ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
        File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in <lambda>
            ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)

        TypeError: Exception encountered when calling layer 'q_global_gen' (type GlobalQueryGen).

        in user code:

            File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/gcvit/layers/feature.py", line 250, in call  *
                x = layer(x)
            File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
                raise e.with_traceback(filtered_tb) from None
            File "/tmp/__autograph_generated_file_ipi8lex.py", line 25, in tf__call
                ag__.for_stmt(ag__.ld(self).conv, None, loop_body, get_state, set_state, ('xr',), {'iterate_names': 'layer'})
            File "/tmp/__autograph_generated_file_ipi8lex.py", line 23, in loop_body
                xr = ag__.converted_call(ag__.ld(layer), (ag__.ld(xr),), None, fscope)
            File "/tmp/__autograph_generated_filefg1s1a87.py", line 11, in tf__call
                x = ag__.converted_call(ag__.ld(tf).reshape, (ag__.converted_call(ag__.ld(self).avg_pool, (ag__.ld(inputs),), None, fscope), (ag__.ld(b), ag__.ld(c))), None, fscope)
            File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in tf__call
                ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
            File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in <lambda>
                ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)

            TypeError: Exception encountered when calling layer 'to_q_global/0' (type FeatExtract).

            in user code:

                File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/gcvit/layers/feature.py", line 210, in call  *
                    xr = layer(xr)
                File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
                    raise e.with_traceback(filtered_tb) from None
                File "/tmp/__autograph_generated_filefg1s1a87.py", line 11, in tf__call
                    x = ag__.converted_call(ag__.ld(tf).reshape, (ag__.converted_call(ag__.ld(self).avg_pool, (ag__.ld(inputs),), None, fscope), (ag__.ld(b), ag__.ld(c))), None, fscope)
                File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in tf__call
                    ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
                File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in <lambda>
                    ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)

                TypeError: Exception encountered when calling layer 'conv/2' (type SE).

                in user code:

                    File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/gcvit/layers/feature.py", line 82, in call  *
                        x = tf.reshape(self.avg_pool(inputs), (b, c))
                    File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler  **
                        raise e.with_traceback(filtered_tb) from None
                    File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in tf__call
                        ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)
                    File "/tmp/__autograph_generated_filent3lrkbg.py", line 58, in <lambda>
                        ag__.if_stmt(ag__.and_((lambda : ((ag__.ld(input_shape)[0] % ag__.ld(self).output_size[0]) == 0)), (lambda : ((ag__.ld(input_shape)[1] % ag__.ld(self).output_size[1]) == 0))), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_'), 2)

                    TypeError: Exception encountered when calling layer 'avg_pool' (type AdaptiveAveragePooling2D).

                    in user code:

                        File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/gcvit/layers/pooling.py", line 117, in call  *
                            if (

                        TypeError: unsupported operand type(s) for %: 'NoneType' and 'int'

                    Call arguments received by layer 'avg_pool' (type AdaptiveAveragePooling2D):
                      • inputs=tf.Tensor(shape=(None, None, None, 64), dtype=float16)

                Call arguments received by layer 'conv/2' (type SE):
                  • inputs=tf.Tensor(shape=(None, None, None, 64), dtype=float16)
                  • kwargs={'training': 'None'}

            Call arguments received by layer 'to_q_global/0' (type FeatExtract):
              • inputs=tf.Tensor(shape=(None, None, None, 64), dtype=float16)
              • kwargs={'training': 'None'}

        Call arguments received by layer 'q_global_gen' (type GlobalQueryGen):
          • inputs=tf.Tensor(shape=(None, None, None, 64), dtype=float16)
          • kwargs={'training': 'None'}

    Call arguments received by layer 'levels/0' (type GCViTLevel):
      • inputs=tf.Tensor(shape=(None, 256, 256, 64), dtype=float16)
      • kwargs={'training': 'None'}

Call arguments received by layer "gcvit_tiny" (type GCViT):
  • inputs=tf.Tensor(shape=(None, 1024, 1024, 3), dtype=float16)
  • kwargs={'training': 'None'}
andreped commented 1 year ago

Might be that something broke in the commit: https://github.com/awsaf49/gcvit-tf/commit/720adf7e43e875daa880ab5588d16612f0f4dd6a

andreped commented 1 year ago

I tried to revert the changes in the aforementioned commit and tried to set default arguments to both the GCViT and GCViTLevel classes. See my fork here.

At least training starts, and after the first epoch, I attempt to save the model on disk and reload it from disk. This results in a new error:

Traceback (most recent call last):
  File "apps/train_lowres.py", line 242, in <module>
    main()
  File "apps/train_lowres.py", line 192, in main
    model = tf.keras.models.load_model(model_path + curr_time + "_" + curr_date + "_model_lowres_classifier_" + name, compile=False)
  File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/andrep/workspace/bcgrade/venv_latest/lib/python3.8/site-packages/tensorflow/python/training/saving/saveable_object_util.py", line 139, in restore
    raise ValueError(
ValueError: Received incompatible tensor with shape (3, 3, 3, 64) when attempting to restore variable with shape (3, 3, 3, 128) and name variables/0/.ATTRIBUTES/VARIABLE_VALUE.

Might be that setting defaults args to the four variables I mentioned results in something breaking. But I could not really see why that was so.

I would rather think this might be due to the resize_query=True we set to support different input size to the network is not really working properly when loading the model from disk afterwards. But this is only speculation.

awsaf49 commented 1 year ago

@andreped Apologies, but I'm currently tied up with something important. I'll be available after September 24th to address this matter. Thanks for your patience regarding the delay.

awsaf49 commented 1 year ago

But I do want to point out that I tried GCViT a few months ago in IEEE Vip Cup 2022 with different image sizes with load and save it worked fine. I did that using SavedModel format instead of .h5. So either something has changed since or we need to save/load the model differently.

andreped commented 1 year ago

Could the TF version that I am using that is not compatible. Im also using an older protobuf and Im subclassing the model. Lots of things it could be. I can try to debug it further tomorrow and see if I can make a reproducible gist.

Anyways, I dont expect you to look at it until you have time.

awsaf49 commented 1 year ago

Might be that something broke in the commit: 720adf7

You are right this commit broke the code. I've reverted it.

awsaf49 commented 1 year ago

I have updated the code to load and save model properly. Previously it used to be fine but I think due to TF version change this issue was created. Anyway, I used following code to save and load the model.

save:

save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save('./checkpoint', options=save_locally) # saving in Tensorflow's "SavedModel" format

load:

# with strategy.scope(): # for tpu
load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
model2 = tf.keras.models.load_model('./checkpoint', options=load_locally) # loading in Tensorflow's "SavedModel" format

You can also check it in GCViT - Starter Notebook notebook. I have also added a dummy training before saving and loading the model in non-default image (128, 128, 3) with resize_query=True.

andreped commented 1 year ago

@awsaf49 Why is it necessary to use tf.saved_model.LoadOptions for both saving and loading of the model? I have not used that before.

awsaf49 commented 1 year ago

It's actually not necessary to use tf.saved_model.LoadOptions, at least not in GPU/CPU. However, in TPU model runs on VM so it is required to load and save the model locally using options.