tensorflow / lucid

A collection of infrastructure and tools for research in neural network interpretability.
Apache License 2.0
4.65k stars 655 forks source link

Issue calling Model.save() - AssertionError: softmax2 is not in graph #214

Closed camoconnell closed 4 years ago

camoconnell commented 4 years ago

Hi there, I'm currently playing around with the following style transfer example on Seedbank using the InceptionV1 model.

I am working with the Model.save() function available in the following commit - 0b81881

I have tried to adhere to the import / load / save instructions as closely as possible, but i can't seem to figure out what the output layer name is.

I am wrapping the training code in a session and calling the save() function after training

sess = tf.Session()
graph = tf.get_default_graph()

with graph.as_default():

  with sess.as_default():

    param_f = lambda: style_transfer_param(content_image, style_image)

    content_obj = 100 * activation_difference(content_layers, difference_to=CONTENT_INDEX)
    content_obj.description = "Content Loss"

    style_obj = activation_difference(style_layers, transform_f=gram_matrix, difference_to=STYLE_INDEX)
    style_obj.description = "Style Loss"

    objective = - content_obj - style_obj

    vis = render.render_vis(model, objective, param_f=param_f, thresholds=[10], verbose=False, print_objectives=[content_obj, style_obj])[-1]

    model.save(
      "saved_model.pb",
      input_name='input',                     # (eg. 'input')
      image_shape=image_shape,     # (eg. [224, 224, 3])
      output_names=['softmax2'],     # (eg. ['logits'])
      image_value_range=[0, 255],    # (eg. '[-1, 1], [0, 1], [0, 255], or [-117, 138]')
    )

I tried inspecting the layers of the model using Model.layer to find the output layer; Screen-Shot-2019-11-29-at-9 15 36-AM

i get the following error;

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/graph_util_impl.py in _assert_nodes_are_present(name_to_node, nodes)
    150   """Assert that nodes are present in the graph."""
    151   for d in nodes:
--> 152     assert d in name_to_node, "%s is not in graph" % d
    153 
    154 

AssertionError: softmax2 is not in graph

I also tried inspecting the graph using - Model.show_graph() Screen-Shot-2019-11-29-at-10 10 05-AM

but the output name 'output2' throws the same error - AssertionError: output2 is not in graph.

Not sure where else can i look to find output_names.

Thanks

camoconnell commented 4 years ago

Ok, i found the following Stack overflow answer; https://stackoverflow.com/a/58427053/79803

def analyze_inputs_outputs(graph):
    ops = graph.get_operations()
    outputs_set = set(ops)
    inputs = []
    for op in ops:
        if len(op.inputs) == 0 and op.type != 'Const':
            inputs.append(op)
        else:
            for input_tensor in op.inputs:
                if input_tensor.op in outputs_set:
                    outputs_set.remove(input_tensor.op)
    outputs = list(outputs_set)
    return (inputs, outputs)

calling print(analyze_inputs_outputs(graph))

returned ([<tf.Operation 'Variable' type=VariableV2>, <tf.Operation 'Variable_1' type=VariableV2>], [<tf.Operation 'random_crop' type=Slice>, <tf.Operation 'stack_1/values_1' type=Const>, <tf.Operation 'strided_slice_4' type=StridedSlice>, <tf.Operation 'Variable/Assign' type=Assign>, <tf.Operation 'random_crop/Assert/Const' type=Const>, <tf.Operation 'random_crop/Assert/Assert' type=Assert>, <tf.Operation 'Variable_1/Assign' type=Assign>, <tf.Operation 'random_crop_1/Assert/Const' type=Const>, <tf.Operation 'stack_2' type=Pack>, <tf.Operation 'stack/1' type=Const>, <tf.Operation 'random_crop_1/Assert/Assert' type=Assert>])

Will close ticket as taking the first and last node names solved the AssertionError.

model.save(
      "saved_model.pb",
      input_name='Variable', 
      image_shape=[224, 224, 3],
      output_names=['random_crop_1/Assert/Assert'],
      image_value_range=[-117, 138], 
    )