huggingface / pytorch-image-models

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
https://huggingface.co/docs/timm
Apache License 2.0
32.05k stars 4.75k forks source link

[FEATURE] Support Tied-Augment #1828

Open ekurtulus opened 1 year ago

ekurtulus commented 1 year ago

Recently, we introduced Tied-Augment, a simple framework that combines self-supervised learning learning and supervised learning by making forward passes on two augmented views of the data with tied (shared) weights. In addition to the classification loss, it adds a similarity term to enforce invariance between the features of the augmented views. We found that our framework can be used to improve the effectiveness of both simple flips-and-crops (Crop-Flip) and aggressive augmentations (RandAugment) even for few-epoch training. As the effect of data augmentation is amplified, the sample efficiency of the data increases.

I believe Tied-Augment would be a nice addition to Timm training script. It can significantly improve mixup/RandAugment (77.6% → 79.6%) with marginal extra cost. Here is my reference implementation.

pdedeler commented 1 year ago

👍🏻 It would be great if you can implement Tied-Augment

rwightman commented 1 year ago

@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?

ekurtulus commented 1 year ago

@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?

@rwightman Yes, however, the only difference is that Tied-Augment requires the features of the augmented views. Therefore, an additional wrapper has to be put for the model as well.

Example (for a Timm model with num_classes=0)

class TimmWrapper(nn.Module):
    def __init__(self, model, num_classes):
        super(TimmWrapper, self).__init__()
        self.model = model
        self.fc = nn.Linear(model.num_features, num_classes)

    def forward(self, x, return_features=False):
        if self.training or return_features:
            features = self.model(x)
            logits = self.fc(features)
            return features, logits
        else:
            return self.fc(self.model(x))