pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.93k stars 6.91k forks source link

mnasnet training schedule #1163

Open a-maci opened 5 years ago

a-maci commented 5 years ago

What is the training schedule for MNASNet 1.0 that gives 73.5 top-1? I tried the mobilenet schedule but it not giving the same accuracy as mentioned in torchvision page.

On a side note, mnasnet paper reports 74% top1 but torchvision model is 73.5% top-1.

fmassa commented 5 years ago

All the discussion regarding MNasNet training can be found in https://github.com/pytorch/vision/pull/829, and has been contributed by @1e100

Dmitry can you give some more details on what was the final set of hyperparameters that you used to train MNasNet?

1e100 commented 5 years ago

Something along the lines of https://github.com/1e100/mnasnet_trainer/blob/master/train.py. To sustain this batch size you will need a quad-GPU machine. This training code was extracted from my own private experimental training pipeline. There's a validator there as well, you can run the model against imagenet and see what accuracy you get. It won't be exactly as quoted, but it should be close.

Note that there are a couple of bugs in the current implementation of MNASNet:

  1. There's not supposed to be a ReLU after the first conv. This one affects all models, but, curiously, seems to make them a little easier to train. Maybe we should keep it. :-)
  2. The layers before the repeated blocks are supposed to scale in accordance with the width multiplier. This does not affect the case where width_multiplier == 1.0.

I have a patch that fixes both of these and I'm currently training the checkpoints in order to be able to submit a PR. Progress is slow, I'm not quite there yet. I'm currently exploring the tricks @rwightman has used to get his results. I'm focusing on 0.5 which was the hardest one to train last time. I figure if that one works, 1.0 will be a piece of cake. At the moment I got 0.5 up to 67.1% top1.

Speaking of which, @fmassa, if you guys are going to cut a release anytime soon, it might be best to either unpatch my PR (and then patch it back in with new checkpoints) or take my fix (removing checkpoints and adding them when they are available). I have carefully compared Google's official models with my own in Netron side by side, they match exactly now.

Let me know which you prefer.

EDIT: as pointed out by @rwightman, 1# on that list is not a bug. It's the same as the reference TFLite implementation.

rwightman commented 5 years ago

@1e100 where did the idea that there is no relu after first conv come from? It's right here: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py#L433

The stem and first depthwise separable are supposed to scale with multiplier as you say though.

1e100 commented 5 years ago

They must have noted the same thing I did and added it. If you look at their official checkpoints at https://www.tensorflow.org/lite/guide/hosted_models, they don't have a ReLU there. They look like this:

no_relu

which tripped me up because I'm used to seeing activation after all batchnorms.

Now that I think about it, maybe their checkpoints are buggy then, and my model was "more correct". Thanks for pointing this out!

rwightman commented 5 years ago

I think the graph optimization performed when going to the tflite models usually involves folding batch norm and activation together with the conv/bias. So it's likely in there...

I haven't looked at the original mnasnet checkpoints in a while, but when I did my original impl I verified I could load the TPU Tensorflow checkpoints in my PyTorch A1 & B1 models with reasonable accuracy.

1e100 commented 5 years ago

This is an unoptimized TF graph though.


From: Ross Wightman notifications@github.com Sent: Thursday, July 25, 2019 14:41 To: pytorch/vision Cc: Dmitry Belenko; Mention Subject: Re: [pytorch/vision] mnasnet training schedule (#1163)

I think the graph optimization going to exported tflite models usually involves folding batch norm and activation together with the conv/bias. So it's likely in there...

I haven't looked at the original mnasnet checkpoints in a while, but when I did my original impl I verified I could load the TPU Tensorflow checkpoints in my PyTorch A1 & B1 models with reasonable accuracy.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHubhttps://github.com/pytorch/vision/issues/1163?email_source=notifications&email_token=AJGPPWXFKLW4DXDTVYF4G4DQBIMW5A5CNFSM4IGV4B32YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD223NLY#issuecomment-515225263, or mute the threadhttps://github.com/notifications/unsubscribe-auth/AJGPPWU4IMBDIQBECNDDLBLQBIMW5ANCNFSM4IGV4B3Q.

1e100 commented 5 years ago

1224 contains the fix for MNASNet models other than 1.0. 1.0 remains the same, so the old checkpoint still works. I'm currently working on a better checkpoint for 1.0. The PR contains a slightly better checkpoint for 0.5 though. Pity it was late for 0.4 release train, but better late than never.