gmalivenko / pytorch2keras

PyTorch to Keras model convertor
https://pytorch2keras.readthedocs.io/en/latest/
MIT License
858 stars 143 forks source link

Porting inception based architectures (inception v4) #2

Open snakers4 opened 6 years ago

snakers4 commented 6 years ago

Hi, once again thanks for awesome work. It really helps with shortening pytorch to production path (pytorch=>keras=>tf).

I was running my pytorch model based on inception4 architecture.

I encountered this error during the run:

ValueError: Unsuported padding size for convolution

I guess this is due to my architecture being an inception4 architecture - it has non-symmetric filters and this is solved differently in keras and pytorch. In keras it is solved via 'same' convolutions, in pytorch - via different paddings. Investigating deeper - I indeed found out that conv layers with assymmetric padding in pytorch are the culprits.

Just compare these links:

I overcame this by essentially a hack, but maybe you will give some advice / commit on how to do it properly?

In a nutshell I did this. I can do a PR with this change, if you want.

        if node.padding[0] != node.padding[1]:
            # originally this line was not commented
            # raise ValueError('Unsuported padding size for convolution')

            # quick fix for inception architectures
            # refer here for more info https://github.com/fchollet/keras/blob/master/keras/applications/inception_v3.py
            border_mode = 'same'
        else:
            # this code initially was under no condition
            padding = node.padding[0]
            if padding > 0:
                padding_name = output_name + '_pad'
                padding_layer = keras.layers.ZeroPadding2D(
                    padding=node.padding,
                    name=padding_name
                )
                layers[padding_name] = padding_layer(layers[input_name])
                input_name = padding_name      

            # this line below also was applied unconditionally
            border_mode = 'valid'

Anyway - which proper solution would you suggest for this edge case?

snakers4 commented 6 years ago

After fixing the concat bug, I managed to convert my model. The whole model gives reasonable discrepancy (sotmax being the last layer), but I still suspect it is not enough given your really small discrepancies in your models w/o softmax.

So to check what is wrong I tested just the inception4 encoder conversion with the above hack. It produced this

Max error: 4.279047966003418

I used this inception implementation:

The encoder that I am using:

class InceptionEncoder(nn.Module):
    def __init__(self,
                 inception):

        super(InceptionEncoder, self).__init__()

        self.inception_extractor = InceptionExtractor(inception)    

    def forward(self, x):
        x,x2,x3,x4 = self.inception_extractor(x)
        out = torch.cat( (x,x2,x3,x4), dim=1)
        return out

class InceptionExtractor(nn.Module):
    def __init__(self,
                 inception):
        super(InceptionExtractor, self).__init__()
        self.stem = nn.Sequential(
            inception.features[0],
            inception.features[1],
            inception.features[2],
            inception.features[3],
            inception.features[4],
            inception.features[5],
        )   
        self.inception1 = inception.features[6]
        self.inception2 = inception.features[7]
        self.inception3 = inception.features[8]
        self.inception4 = inception.features[9]    
    def forward(self, x):
        x = self.stem(x)
        x = self.inception1(x)
        x2 = self.inception2(x)
        x3 = self.inception3(x2)
        x4 = self.inception4(x3)        
        return x,x2,x3,x4  

The params I am using to invoke the model assuming the above inception4 implementation

inception4 = inceptionv4(num_classes=1000, pretrained='imagenet')
model = InceptionEncoder(inception4,**model_params)

So, I assume that my hack for Inception architectures is somewhat buggy / needs to be fixed somehow. Any help / advice appreciated.

gmalivenko commented 6 years ago

Hello, @snakers4.

I've tested the Inception v4 today, It seems more or less accurate but not all PyTorch parameters have Keras/TF equivalent. It seems like AvgPooling parameter count_include_pad may cause wrong result with padding.

I also tested your InceptionEncoder too (with average error ~2.91038e-10), you can check it out right there.

aveysov commented 6 years ago

Hi @nerox8664,

Many thanks for your effort and support! Pulled your repo anew.

Did the following tests:

So now I can reasonably assume that inception base works just fine, and there is no issue with the encoder part - only my model produces worse results, and only after loading weights.

Unfortunately, I cannot share the weights, but maybe you can give a hint based on model architecture?

class IlgSimplifiedNormalized(nn.Module):
    def __init__(self,
                 inception,
                 num_classes=2,
                 num_skip1=256,
                 num_skip2=256,
                 num_skip3=256,
                 num_skip4=256,
                 num_filters=256,
                 num_fmap=7):

        super(IlgSimplifiedNormalized, self).__init__()

        self.inception_extractor = InceptionExtractor(inception)    
        self.inception_connectors = InceptionConnectors(num_skip1,
                                                         num_skip2,
                                                         num_skip3,
                                                         num_skip4,
                                                         num_filters,
                                                         num_fmap)
        self.classifier = nn.Sequential(
            nn.Linear(num_skip1+num_skip2+num_skip3+num_skip4, 1024),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(1024, num_classes),
        )
    def forward(self, x):
        x,x2,x3,x4 = self.inception_extractor(x)
        out = self.inception_connectors(x,x2,x3,x4)
        out = self.classifier(out)          
        return out 

