Yangsenqiao / vida

[ICLR 2024] ViDA: Homeostatic Visual Domain Adapter for Continual Test Time Adaptation
MIT License
51 stars 5 forks source link

About Gradient Problem and Resnet Model #11

Open steven12138 opened 6 months ago

steven12138 commented 6 months ago

In your paper, only the adapter is updated. But although you only allow gradients for the adapter in your ImageNet code, the model is already initialized with all parameters requiring gradients.

def inject_trainable_vida(... ):
  # model already initialized with all parameters requiring gradients.
    for _module in model.modules():
        if _module.__class__.__name__ in target_replace_module:
            for name, _child_module in _module.named_modules():
                if _child_module.__class__.__name__ == "Linear":
                    # ... inject the adapter
                    _module._modules[name].vida_up.weight.requires_grad = True
                    _module._modules[name].vida_down.weight.requires_grad = True

                    require_grad_params.extend(
                        list(_module._modules[name].vida_up2.parameters())
                    )
                    require_grad_params.extend(
                        list(_module._modules[name].vida_down2.parameters())
                    )
                    _module._modules[name].vida_up2.weight.requires_grad = True
                    _module._modules[name].vida_down2.weight.requires_grad = True                    
                    names.append(name)

    print([name for name, param in model.named_parameters() if param.requires_grad])
    # will contains all module of the model (backbone + vida)

This means that when the code actually runs, all parts are updated.

When I tried disabling gradients for the backbone, I couldn't achieve the expected performance, with results like this, which is almost 1 persent higher then the reported result in the paper. (Adapter_LR: 2e-7 EMA_MT: 0.8) Metric Gaussian Shot Impulse Defocus Glass Motion Zoom Snow Frost Fog Brightness Contrast ElasticTransform Pixelate JPEG Mean
Error 48.68 42.72 42.80 52.56 59.10 44.78 49.74 39.22 42.36 40.28 24.34 58.50 50.64 33.64 32.96 44.15

Do you have any ideas why this happened, that will be a great help for us.

Also, could you please provide the code for the ResNet part and the warm-up model, or just the warm-up model? That would greatly help our research.

hhhyyeee commented 3 months ago

Any updates? 👀