pmorerio / admd

Tensorflow code for the paper 'Learning with privileged information via adversarial discriminative modality distillation', TPAMI 2019
MIT License
10 stars 3 forks source link

Help to change the first convolution #5

Closed Scienceseb closed 4 years ago

Scienceseb commented 4 years ago

Hi I need to change the first convolution of the model from rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x3x64) to rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x4x64), so basicaly augmenting the number of filter form 3 to 4 to accept 4 channels images but keeping the pretrained weight elsewhere (just the additional channel initialize ramdonly).

Do you have an idea of how to do that in Tensorflow (I'm more of a PyTorch guy...) ?

InPyTorch I do:

net = model.resnet50(num_classes=dataset_train.num_classes(),pretrained=True)

new_conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2,padding=3,bias=False)  

conv1 = net.conv1

with torch.no_grad():

   new_conv1.weight[:, :3, :, :]= conv1.weight

   new_conv1.bias = conv1.bias

net.conv1 = new_conv1

Thanks a lot for your help!

pmorerio commented 4 years ago

Ok, in tensorflow this is not as straightforward as in pytroch. Not sure exacly how to do it. The resnet class returns a python dictionary of all the layers (end_points, see here). You can access the first resnet layer. Plus, tensorflow ha s a list of all trainable variables from which you can access the weights. You can try something similar to what you do in pytorch. When you create a new layer, be careful with the scope.

Scienceseb commented 4 years ago

Ok I just need to change rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x3x64) to rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x4x64)...

conv1 is: rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x3x64) [9408, bytes: 37632] rgb/resnet_v1_50/conv1/BatchNorm/gamma:0 (float32_ref 64) [64, bytes: 256] rgb/resnet_v1_50/conv1/BatchNorm/beta:0 (float32_ref 64) [64, bytes: 256]

so I can do something with:

    net=end_points[modality + '/resnet_v1_50/conv1']

    with tf.variable_scope(modality + '/resnet_v1_50', reuse=reuse):
        bottleneck = slim.conv2d(net, ### 4, [
                                 7, 7], padding='VALID', activation_fn=tf.nn.relu, scope='f_repr')
        net = slim.conv2d(bottleneck, self.no_classes, [
                          1, 1], activation_fn=None, scope='_logits_')
pmorerio commented 4 years ago

Well maybe best option is to create a local copy of from the slim resnet_v1 file. Instead of from tensorflow.contrib.slim.nets import resnet_v1 you can import you local copy and modify that.

Scienceseb commented 4 years ago

Well maybe best option is to create a local copy of from the slim resnet_v1 file. Instead of from tensorflow.contrib.slim.nets import resnet_v1 you can import you local copy and modify that.

Ok but what will be your methodology to do that: how can I change the first convolution of a local copy of the resnet_v1 file ? I really dont know how to make the 3 become a 4...

pmorerio commented 4 years ago

Ok, maybe you do not need to modify the network but only the placeholder, which is the way the computational graph gets his input (4d instead of 3d). Then the problem will be in loading the checkpoint. You can load the checkpoint and assign values to variables excluding conv1 variables (similarly as done already for logits). The assignment for conv1 should be done manually because of size mismatch.

Scienceseb commented 4 years ago

Ok so self.images = tf.placeholder(tf.float32, [None, 224, 224, 4], modality + '_images') effectively change the 3 to a 4: rgb/resnet_v1_50/conv1/weights:0 (float32_ref 7x7x4x64) [12544, bytes: 50176]

but like you said the problem is with the checkpoint now:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [7,7,4,64] rhs shape= [7,7,3,64][[{{node save/Assign_264}}]]

so according to what you said previously I have to change the code in the def single_stream like you did with logits ? But what you mean by conv1 should be done manually because of size mismatch how can I do that?

Scienceseb commented 4 years ago

For exemple when I do that:

 net = end_points[modality + '/resnet_v1_50/conv1']
 with tf.variable_scope(modality + '/resnet_v1_50', reuse=tf.AUTO_REUSE):
     net = slim.conv2d(net, 64, [7, 7], activation_fn=None, scope='conv1')

I got the following error ValueError: Trying to share variable rgb/resnet_v1_50/conv1/weights, but specified shape (7, 7, 64, 64) and found shape (7, 7, 4, 64)...it's because /resnet_v1_50/conv1 give 64 filters but how can I initialize it with 4 as the value when it's the first convolution of the network...the input of the first convolution is the image so it's 4...but when I do :

net = slim.conv2d(images, 64, [7, 7], scope='conv1')

I got tensorflow.python.framework.errors_impl.NotFoundError: Key rgb/resnet_v1_50/conv1/biases not found in checkpoint [[{{node save/RestoreV2}}]].

So I'm clearly doing something wrong and to be frank I don't really know what I'm doing, I just tried all the day and nothing is working.

pmorerio commented 4 years ago

No. The computational graph for the net is ok. So you must not modify the file model.py anymore. The problem is in solver.py when you restore the checkpoint. You cannot automatically restore the weights for conv1. So you must exclude that tensor from variables_to_restore. Now the first layer is randomly initialized and the code should run. If you want to assign the pretrained value for the 3 channels you can do it manually by reading that tensor from the checkpoint and assigning it to the rgb part of conv1.

Scienceseb commented 4 years ago

Yep this is working thanks a lot, what do you suggest to normalize the depth channel a value of 122.5 (for the mean)?

pmorerio commented 4 years ago

Be careful. It is working but conv1 is randomly initialized at the moment. Concerning the depth channel (which I guess is your 4th channel) yes, I would suggest so.

Scienceseb commented 4 years ago

Do you have any idea of why when I train_rgb with the 4th channel all at zero (so basicaly its just RGB), I got after like 10 epochs: rgb train acc [0.9378] rgb test acc [0.2549]...? Is it because of the weights randomly initialized ?

pmorerio commented 4 years ago

Yes, I think so.