kappazeta / km_predict

S2 full image prediction
Apache License 2.0
21 stars 9 forks source link

Possible to get model weights as .pb file? #18

Open JohnMBrandt opened 2 years ago

JohnMBrandt commented 2 years ago

hi,

Very incredible paper. Really cool to see the ability to pick up the small clouds and cloud shadows compared to Maja & Fmask.

I'd like to test applying the L2A UNet model in my existing Tensorflow-based image segmentation pipeline. I currently use something like the following to save my model weights

saver.restore(sess,tf.train.latest_checkpoint(meta_path))
output_node_names = ['conv2d_13/Sigmoid']

# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
     output_node_names)

with open('../../models/620-240-apr/predict_graph.pb', 'wb') as f:
     f.write(frozen_graph_def.SerializeToString())"

and then something like the following to load them:

 predict_file = tf.io.gfile.GFile(args.predict_model_path + "predict_graph.pb", 'rb')
 predict_graph_def.ParseFromString(predict_file.read())
 predict_graph = tf.import_graph_def(predict_graph_def, name='predict')

Do you know if it would be possible to make something like this out of then .hdf5 file for the U-Net weights? do you know the input, output variable names? This is the only way I know to have multiple TF graphs in memory at once -- I have one for doing model predictions, one for super resolving 20m band to 10m, and would have a third graph with your model weights

tanyashtym commented 2 years ago

Hi @JohnMBrandt!

Thanks for your question!

I understood your question, and we're currently working on it. We will let you know as soon as possible about the result.

tanyashtym commented 2 years ago

Hi @JohnMBrandt :)

I tried to convert the L2A U-Net model to pb. I tested it out on one of the examples and it seems to work correctly. Here is the link to the file: https://drive.google.com/file/d/1zr92-v9dtZDBb8vy1PnZCoqLtxlJh2Up/view?usp=sharing

Please let me know if there are any problems with the file or if there is anything we could help you with to make KappaMask work.

JohnMBrandt commented 2 years ago

Thanks! This is super helpful. I was able to pull in the graph def and load the model weights. I'm trying to find the name of the tensor to run in order to get the sigmoid output.

I loaded the .pb file as:

graph_def = tf.compat.v1.GraphDef()

pb_file = tf.io.gfile.GFile("l2a_unet.pb", 'rb')
graph_def.ParseFromString(pb_file.read())
cloud_graph = tf.import_graph_def(graph_def, name='predict')
cloud_sess = tf.compat.v1.Session(graph=cloud_graph)

and can print all of the tensors with:

[n.name for n in tf.get_default_graph().as_graph_def().node]

But I can't seem to find any tensor names that have a sigmoid in them, indicating the output.

The last convolution layer has the following tensors:

'predict_1/Model/conv2d_23/kernel',
 'predict_1/Model/conv2d_23/bias',
 'predict_1/Model/conv2d_23/Conv2D/ReadVariableOp',
 'predict_1/Model/conv2d_23/Conv2D',
 'predict_1/Model/conv2d_23/BiasAdd/ReadVariableOp',
 'predict_1/Model/conv2d_23/BiasAdd',
 'predict_1/Model/conv2d_23/Max/reduction_indices',
 'predict_1/Model/conv2d_23/Max',
 'predict_1/Model/conv2d_23/sub',
 'predict_1/Model/conv2d_23/Exp',
 'predict_1/Model/conv2d_23/Sum/reduction_indices',
 'predict_1/Model/conv2d_23/Sum',
 'predict_1/Model/conv2d_23/truediv'

any ideas? Once I know which tensor to run for the output, I think I'll have a very streamlined way to run this architecture within an existing python script where the sentinel 2 tile patches are already loaded into RAM, which I'll be happy to share.