Closed vtddggg closed 2 years ago
Hello, the random hparam sweep provided the best results for us.
A few comments:
python main.py --download-models --model-location <where models will be stored>
.@mitchellnw Thanks for your reply. Actually, we would like to add our proposed technique in the fine-tune stage to further improve the ModelSoup. ModelSoup is an excellent work and we cite the paper to help our research.
The minimal fine-tuning example already has learning rate
, weight decay
, epochs
. The things we need to add are label smooth
, mixup
and augmentation
, is it right?
It will bring us great convenience if you can provide the fine-tuning script run in CLIP ViT-B/32 experiment. And also If the fine-tuning script is unreachable, we may bother you again to check if our reimplement is consist with paper's.
The reimplement code is work in progress. We think we will post them is a few days later.
Best!
Ah okay, yes that's correct.
For mixup we took the following from the original mixup repository:
def mixup_data(x, y, beta=0.8, device='cpu'):
'''Returns mixed inputs, pairs of targets, and lambda'''
if beta > 0:
lam = np.random.beta(beta, beta)
else:
lam = 1
batch_size = x.size()[0]
index = torch.randperm(batch_size).to(device)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
Then instead of the standard forwards pass we do:
inputs, targets_a, targets_b, lam = mixup_data(inputs, labels, beta=args.beta, device=args.gpu)
logits = model(inputs)
loss = mixup_criterion(self.loss_fn, logits, targets_a, targets_b, lam)
For label smoothing we do:
loss_fn = LabelSmoothingCrossEntropy(args.smoothing)
where LabelSmoothingCrossEntropy
is defined as
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
def forward(self, x, target):
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
For data aug, with prob 1/3 we use the aug currently in the repo. With probably 1/3 we use
from timm.data.transforms_factory import transforms_imagenet_train
train_preprocess = transforms_imagenet_train(
img_size=224,
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
)
With probably 1/3 we do
from timm.data.transforms_factory import transforms_imagenet_train
train_preprocess = transforms_imagenet_train(
img_size=224,
auto_augment=RANDOM_AUG_STR,
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)
)
where RANDOM_AUG_STR
is generated as described in the paper, and you can find examples of everything we used here: https://github.com/mlfoundations/model-soups/blob/main/hparam_info.json
Thanks for your kind help!
Sorry I still have some left questions:
What about other fixed training configs, like --batch-size
, --warmup-length
? I get --batch-size=256
, --warmup-length=500
in fine-tuning script. Does --warmup-length=500
need to change with --epochs
?
How many GPUs do you use for accelerating the fine-tuning on 72 CLIP ViT-B/32 models?
Each run was 8 GPUs. We used warmup length 500 and batch size 512. For epochs we chose that randomly from the range (4, 6).
Here is the code which goes along with our description in the paper:
if np.random.rand() < 0.333:
aug = None
else:
m_aug = random.randint(0, 20)
n_aug = random.randint(0, 2)
aug = f'rand-m{m_aug}-n{n_aug}'
wd = 10 ** (- (0.2 + np.random.rand()*4) )
lr = 10 ** (- (4 + np.random.rand()*2) )
if np.random.rand() < 0.5:
mix = 0
else:
mix = np.random.rand()*0.9
if np.random.rand() < 0.5:
smoothing = 0
else:
smoothing = np.random.rand()*0.25
epochs = random.randint(4, 16)
info['aug'] = aug
info['wd'] = wd
info['lr'] = lr
info['mix'] = mix
info['smoothing'] = smoothing
info['epochs'] = epochs
Now everything is clear for me. Thanks!! 👍 We will work for reproducing with your advices. Can we ask for permission of open our reproducing fine-tuning code? It will be exhibited in easyrobust, and we will cite the paper.
Sounds good, thanks!
I am trying to reproduce the results of hyperparameter sweep on CLIP ViT-B/32, using this script
Could you give some help about: In J.2.1, standard grid, extreme grid, and random search is adopted for CLIP ViT-B/32 sweep. Which search method is the best and finally used for CLIP ViT-B/32?