koncle / MVDG

6 stars 0 forks source link

How to draw the Local sharpness curve as shown in Figure.4? #1

Open GoodMan0 opened 2 years ago

GoodMan0 commented 2 years ago

Hi, thanks for your excellent work! I want to know the detail of the perturbed parameter, as you said in the paper which is sampled from multi different gaussian distributions, and I am confused with the procedure. Could you please provide more details or the script to get the curve data?

koncle commented 2 years ago

Thank you for the interest. The perturbed parameter are sampled from a single gaussian distribution and the code snippet is as follows:

origin_loss = self.collect_loss(self.model, self.loaders[mode])
times = 10
for gamma in tqdm(list(range(2, 40, 2))):
    avg_loss = 0
    for _ in range(times):
        delta = torch.randn_like(parameter)
        delta = delta / delta.norm() * gamma
        new_param = parameter + delta
        self.put_new_param(new_param, self.model)
        loss = self.collect_loss(self.model, self.loaders[mode])
        avg_loss += loss
    flatness = avg_loss / times - origin_loss
    print('{} : {}'.format(gamma, flatness, origin_loss))
GoodMan0 commented 2 years ago

Thanks for your reply. I implemented the statistical process about sharpness based on the code snippet and settings you provided. But I met an issue that as gamma increases, the loss increases uncontrollably and is much bigger than the original loss, especially in the ERM model. So I want to know if you add perturbation on the whole parameters of the model or if there are some details I didn't notice, like scale.

koncle commented 2 years ago

The norm of perturbation is actually normalized to 1 and then multiplied with gamma. The uncontrollable loss actually should not appear. I have upload the visualization file in utils/vis_weight_surface.py, you can check it.

GoodMan0 commented 2 years ago

Got it. Thanks a lot!

GoodMan0 commented 2 years ago

Hello, I run the code to implement multi-domain generalization. The results of PACS(86.51%) are the same as reported in your paper(86.56%), but the performance on VLCS is lower than your reports. I wonder if I miss some important issues, and I run the two experiments with the same settings as shown in your paper. The results are shown below:

**MVRML:**
CALTECH : 97.93+-0.27 
LABELME : 62.76+-0.39 
PASCAL : 71.79+-0.64 
SUN : 71.67+-0.83 
Mean : 76.04+-0.44
**MVRML+MVP:** 
CALTECH : 98.21+-0.41 
LABELME : 63.73+-0.34 
PASCAL : 72.99+-0.15 
SUN : 71.98+-0.56 
Mean : 76.73+-0.30