joshcarty / tfgnn-ogb

Examples using TensorFlow GNN with Open Graph Benchmark datasets.
MIT License
7 stars 0 forks source link

Can't reload model to train #5

Closed SidneyLann closed 2 years ago

SidneyLann commented 2 years ago

logs for tf.keras.models.save_model(gnn, save_path) or tf.saved_model.save(gnn, save_path):

/home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/readout_first_node/Reshape_1:0", shape=(128,), dtype=int32), values=Tensor("gradient_tape/node_classification_model/readout_first_node/Reshape:0", shape=(128, 32), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/readout_first_node/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_3:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_2:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_6:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_5:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_3:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_2:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_6:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_5:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( 711/711 [==============================] - 2571s 4s/step - loss: 2.2800 - accuracy: 0.3815 - top_5: 0.7001 - val_loss: 1.8160 - val_accuracy: 0.4884 - val_top_5: 0.8237 2022-03-23 07:25:10.295444: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/readout_first_node/Reshape_1:0", shape=(128,), dtype=int32), values=Tensor("gradient_tape/node_classification_model/readout_first_node/Reshape:0", shape=(128, 32), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/readout_first_node/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_3:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_2:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_6:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Reshape_5:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update_1/node_set_update/gat_v2_convolution/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_3:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_2:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Cast:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/indexed_slices.py:444: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_6:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Reshape_5:0", shape=(None, 4, 1), dtype=float32), dense_shape=Tensor("gradient_tape/node_classification_model/graph_update/node_set_update/gat_v2_convolution/Cast_1:0", shape=(3,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. warnings.warn( WARNING:absl:Function _wrapped_model contains input name(s) args_0 with unsupported characters which will be renamed to args_0_6 in the SavedModel. WARNING:absl:Found untraced functions such as node_set_update_layer_call_fn, node_set_update_layer_call_and_return_conditional_losses, node_set_update_layer_call_fn, node_set_update_layer_call_and_return_conditional_losses, next_state_from_concat_layer_call_fn while saving (showing 5 of 28). These functions will not be directly callable after loading. /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:522: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.GraphTensorSpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:522: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.ContextSpec.v2; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:522: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.NodeSetSpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:522: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.EdgeSetSpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this " /home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:522: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.AdjacencySpec; loading this StructuredValue will require that this type be imported and registered. warnings.warn("Encoding a StructuredValue with type %s; loading this "

Process finished with exit code 0

SidneyLann commented 2 years ago

Logs for gnn = tf.keras.models.load_model(save_path):

WARNING:absl:Importing a function (__inference_internal_grad_fn_3270131) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. /home/sidney/.pycharm_helpers/pydev/_pydevd_bundle/pydevd_resolver.py:166: UserWarning: Model.state_updates will be removed in a future version. This property should not be used in TensorFlow 2.0, as updates are applied automatically. attr = getattr(var, n) 2022-03-23 07:43:39.475666: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. /home/sidney/.pycharm_helpers/pydev/_pydevd_bundle/pydevd_resolver.py:166: UserWarning: layer.updates will be removed in a future version. This property should not be used in TensorFlow 2.0, as updates are applied automatically. attr = getattr(var, n) /home/sidney/.pycharm_helpers/pydev/_pydevd_bundle/pydevd_resolver.py:71: UserWarning: Model.state_updates will be removed in a future version. This property should not be used in TensorFlow 2.0, as updates are applied automatically. return getattr(var, attribute) /home/sidney/.pycharm_helpers/pydev/_pydevd_bundle/pydevd_resolver.py:71: UserWarning: layer.updates will be removed in a future version. This property should not be used in TensorFlow 2.0, as updates are applied automatically. return getattr(var, attribute) Traceback (most recent call last): File "/home/sidney/.pycharm_helpers/pydev/pydevd.py", line 1741, in main() File "/home/sidney/.pycharm_helpers/pydev/pydevd.py", line 1735, in main globals = debugger.run(setup['file'], None, None, is_module) File "/home/sidney/.pycharm_helpers/pydev/pydevd.py", line 1135, in run pydev_imports.execfile(file, globals, locals) # execute the script File "/home/sidney/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "/home/sidney/py_proj/pcng-idea-kg/gat/main.py", line 102, in main() File "/home/sidney/py_proj/pcng-idea-kg/gat/main.py", line 70, in main gnn.fit( File "/home/sidney/Python3102/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler raise e.with_traceback(filtered_tb) from None File "/home/sidney/Python3102/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py", line 1147, in autograph_handler raise e.ag_error_metadata.to_exception(e) ValueError: in user code:

File "/home/sidney/Python3102/lib/python3.10/site-packages/keras/engine/training.py", line 1021, in train_function  *
    return step_function(self, iterator)
File "/home/sidney/Python3102/lib/python3.10/site-packages/keras/engine/training.py", line 1010, in step_function  **
    outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/home/sidney/Python3102/lib/python3.10/site-packages/keras/engine/training.py", line 1000, in run_step  **
    outputs = model.train_step(data)
File "/home/sidney/Python3102/lib/python3.10/site-packages/keras/engine/training.py", line 861, in train_step
    self._validate_target_and_loss(y, loss)
File "/home/sidney/Python3102/lib/python3.10/site-packages/keras/engine/training.py", line 818, in _validate_target_and_loss
    raise ValueError(

ValueError: Target data is missing. Your model was compiled with loss=<keras.losses.SparseCategoricalCrossentropy object at 0x7f8bd5a6c040>, and therefore expects target data to be provided in `fit()`.

Process finished with exit code 1

SidneyLann commented 2 years ago

Logs for gnn = tf.saved_model.load(save_path):

WARNING:absl:Importing a function (__inference_internal_grad_fn_3270131) with ops with unsaved custom gradients. Will likely fail if a gradient is requested. Traceback (most recent call last): File "/home/sidney/py_proj/pcng-idea-kg/gat/main.py", line 102, in main() File "/home/sidney/py_proj/pcng-idea-kg/gat/main.py", line 70, in main gnn.fit( AttributeError: '_UserObject' object has no attribute 'fit'

Process finished with exit code 1

SidneyLann commented 2 years ago
@tf.function(input_signature=[......])
def train_step(self, data: tf.data.Dataset) -> Dict[str, float]:

Does the signature need to be set? What's the values then?

SidneyLann commented 2 years ago

Can only be reload to Functional or _UserObject but NOT NodeClassificationModel, so it is different with the init model of type NodeClassificationModel and can't be train to use methods in NodeClassificationModel.

SidneyLann commented 2 years ago

self.target_node = kwargs.get('target_node',"paper") #kwargs.pop("target_node") self.label_name = kwargs.get('label_name',"label") #kwargs.pop("label_name")

Amend this 2 lines to store target_node and label_name and then can be reload for trainning. Model can be reload and trained now!

joshcarty commented 2 years ago

Interesting, pleased that you got this working. Not sure I fully understand why the solution works but great that it does.

Does the signature need to be set? What's the values then?

The model trains for me without an input_signature being set. However, there are warnings around Tensor types and shapes that it might help with. You would pass in the TensorSpec (GraphTensorSpec) that describes your input data. For example the GraphTensorSpec in schema.TYPE_SPECS['ogbn-arxiv'].

SidneyLann commented 2 years ago

The model trains for me without an input_signature being set.

But it can't be trained when reload it which save by tf.saved_model.save

Without an input_signature, the model can be trained when reload it which save by tf.keras.models.save_model