huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
31.86k stars 4.73k forks source link

[FEATURE] TrivialAugment Integration #780

Closed SamuelGabriel closed 1 year ago

SamuelGabriel commented 3 years ago

Is your feature request related to a problem? Please describe. So far you have integration for AutoAugment and RandAugment in the repo as well as some settings that seem to be inspired by UniformAugment and/or TrivialAugment (see https://github.com/rwightman/pytorch-image-models/commit/79640fcc1f72241431b534420f2f7c9868157e93). I am the author of TrivialAugment (https://github.com/automl/trivialaugment), which was just accepted as Oral for ICCV, and used your repository for my latest experiments on EfficientNet-B1, where I saw improvements over the more complicated methods, like in the previous experiments. The current augmentation pipeline has a lot of choices, which makes it difficult to get a training going, as these have to be set first. TrivialAugment, at least in my comparison with my own RandAugment implementation, outperforms RA on EffNet-B1 without hyper-parameters. Of course, this is me pushing my own baseline for this repo (:D), but I would also totally be up to provide some not too expensive experiments to show whether an integration of TA actually makes sense.

Describe the solution you'd like I would like to stick to your general augmentation setup and use what is there with the TrivialAugment algorithm. This would only mean very few lines of changes, as it is almost implemented already. One could also integrate the auglib from https://github.com/automl/trivialaugment which is what I tested so far.

Describe alternatives you've considered RA and AA have hyper-parameters and performed worse in my experiments.

What would you think?

mrT23 commented 3 years ago

@SamuelGabriel Got to know your article from this Issue. that a very nice paper you have :-) I always thought that the "auto" part in the auto-augmentation algorithms is unreliable, to say the least. The main force there that led to improvement was the prior (human) selection of the possible augmentations.

I am happy you are getting an oral in ICCV. i hope your work will also inspire more scrutiny into other "auto" algorithms, which are more based on human tuning and knowledge than on real contribution from the automatic parts.

rwightman commented 3 years ago

@SamuelGabriel thanks for bringing TrivialAugment to my attention, I took a quick look through the code and paper. There is a lot of overlap with RandAugment, especially with some of the 'fixes' and modifications I've made to my impl here. I use gaussian noise w/ specified std dev for the magnitude for almost all of my RA training. That's been implemented for close to 2 years now. Someone else added the PR for uniform more recently.

Uniform min/max range for the augmentations appears to be the key difference for TA? Are there any other key differences I'm missing?

I'm open to examples of TA being better and/or worth consideration over my usual RA pipeline. If it can beat training w/ (--aa rand-m9-mstd1.0-inc1 --reprob 0.4 --remode pixel) or similar for resnet50 efficientnetb0/b1 level models I'd consider adding TA. I agree that AA was a bit silly, but RA is pretty low tuning and works quite well. There were numerous 'bugs' in the original TF impls though that made adjusting the magnitudes problematic (ie no consistency across the augs when moving from one M value to another w/ the official impl).

I'd point out compared to papers, w/ my RA pipeline I've trained ResNet50 (vanilla) to 79.5, and ResNet50-D to 80.5 @ 224x224 and those scale really well with higher test res. That said, I can train a vanilla R50 to 78 w/o any augs besides standard 'inception' preprocessing and some extra reg. So the results for TA in paper are somewhat 'in the noise' as I see it so far.

For EfficientNet, I've not actually bothered to train the B1 'well', it's an old set of hparams, but my B3 and more recent V2 models are good. So if you can match those I'd be interested. You'll note my B3 is above the official w/ AdvProp, it can't hit NoisyStudent B3 levels but that's expected.

SamuelGabriel commented 3 years ago

Hi @rwightman, Thanks for the long answer. I saw your fixes in the pipeline when I used it to run EffNet trainings and found it interesting you came up with similar fixes. I think the usefulness of TA could be that it simplifies things, as there are no important hyper-parameters for augmentations anymore.

I would like to try and see how TrivialAugment works in your RN-50 setup, like you said. I never tried using 'fixing the imagenet resolution' tricks, but I am open to. Do you have a full specification of the arguments around so that I can easily change the relevant things only or do you talk as a first step about the one from the training examples (https://rwightman.github.io/pytorch-image-models/training_hparam_examples/)?

I am not sure about the costs of a B3 training, but I can try, since you changed a lot of things during training in the example trainings, I am not sure what hyper-parameters you prefer here. I'd be glad if you could share some with me, I can just run a baseline, too, if these trainings are not too expensive.

rwightman commented 3 years ago

@SamuelGabriel I don't believe you've shown that TA is 'hyper-parameter free', the experiments in the paper max out at ResNet-50 on ImageNet with fairly minor improvements.

Being able to adjust the magnitude and # of layers, (and in timm's RA impl, the random sampling of the magnitude) is important as you train on models and datasets of different sizes. I don't believe you've show that TA would actually work well in the cases where you need to increase M and N for RA. Try it on a 200-300 layer ResNet, or a really large EfficientNet, or a vision transformer. NFNets are recent neworks with a fairly specific optimal hparam range that relies on RA w/ more aug layers than usual (4 instead of 2). I'd recommend trying TA as is on larger models.

SamuelGabriel commented 3 years ago

Hi @rwightman, Thanks for the second reply as well.

A quick update on the requested experiments:

I am right now training a big model: EfficientNet-V2-M with hyper-parameters inspired by your https://gist.github.com/rwightman/e69d5f456047c16773a77182cea68c3c (just changed the lr with linear scaling for my batch size). I am not sure if that is what you meant, the RA baseline uses m=6 layers. I can also add a training with the NFNet setting form that gist. I already trained an R50 with the setup mentioned in https://rwightman.github.io/pytorch-image-models/training_hparam_examples/, where I just replace RA with TA. Here, I get 79.093 final accuracy. Slightly better than the RA model in your documentation but likely within noise. I didn't run RA yet in the setting I used, though. The difference: I used a batch size of 2048 with linearly scaled lr, while you used 128.

rwightman commented 3 years ago

@SamuelGabriel yes, something like the V2-M using the same hparams would be an interesting data point

m=6 is magnitude 6, the # layers is n which is 5 there.

The RA resnet50 model that's current default in timm was trained with RA + consistency loss (modified augmix). It was a good result at the time but I've eclipsed it with pure RA + other regularization and no consistency loss. Currently 79.4-79.5 for rmsprop based hparams evolved from some of the older ones posted, and 79.7-79.8 for something closer to the AGC + SGD you see in the NFNet and EffNetV2 gists...

SamuelGabriel commented 3 years ago

@rwightman Ok, good! Oh yes, of course! I should know that :D The first runs of V2-M are done. My config diff for the RA version on 32 workers is

< batch_size: 256
---
> batch_size: 16
> lr: 0.032
28c29
< epochs: 656
---
> epochs: 350
39d39
< lr: 0.5
52c52
< model: efficientnet_v2m
---
> model: efficientnetv2_rw_m

The epochs are 350 in the paper and the model used in the gist was not available in timm. Besides that, I only scaled the learning rate. I suppose these changes should be ok, no? I received scores much below the paper, though. Still, TA is outperforming the RA setup out-of-the-box.

Aug Method Final Top-1 Acc. as in the summary.yml
RA from the gist 81.61
TA 81.78

The difference is more than I can see in variance in the accuracies of the last epochs, but do you have any idea why the performance is so sub-par? Another thing I saw is that the Model actually seemed to train for 360 epochs (as in summary.yml), even though the config was set differently. (Sorry to bother you with this..)