Open Andrei997 opened 3 years ago
Updated, changing attn_output_weights += attn_mask
to attn_output_weights += tf.cast(attn_mask, tf.float32)
in transformer.py 320 fixes the initial problem, but another problem appears when exporting the model.
The new error is: ValueError: None values not supported.
, and it seems to come from: query_tgt = key_tgt = target + query_encoding
in transformer.py 211 ... Any ideas ?
Update: seems it's possible to save the model to a .pb file (frozen graph), though I haven't tried to load the model and use it for inference yet.
For reference, I managed to save the model with:
g = tf.Graph()
with g.as_default():
model = get_detr_model(config, include_top=False, nb_class=2, weights="detr")
input = tf.random.uniform((1, 512, 512, 3), minval=0, maxval=1, dtype=tf.dtypes.float32)
output = model(input)
tf.io.write_graph(g.as_graph_def(), './exports', 'model.pb', as_text=False)
It would still be nice to be able to save the model via the recommended way, using model.save(...)
, so any suggestions regarding that would be appreciated.
I follow below approach to save model and its checkpoint and to load it again
os.makedirs("/data/W2FORMS/",exist_ok=True) checkpoint_path = "/data/W2FORMS/detr.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path)
for epoch in range(0,600 ): training.fit(detr, dataset, optimzers, config, epoch_nb=epoch, class_names=CLASS_NAMES) detr.save_weights(checkpoint_path)
Now for loading this weights again to the model
Hi, and first of all thank you for your contribution. To cut it short, I cannot save a trained model using the saved_model format, with the following error:
TypeError: Input 'y' of 'AddV2' Op has type bool that does not match type float32 of argument 'x'
The code I'm using to save a model is:
model.save("./exports", include_optimizer=False)
I need to emphasize that I do not want to save the model via
save_weights
, as I need to use the saved_model format for inference with tools like TensorflowServing and TritonInferenceServer.Any help would be much appreciated!
Problem Cannot save model via
model.save(...)
ortf.saved_model.save(...)
Minimal steps to reproduce issue