google / objax

Apache License 2.0
768 stars 77 forks source link

Add EfficientNet to the objax.zoo #60

Open david-berthelot opened 3 years ago

david-berthelot commented 3 years ago

EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks https://arxiv.org/abs/1905.11946

sathish-a commented 3 years ago

Hi David, Can I work on this too?

david-berthelot commented 3 years ago

Sure, I assigned it to you.

sathish-a commented 3 years ago

Hi David, I have a doubt regarding the kernel initializer - VarianceScaling used in the Keras implementation of the EfficientNet model. If you see they have this mode called fan_out where they do the product of shape only till shape[:-2]. Whereas, in Objax the default calculation is in mode fan_in and we do product till shape[:-1]. Should I use the same Objax function or do I need to implement one? I'm a bit confused here.

rwightman commented 3 years ago

@sathish-a I implented EfficientNets in PyTorch and the weight init is important as it alters the behaviour training from scratch. In my original attempt I messed up the fan in/out calc due to discrepencies in the group (depthwise) packing between TF and PyTorch. I'd try to match the original TPU Tensorflow impl.

It should be noted the that the original TPU TF impl (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) doesn't use truncated normal, just normal. That will likely have less impact that messing up the fan_in/out calc, especially for the depthwise convs where the fan_out is calculated as 1.

david-berthelot commented 3 years ago

@sathish-a You can make your own initializer:

def my_initializer(shape: Tuple[int, ...]) -> JaxArray:
     return objax.nn.init.truncated_normal(shape, stddev=np.sqrt(2 / shape[-1]))

Eventually (maybe in a separate future PR), if there's a need for a fanout initializer, we could add it in (objax.nn.init) and design it similarly to kaiming_truncated_normal or kaiming_normal.

For now, I would suggest not to worry about the elegance of the design, so you could technically put everything in efficient_net.py for now, we can go through it through the review process.

sathish-a commented 3 years ago

@rwightman and @david-berthelot Thanks for your valuable inputs. Which one can I follow the original TPU implementation or Keras Application? I felt that the latter one was much more clear.

rwightman commented 3 years ago

Keras ver is easier to follow, but they do deviate slightly from the original so you should be aware of that. Some notes by the author there: https://github.com/keras-team/keras-applications/pull/113 ...also, the stride=2 padding is handled with an explicit pad layer + valid instead of using same, not quite sure why as I figured Keras 'same' padding should match TF.

Also, the keras impl does not have defs for B8, L2, lite, edge-tpu models, and they only feature the auto-augment weight sets, not the better noisy-student and advprop weights.

david-berthelot commented 3 years ago

From @rwightman feedback, I would be more inclined to use TensorFlow as the true reference.

sathish-a commented 3 years ago

I started to work on porting TF implementation. In the TF repo, I could see so many options available for users to customize the model to obtain different variants. Often I get lost in finding the right set of options for the standard EfficientNet variant. Do I need to support all this customization as well? I require some inputs regarding this and how to approach it?

rwightman commented 3 years ago

@sathish-a there is a lot of customization in their MBConv block, it's hard to follow. I'd recommend referencing the Keras model structurally and the official Tensorflow TPU impl for the details, like initialization, padding, etc. I'd assume for Objax you'd want something that ends up looking more like the Keras or PyTorch variants than the original TF1 style code.

My PyTorch models also match the impl in Tensorflow TPU down to weight compatibility with roughly .1-2% top-1 differences but remain OO and easier to follow. I factored out CondConv and EdgeTpu into different block variants. I covered all the major variants with the exception of the newer TPU optimized versions that have fused_conv and space2depth options enabled (called EfficientNet-X), I don't think there are weights for this released yet.

david-berthelot commented 3 years ago

One thing to consider is to be compatible with an implementation that offers pre-trained weights that we could later import in Objax. So using either Keras or PyTorch works, I'm not sure which I would pick.

sathish-a commented 3 years ago

Hi, @david-berthelot I'm almost done with the implementation. Many thanks to @rwightman his PyTorch implementation had helped me a lot. And can you guide me on the porting of model weights from PyTorch / Keras? Sorry for the delay, I'm working on this in my free time.

david-berthelot commented 3 years ago

Hi @sathish-a and thanks for the update. @kihyuks has ported some models from PyTorch and Keras to Objax. Here's some code he contributed to automatically import pre-trained models from Keras (load_pretrained_weights_from_keras): https://github.com/google/objax/blob/master/objax/zoo/resnet_v2.py#L428

Basically the methodology would be similar for PyTorch (reshape variables if needed, map the variables names).

@rwightman also mentioned he was working on porting some models between PyTorch and Objax but I don't have specifics.

rwightman commented 3 years ago

@david-berthelot @sathish-a I had been working in fits and spurts on some JAX models. It does overlap somewhat with this, but has a different focus and I'm not sure exaclty at this stage where my models will end up.

I've implemented all of the models in the same family as EfficientNet https://github.com/rwightman/efficientnet-jax

For a canonical Objax model to include here, @sathish-a can likely shorten his time by leveraging the weights I've ported as it was a big pain to bring them over from Tensorflow to PyTorch (lots of inconsistent layer name indexing in typical TF fashion). Converting my PyTorch weights to a generic npz format (not unlike the npz format used here) was much quicker.

if you do use the weights / conversion a mention / link would be appreciated. Quite a bit of time has gone into getting those weights and models to this point. Based on past experience, there will likely be quite a few gotchas going from 'I think the models are done', to they ARE done and they work with the original weights to original accuracies.

sathish-a commented 3 years ago

@david-berthelot, @rwightman's efficientnet-jax repository looks pretty interesting and well built. My version of the model is more aligned with @rwightman's PyTorch implementation and only supports the basic variant of EfficientNet. Should I continue working on this? Let me know what you think.

rwightman commented 3 years ago

@sathish-a I built my models with my own aims in mind, mentioned in that repo. There is value in having a concise Objax version that is part of this code base, maintained in sync. My version is not that, it has wider coverage of models than EfficientNet variants, a more complex builder to achieve that, multiple files, variations that cover PyTorch and TF weight origins, support for a Flax Linen backend in addiiton to a (tweaked) Objax.

If you are close to finished the models, the weights I've ported should accelerate your effort as I've spent lots of time in the past between 'model done' and porting the weights + making the results match. Time is often wasted sorting out differences in padding, batchnorm epsilons, mixed up activation types, default image interpolations, etc.

david-berthelot commented 3 years ago

The only constraints from my perspectives are:

  1. It should be functionally an EfficientNet
  2. Its performance should match of existing implementations so people can use it for experimenting for their publications, and pre-trained models performance should match too.
  3. Keep the code simple

That's all, you have absolute freedom within those constraints.