class InceptionConnectors(nn.Module):
    def __init__(self,
                num_skip1=256,
                num_skip2=256,
                num_skip3=256,
                num_skip4=256,
                num_filters=256,
                num_fmap=7):
        super(InceptionConnectors, self).__init__()

        self.inception_connector1 = InceptionConnector(num_skip=num_skip1,
                                                        num_filters=num_filters,
                                                        num_fmap=num_fmap)
        self.inception_connector2 = InceptionConnector(num_skip=num_skip2,
                                                        num_filters=num_filters,
                                                        num_fmap=num_fmap)        
        self.inception_connector3 = InceptionConnector(num_skip=num_skip3,
                                                        num_filters=num_filters,
                                                        num_fmap=num_fmap)
        self.inception_connector4 = InceptionConnector(num_skip=num_skip4,
                                                        num_filters=num_filters,
                                                        num_fmap=num_fmap)

    def forward(self,x,x2,x3,x4):
        x1_out = self.inception_connector1(x)
        x2_out = self.inception_connector2(x2)
        x3_out = self.inception_connector3(x3)
        x4_out = self.inception_connector4(x4)
        out = torch.cat((x1_out,x2_out,x3_out,x4_out), dim=1)
        return out

class InceptionConnector(nn.Module):
    def __init__(self,
                num_skip=256,
                num_filters=256,
                num_fmap=3):
        super(InceptionConnector, self).__init__()

        self.ae_block = nn.Sequential(
            nn.AvgPool2d(kernel_size=5, stride=3, padding=0),
            nn.Conv2d(384, num_filters, kernel_size=1, stride=1),
            nn.BatchNorm2d(num_filters),
            nn.ReLU(inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(num_filters*num_fmap*num_fmap, num_skip),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.ae_block(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x 

class InceptionExtractor(nn.Module):
    def __init__(self,
                 inception):
        super(InceptionExtractor, self).__init__()
        self.stem = nn.Sequential(
            inception.features[0],
            inception.features[1],
            inception.features[2],
            inception.features[3],
            inception.features[4],
            inception.features[5],
        )

        self.inception1 = inception.features[6]
        self.inception2 = inception.features[7]
        self.inception3 = inception.features[8]
        self.inception4 = inception.features[9]

    def forward(self, x):
        x = self.stem(x)
        x = self.inception1(x)
        x2 = self.inception2(x)
        x3 = self.inception3(x2)
        x4 = self.inception4(x3)        
        return x,x2,x3,x4    

PS Do you have some kind of account for tips? I would be more than happy to donate a bit to provide further incentive to maintain and develop this repository.

aveysov commented 6 years ago

Also, this just may be due to the fact that my model is just binary classification, and for random data it is very confident in one class...

gmalivenko commented 6 years ago

Hello, @aveysov.

How are you using a IlgSimplifiedNormalized class? As I can see there is nn.Dropout,

        self.classifier = nn.Sequential(
            nn.Linear(num_skip1+num_skip2+num_skip3+num_skip4, 1024),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(1024, num_classes),
        )

May be a high error rate is caused by nn.Dropout(p=0.5)? Do you call your_model.eval() before convertation?

snakers4 commented 6 years ago

How are you using a IlgSimplifiedNormalized class?

Please see the snippet below. Inception4Cadene - is just the above inception4 original code. IlgSimplifiedEnsemble - is just the above model.

import keras  # work around segfault
import sys
import numpy as np
import torch.nn.functional as F

import torch
import torch.nn as nn
from torch.autograd import Variable

sys.path.append('../pytorch2keras')
from converter import pytorch_to_keras

# import my models the same way as in training loop 
from Inception4Cadene import inceptionv4
from IlgSimplifiedEnsemble import IlgSimplifiedNormalized

if __name__ == '__main__':
    max_error = 0
    for i in range(2):
        inception4 = inceptionv4(num_classes=1000, pretrained='imagenet')

        model = IlgSimplifiedNormalized(inception4,
                                        num_classes=2,
                                        num_skip1=256,
                                        num_skip2=256,
                                        num_skip3=256,
                                        num_skip4=512,
                                        num_filters=256,
                                        num_fmap=11)        

        saved_weights = '../some_folder/some_weights.pth.tar'

        print("=> loading checkpoint '{}'".format(saved_weights))
        checkpoint = torch.load(saved_weights)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint (epoch {})"
              .format(checkpoint['epoch']))        

        model.eval()

        input_np = np.random.uniform(0, 1, (4, 3, 299, 299))
        input_var = Variable(torch.FloatTensor(input_np))
        output = model(input_var)

        k_model = pytorch_to_keras((3, 299, 299,), output)

        pytorch_output = output.data.numpy()
        keras_output = k_model.predict(input_np)

        error = np.max(pytorch_output - keras_output)
        print(error)
        if max_error < error:
            max_error = error

    print('Max error: {0}'.format(max_error))

Do you call your_model.eval() before convertation?

Yes, of course, I do, see the snippet above.

May be a high error rate is caused by nn.Dropout(p=0.5)?

Well, in this case - binary classification + random image that really may be the case, but afaik dropout is disabled when you use .eval().