tensorlayer / TensorLayer

Deep Learning and Reinforcement Learning Library for Scientists and Engineers
http://tensorlayerx.com
Other
7.33k stars 1.61k forks source link

How to save tensorlayer2 model as .pb file #1091

Closed haowsun closed 4 years ago

haowsun commented 4 years ago

How to save tensorlayer2 model as .pb file

[INSERT CODE HERE]

# this is my model

def get_model(inputs_shape, keep=0.5):

    ni = tl.layers.Input(inputs_shape, name='input_layer')
    # net = tl.layers.Conv1d(n_filter=6, filter_size=3, stride=2, b_init=None, in_channels=200, name='conv1d_1')(net)
    nn = tl.layers.Conv1dLayer(act=tf.nn.relu, shape=[3, 200, 6], name='cnn_layer1', padding='VALID')(ni)
    nn = tl.layers.MaxPool1d(filter_size=3, strides=3, name='pool_layer1')(nn)
    nn = tl.layers.Flatten(name='flatten_layer')(nn)

    nn = tl.layers.Dropout(keep=keep, name='drop1')(nn)
    nn = tl.layers.Dense(n_units=2, act=tf.identity, name="output")(nn)

    model = tl.models.Model(inputs=ni, outputs=nn, name='cnn')
    return model

model = get_model([None, 20, 200])
Laicheng0830 commented 4 years ago
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
from tensorlayer.layers import Dense, Dropout
from tensorlayer.models import Model

## define the network
class CustomModel(Model):
def __init__(self):
    super(CustomModel, self).__init__()
    self.dropout1 = Dropout(keep=0.8)  # (self.innet)
    self.dense1 = Dense(n_units=800, act=tf.nn.relu, in_channels=784)
    self.dense2 = Dense(n_units=800, act=tf.nn.relu, in_channels=800)
    self.dense3 = Dense(n_units=10, act=tf.nn.relu, in_channels=800)

@tf.function
def forward(self, x, foo=None):
    z = self.dropout1(x)
    z = self.dense1(z)
    z = self.dense2(z)
    out = self.dense3(z)
    if foo is not None:
        out = tf.nn.relu(out)
    return out

if __name__ == '__main__':
    save_model = CustomModel()
    save_model.eval()
    inHeight = 800
    inWidth = 784

    input_sigbature = tf.TensorSpec(shape=(None, inHeight, inWidth))
    concrete_function = save_model.forward.get_concrete_function(x=input_sigbature)

    # This one is a no @tf.funcion
    # forward_function = tf.function(lambda x:save_model.forward(x))
    # concrete_function = forward_function.get_concrete_function(x = input_sigbature)

    frozen_graph = convert_variables_to_constants_v2(concrete_function)
    frozen_graph_def = frozen_graph.graph.as_graph_def()
    tf.io.write_graph(graph_or_graph_def=frozen_graph_def,
                      logdir="./frozen_models",
                      name="frozen_graph.pb",
                      as_text=False
                      )
Laicheng0830 commented 4 years ago

References https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/