Open ekurtulus opened 1 year ago
👍🏻 It would be great if you can implement Tied-Augment
@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?
@ekurtulus that sounds interesing, can it be implement similar to augmix + jsd loss where most of the detail wrt to the splits of data, etc is in the dataset wrapper and loss ?
@rwightman Yes, however, the only difference is that Tied-Augment requires the features of the augmented views. Therefore, an additional wrapper has to be put for the model as well.
Example (for a Timm model with num_classes=0)
class TimmWrapper(nn.Module):
def __init__(self, model, num_classes):
super(TimmWrapper, self).__init__()
self.model = model
self.fc = nn.Linear(model.num_features, num_classes)
def forward(self, x, return_features=False):
if self.training or return_features:
features = self.model(x)
logits = self.fc(features)
return features, logits
else:
return self.fc(self.model(x))
Recently, we introduced Tied-Augment, a simple framework that combines self-supervised learning learning and supervised learning by making forward passes on two augmented views of the data with tied (shared) weights. In addition to the classification loss, it adds a similarity term to enforce invariance between the features of the augmented views. We found that our framework can be used to improve the effectiveness of both simple flips-and-crops (Crop-Flip) and aggressive augmentations (RandAugment) even for few-epoch training. As the effect of data augmentation is amplified, the sample efficiency of the data increases.
I believe Tied-Augment would be a nice addition to Timm training script. It can significantly improve mixup/RandAugment (77.6% → 79.6%) with marginal extra cost. Here is my reference implementation.