marcoancona / DeepExplain

A unified framework of perturbation and gradient-based attribution methods for Deep Neural Networks interpretability. DeepExplain also includes support for Shapley Values sampling. (ICLR 2018)
https://arxiv.org/abs/1711.06104
MIT License
725 stars 133 forks source link

Problem on recreating the graph with trained weights (TensorFlow) #13

Open maosi-chen opened 6 years ago

maosi-chen commented 6 years ago

I have trained model (1-lyr Bi-LSTM followed by 2-lyr FC). Following the MNIST example, I tried to "recreate the network graph" under the DeepExplain context.

The problem of the recreated logits = model(X) is that it created a new graph of everything with a similar but different name as in the original graph. For example, the new (recreated) graph has a tensor Tensor("Prediction_1/predicted:0", shape=(?,), dtype=float32) and its original counterpart is Tensor("Prediction/predicted:0", shape=(?,), dtype=float32). As a result, the session_run didn't work and I guess it is because the weights restored for the original graph were not recognized in the recreated graph. How can I solve this problem?

Thanks.

marcoancona commented 6 years ago

Can you share the code you use to reload the weights?

maosi-chen commented 6 years ago

I have a class A with its __init__ to receive model parameters (e.g. numbers of LSTM and FC layers, number of n, dropout, etc.) and build the graph in it. The input tensor is extracted from a feedable tf.dataset pipeline, the predicted tensor is the result of the LSTM + FC. The run member function of class A gets the running mode (i.e. TRAIN, EVAL, PREDICT, or ATTRIBUTION ) and run the graph defined in __init__ accordingly. For modes other than TRAIN and EVAL, the weights are restored by restore_checkpoint:

def restore_checkpoint(self, sess):
    # restore the latest checkpoint status
    try:
        ckpt = tf.train.get_checkpoint_state(self.FP_checkpoints)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(sess, ckpt.model_checkpoint_path)
            last_global_step = sess.run(self.global_step)
            return last_global_step
        else:
            raise Exception('Message: Checkpoint was not restored correctly.')
    except Exception as err:
        print(err.args)
        return -1

Currently, a workaround I found is to apply the graph.gradient_override_map(...) in the original graph for the parts from extracted input tensor to the predicted tensor before restoring the trained weights. Besides, I have to prepare attribution tensor in my graph instead of calculating them in the get_symbolic_attribution function for methods with [g * x] attributions (not necessary if g is the attribution because g is somehow part of the original graph already). I know my workaround is messy and move substantial parts of the code in your code into the original graph, could you help improve this? Thanks.

maosi-chen commented 5 years ago

The issue was solved by wrapping the graph components with tf.make_template function. I set the "name" property of the instance of the model class to a unique string, so that every time an instance's (with that unique "name") method is called, the Variables in that method will be reused without creating new ones. By the way, Tensors will still be duplicated with suffix if the method is called multiple times, but that doesn't affect the recovery of weights/biases from checkpoints (b/c they are Variables and they have unique names).

Reference: https://gist.github.com/danijar/720394a9071a03413be8a60852374aa4