ActiveVisionLab / DFNet

DFNet: Enhance Absolute Pose Regression with Direct Feature Matching (ECCV 2022)
https://dfnet.active.vision
MIT License
94 stars 9 forks source link

Question about fixing feature extractor #25

Open ymtoo opened 5 days ago

ymtoo commented 5 days ago

Thanks for the great work!

In the paper: In stage two, fixing the histogram-assisted NeRF and the feature extractor G, ...

Referring to the code, the model is responsible for estimating the pose, while the feat_model generates the features. Both the pose and features are involved in the loss computation. During backpropagation, the gradients from loss.backed() flow through both the model and feat_model.

Could you point me to the part of the code where the feature extractor in feat_model is fixed?

chenusc11 commented 4 days ago

Hi, although the loss will be backpropagated through the feat_model, the parameters in the feat_model are fixed and will not be updated.

https://github.com/ActiveVisionLab/DFNet/blob/2c8fa7e324f8d17352ed469a8b793e0167e4c592/script/train.py#L123

ymtoo commented 9 hours ago

Thanks @chenusc11

Doesn't feat_model.eval() only change the behavior of some layers such as dropout and batch normalization? To disable gradient calculations, shouldn't requires_grad=False or torch.no_grad() be used instead?

Here is a minimum working example to demonstrate this.

>>> x = torch.rand(10,5)
>>> y = torch.rand(10,1)
>>> model = nn.Linear(5,1)
>>> model.eval()
Linear(in_features=5, out_features=1, bias=True)
>>> loss_function = lambda x, y : (model(x) - y).abs().sum()
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
>>> model.weight
Parameter containing:
tensor([[-0.2051, -0.0533,  0.1585, -0.2878, -0.1016]], requires_grad=True)
>>> loss = loss_function(x, y)
>>> loss.backward()
>>> optimizer.step()
>>> model.weight
Parameter containing:
tensor([[-0.1579, -0.0045,  0.2045, -0.2340, -0.0499]], requires_grad=True) # the weights changed
ymtoo commented 8 hours ago

It seems like the optimizer only updates model.parameters() but not feat_model.parameters().