Closed haowsun closed 4 years ago
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from tensorlayer.layers import Dense, Dropout
from tensorlayer.models import Model
## define the network
class CustomModel(Model):
def __init__(self):
super(CustomModel, self).__init__()
self.dropout1 = Dropout(keep=0.8) # (self.innet)
self.dense1 = Dense(n_units=800, act=tf.nn.relu, in_channels=784)
self.dense2 = Dense(n_units=800, act=tf.nn.relu, in_channels=800)
self.dense3 = Dense(n_units=10, act=tf.nn.relu, in_channels=800)
@tf.function
def forward(self, x, foo=None):
z = self.dropout1(x)
z = self.dense1(z)
z = self.dense2(z)
out = self.dense3(z)
if foo is not None:
out = tf.nn.relu(out)
return out
if __name__ == '__main__':
save_model = CustomModel()
save_model.eval()
inHeight = 800
inWidth = 784
input_sigbature = tf.TensorSpec(shape=(None, inHeight, inWidth))
concrete_function = save_model.forward.get_concrete_function(x=input_sigbature)
# This one is a no @tf.funcion
# forward_function = tf.function(lambda x:save_model.forward(x))
# concrete_function = forward_function.get_concrete_function(x = input_sigbature)
frozen_graph = convert_variables_to_constants_v2(concrete_function)
frozen_graph_def = frozen_graph.graph.as_graph_def()
tf.io.write_graph(graph_or_graph_def=frozen_graph_def,
logdir="./frozen_models",
name="frozen_graph.pb",
as_text=False
)
How to save tensorlayer2 model as .pb file
[INSERT CODE HERE]