Open GoodMan0 opened 2 years ago
Thank you for the interest. The perturbed parameter are sampled from a single gaussian distribution and the code snippet is as follows:
origin_loss = self.collect_loss(self.model, self.loaders[mode])
times = 10
for gamma in tqdm(list(range(2, 40, 2))):
avg_loss = 0
for _ in range(times):
delta = torch.randn_like(parameter)
delta = delta / delta.norm() * gamma
new_param = parameter + delta
self.put_new_param(new_param, self.model)
loss = self.collect_loss(self.model, self.loaders[mode])
avg_loss += loss
flatness = avg_loss / times - origin_loss
print('{} : {}'.format(gamma, flatness, origin_loss))
Thanks for your reply. I implemented the statistical process about sharpness based on the code snippet and settings you provided. But I met an issue that as gamma increases, the loss increases uncontrollably and is much bigger than the original loss, especially in the ERM model. So I want to know if you add perturbation on the whole parameters of the model or if there are some details I didn't notice, like scale.
The norm of perturbation is actually normalized to 1 and then multiplied with gamma. The uncontrollable loss actually should not appear. I have upload the visualization file in utils/vis_weight_surface.py
, you can check it.
Got it. Thanks a lot!
Hello, I run the code to implement multi-domain generalization. The results of PACS(86.51%) are the same as reported in your paper(86.56%), but the performance on VLCS is lower than your reports. I wonder if I miss some important issues, and I run the two experiments with the same settings as shown in your paper. The results are shown below:
**MVRML:**
CALTECH : 97.93+-0.27
LABELME : 62.76+-0.39
PASCAL : 71.79+-0.64
SUN : 71.67+-0.83
Mean : 76.04+-0.44
**MVRML+MVP:**
CALTECH : 98.21+-0.41
LABELME : 63.73+-0.34
PASCAL : 72.99+-0.15
SUN : 71.98+-0.56
Mean : 76.73+-0.30
Hi, thanks for your excellent work! I want to know the detail of the perturbed parameter, as you said in the paper which is sampled from multi different gaussian distributions, and I am confused with the procedure. Could you please provide more details or the script to get the curve data?