keras-team / keras-applications

Reference implementations of popular deep learning models.
Other
2k stars 913 forks source link

EfficientNet Implementation #113

Closed Callidior closed 4 years ago

Callidior commented 5 years ago

This pull-request is a translation of the reference implementation of EfficientNet (ICML 2019) from Tensorflow to Keras.

There are two deviations from the description of EfficientNet in the paper and one deviation from the reference implementation:

Pre-trained weights have been converted from the Tensorflow checkpoints trained with AutoAugment provided by the authors (which perform better than reported in the paper) and achieve accuracies comparable to those stated by them:

Model Reported in Paper Official Checkpoint Converted Weights
B0 76.3% 77.3% 77.2%
B1 78.8% 79.2% 79.1%
B2 79.8% 80.3% 80.2%
B3 81.8% 81.7% 81.6%
B4 82.6% 83.0% 83.0%
B5 83.3% 83.7% 83.7%
B6 84.0% 84.2% 84.1%
B7 84.4% 84.5% 84.4%

The weights are currently hosted on my GitHub repository and will be downloaded automatically by the EfficientNet implementation. Upon merge, however, it would be reasonable to transfer them to the keras-team/keras-applications repository.


To Do

Callidior commented 5 years ago

To obtain the validation accuracies reported above, I employed the same pre-processing as the reference implementation, using the following code:

import numpy as np
import keras
from keras.preprocessing.image import load_img, img_to_array
from keras_applications.efficientnet import preprocess_input
import skimage.transform

def center_crop_and_resize(image, image_size, crop_padding=32):

    h, w = image.shape[:2]
    padded_center_crop_size = int((image_size / (image_size + crop_padding)) * min(h, w))
    offset_height = ((h - padded_center_crop_size) + 1) // 2
    offset_width = ((w - padded_center_crop_size) + 1) // 2
    image_crop = image[offset_height:padded_center_crop_size + offset_height,
                       offset_width:padded_center_crop_size + offset_width]
    return skimage.transform.resize(image_crop, (image_size, image_size),
                                    order=3,
                                    mode='reflect',
                                    anti_aliasing=True,
                                    preserve_range=True,
                                    clip=False)

def load_img_for_eval(filename, size=224):

    return preprocess_input(center_crop_and_resize(img_to_array(load_img(fn)), size))
Callidior commented 4 years ago

The authors have released new checkpoints for all models (including the previously missing B6 and B7) trained with AutoAugment. The performance of these models is generally better than what has been reported in the paper. I have converted these new checkpoints to keras weight files and updated the implementation to use the new weights trained with AutoAugment. Thus, we now have pre-trained weights available for all EfficientNet variants.

wingman-jr-addon commented 4 years ago

As a user of Keras, I'm excited to see this get added to the official set of applications. It's not clear to me from this issue that there is anything holding it up from being merged in - what else is left to do before it's ready?

Callidior commented 4 years ago

I guess it would just need one of the maintainers to pay attention to it, but nothing has happened on this repository since July.

wingman-jr-addon commented 4 years ago

I suspect it has something to do with a focus on TF 2.0. @taehoonlee it appears that you've been active in the community here. Do you happen to have any suggestions on the next steps for this PR?

taehoonlee commented 4 years ago

@Callidior, @wingman-jr-addon, Sorry for the late response. Since you and many users have been waiting so long, I revised the PR as much as possible instead of reviewing it. The fails in Travis seem to be caused by the official release of TensorFlow 2.0. Thus, I checked the functionality of the revised PR in my local with TF 1.x. In results, the original PR and the revised one can produce the same inference results.

@Callidior, I wonder if you checked with CNTK, and thank you for your great PR!

Callidior commented 4 years ago

@taehoonlee Thanks for your revision! I didn't test it with the CNTK backend explicitly, but trusted the positive results of Travis in this regard. The initial PR at least passed the CNTK tests in Travis.

taehoonlee commented 4 years ago

@Callidior, And did you check with Theano? The initial PR passed Travis as you mentioned, but it failed in my local (numpy==1.16, theano==1.0 or 1.0.4). As you know, keras-applications should support the three backends.

taehoonlee commented 4 years ago

The error was about the pattern_broadcast in the Lambda layer.

Callidior commented 4 years ago

I've tested it with all three backends now and fixed one warning along the way relating to the output_shape of the broadcasting layer that is needed for Theano. As a sanity check, I tried classifying the default tests/data/elephant.jpg image with EfficientNet-B2 and got African_elephant as top prediction with all three backends, but with different confidence. With the TensorFlow backend, this class is predicted with 75.5% confidence, but only with 68.1% with both of the other backends. I have no idea so far where this difference comes from, especially since the code is exactly the same for TensorFlow and CNTK.

taehoonlee commented 4 years ago

@Callidior, Thanks for the checks and the revision. The reason for the discrepancy over three backends is those different behaviors of Conv2D(strides=2, padding='same'), and I've been resolved the issue now. For more information, you can refer to my notes. My test codes are:

# test_cat.py
import keras
import numpy as np

from tensornets.utils import load_img
from keras_applications.efficientnet import EfficientNetB0
from keras_applications.efficientnet import preprocess_input, decode_predictions
kwargs = {'backend': keras.backend, 'layers': keras.layers, 'models': keras.models, 'utils': keras.utils}

model = EfficientNetB0(weights='imagenet', **kwargs)
img = load_img('cat.png', target_size=256, crop_size=224)
preds = model.predict(preprocess_input(img, **kwargs))
print(decode_predictions(preds, top=3, **kwargs)[0])

The results are:

$ python test_cat.py
$ CUDA_VISIBLE_DEVICES= KERAS_BACKEND=theano python test_cat.py
$ CUDA_VISIBLE_DEVICES= KERAS_BACKEND=cntk python test_cat.py
# before
[('n02124075', 'Egyptian_cat', 0.25851214), ('n02123045', 'tabby', 0.06481709), ('n02127052', 'lynx', 0.057544794)]
[('n02124075', 'Egyptian_cat', 0.23895964), ('n02123394', 'Persian_cat', 0.121028826), ('n02123159', 'tiger_cat', 0.0978675)]
[('n02124075', 'Egyptian_cat', 0.2389593), ('n02123394', 'Persian_cat', 0.12102842), ('n02123159', 'tiger_cat', 0.09786791)]

# after
[('n02124075', 'Egyptian_cat', 0.25851214), ('n02123045', 'tabby', 0.06481709), ('n02127052', 'lynx', 0.057544794)]
[('n02124075', 'Egyptian_cat', 0.25851145), ('n02123045', 'tabby', 0.064817175), ('n02127052', 'lynx', 0.057544947)]
[('n02124075', 'Egyptian_cat', 0.2585115), ('n02123045', 'tabby', 0.06481718), ('n02127052', 'lynx', 0.057544786)]
wingman-jr-addon commented 4 years ago

(Thanks @taehoonlee and @Callidior - I've been excited to see the progress!)

Callidior commented 4 years ago

Thanks a lot @taehoonlee ! I can confirm from my side that everything is working now as it should.

taehoonlee commented 4 years ago

@fchollet, a new model, EfficientNet, is almost ready. I wonder if you have any concerns (e.g., compatibility with the recent TF) about adding a new model. As you may know, it was published in ICML 2019 and is the current SOTA. I think it is worth adding EfficientNet.