Parskatt / DKM

[CVPR 2023] DKM: Dense Kernelized Feature Matching for Geometry Estimation
https://parskatt.github.io/DKM/
Other
378 stars 28 forks source link

Question about Coordinate Embeddings implementations #58

Closed xjiangan closed 4 months ago

xjiangan commented 4 months ago

Hi Johan,

First of all, I want to express my gratitude for your recent contributions to the field, particularly your papers DKM and RoMa. They have been a great source of inspiration and motivation for my own research endeavors.

However, while reviewing the accompanying code implementation, I noticed some discrepancies regarding the Coordinate Embedding with random Fourier features because of the default behaviour of pytorch.

Issue Details

In the paper, it's stated that the entries for Fourier basis frequencies are sampled from a normal distribution. It is implemented using a Conv2d layer in pytorch. However, I found that the default initialization for Convolutional weight is Kaiming uniform.

In addition, I found that the Fourier basis frequencies are part of trainable parameters instead of fixed sampling in a previous work

I've visualized the histogram of the pretrained pos_conv weights image

with the following code:

import torch
from torch import nn

from dkm import DKMv3_outdoor
from roma import roma_outdoor

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dkm_weight=DKMv3_outdoor(device=device).decoder.gps['16'].pos_conv.weight.detach().cpu().numpy()
roma_weight=roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152)).decoder.gps['16'].pos_conv.weight.detach().cpu().numpy()
random_weight=nn.Conv2d(2, 256, kernel_size=(1, 1), stride=(1, 1)).weight.detach().cpu().numpy()

plt.figure(figsize=(8,5))
plt.hist(random_weight.flatten(), bins=25, alpha=0.5,density=True, label='Init',range=(-1.5,1.5))
plt.hist(roma_weight.flatten(), bins=25, alpha=0.5,density=True, label='ROMA',range=(-1.5,1.5))
plt.hist(dkm_weight.flatten(), bins=25, alpha=0.5,density=True, label='DKM',range=(-1.5,1.5))
plt.legend()
plt.title('Histogram of the weights pos_conv')

The resulting distribution appears more Gaussian, possibly due to the weight decay in the AdamW optimizer, which assumes a normal prior for parameters.

Questions

  1. Would it make a difference to use different distribution for initialization?
  2. Is it necessary to train the Fourier basis instead of fixing it?

Based on insights from a previous work, I suspect that the initialization distribution might not matter greatly, with only the standard deviation being crucial. However, I'm curious about whether training could lead to better Fourier basis frequencies. I would greatly appreciate it if you could further investigate these aspects.

Thank you once again for your exceptional contributions to the field.

Parskatt commented 4 months ago

Hi, thanks for the interesting observation.

Also interesting to see that DKM and RoMa do not seem to converge to the same parameters. As you say, probably the std is the most important param.

Another way you could visualize the embeddings is to scatterplot the (D,2) vectors. This should give some intuition whether its uniform on the circle.

I don't think fixing the distribution is necessarily optimal. Sampling a Normal distribution will make the correlation (between coordinates) tend towards the RBF kernel, but you could consider other distributions that lead to another kernel in the limit.

Regarding the questions:

  1. I think it's mainly the std, and even then the network can kind of adapt to this.
  2. Based on your results, possibly no, given that you know quite accurately what std to use. I think it's more like a convenience things, as things might depend on image resolution etc. Also it allows us to analyze stuff, like you did :)
Parskatt commented 4 months ago

I think some priors you might want to always have is that correlation is A: rotation invariant B:position invariant. But I think something that is more long-tailed than Gaussian, or even non-monotonic could have some uses.

xjiangan commented 4 months ago

Hi,

Thank you for your response and for sharing additional insights on the observation. I've incorporated your suggestion and visualized the embeddings using scatterplots, as recommended. I also tried kdeplot to better visualize overlapping points. Here are the visualizations: image image

I used the following code to generate these plots:

roma_weight=roma_weight.reshape(-1,2)
dkm_weight=dkm_weight.reshape(-1,2)
random_weight=random_weight.reshape(-1,2)

xlim=(-1.5,1.5)
ylim=(-1.1,1.1)

plt.figure(figsize=(8,5))
plt.scatter(random_weight[:,0], random_weight[:,1], alpha=0.5, label='Init',s=4)
plt.scatter(roma_weight[:,0], roma_weight[:,1], alpha=0.5, label='ROMA',s=4)
plt.scatter(dkm_weight[:,0], dkm_weight[:,1], alpha=0.5, label='DKM',s=4)
plt.gca().set_aspect('equal', adjustable='box')
plt.legend()
plt.title('Scatter plot of the weights pos_conv')

bw=0.1
fig,ax=plt.subplots(2,2,figsize=(10,10),sharey=True,sharex=True)
sns.kdeplot(x=random_weight[:,0], y=random_weight[:,1], label='Init',bw_adjust=bw,levels=255,fill=True,ax=ax[0,0])
sns.kdeplot(x=roma_weight[:,0], y=roma_weight[:,1], label='ROMA',bw_adjust=bw,levels=255,fill=True,ax=ax[0,0])
sns.kdeplot(x=dkm_weight[:,0], y=dkm_weight[:,1], label='DKM',bw_adjust=bw,levels=255,fill=True,ax=ax[0,0])
ax[0,0].set_title('All')

sns.kdeplot(x=random_weight[:,0], y=random_weight[:,1], bw_adjust=bw,levels=255,fill=True,ax=ax[0,1])
ax[0,1].set_title('Init')
sns.kdeplot(x=roma_weight[:,0], y=roma_weight[:,1], bw_adjust=bw,levels=255,fill=True,ax=ax[1,0])
ax[1,0].set_title('ROMA')
sns.kdeplot(x=dkm_weight[:,0], y=dkm_weight[:,1], bw_adjust=bw,levels=255,fill=True,ax=ax[1,1])
ax[1,1].set_title('DKM')

for i in range(2):
    for j in range(2):
        ax[i,j].set_xlim(xlim)
        ax[i,j].set_ylim(ylim)
        ax[i,j].set_aspect('equal', adjustable='box')
plt.tight_layout()

fig.suptitle('KDE plot of the weights pos_conv')

Upon examining the plots, they are not uniform in a circle, and some axis-aligned patterns appear, particularly noticeable in the RoMa embeddings, which suggests that the rotation invariant may not be well preserved.

In addition, because only cos is used, the kernel is not guaranteed to be stationary by cos(a)cos(b) + sin(a)sin(b) = cos(a-b)

Its violation of rotation and translation invariance is shown in the following visualization: image The code is as follows:

import numpy as np
def get_embedding(x,weight,cos_only=True):
    x=x.reshape(-1,2)
    weight=weight.reshape(-1,2)
    proj=8*np.pi* (x@weight.T)
    cos_emb=np.cos(proj)
    if cos_only:
        return cos_emb
    sin_emb=np.sin(proj)
    emb_all = np.concatenate([cos_emb,sin_emb],axis=1)
    return emb_all

x=y=np.linspace(-2,2,512)
X,Y=np.meshgrid(x,y)
pos=np.stack([X.flatten(),Y.flatten()],axis=1)

roma_embedding=get_embedding(pos,roma_weight)
roma_embedding=roma_embedding.reshape(512,512,-1)
origin_embedding = roma_embedding[256,256]
shift_embedding = roma_embedding[128,128]
fig,ax=plt.subplots(1,3,figsize=(15,5),sharey=True)

k_origin=np.sum(roma_embedding*origin_embedding,axis=-1)
min_k,max_k=np.amin(k_origin),np.amax(k_origin)
ax[0].imshow(k_origin[128:384,128:384],vmin=min_k,vmax=max_k)
ax[0].set_title('kernel to origin')

k_shift=np.sum(roma_embedding*shift_embedding,axis=-1)
ax[1].imshow(k_shift[0:256,0:256],vmin=min_k,vmax=max_k)
ax[1].set_title('kernel to shifted')

k_diff=k_origin[128:384,128:384]-k_shift[0:256,0:256]
ax[2].imshow(k_diff,vmin=min_k,vmax=max_k)
ax[2].set_title('kernel diff')

fig.suptitle("cosine embedding only")
plt.tight_layout()

However, when both cosine and sine embeddings are considered, as demonstrated in the following visualization, translation invariance seems to be better enforced.

roma_embedding=get_embedding(pos,roma_weight,cos_only=False)
roma_embedding=roma_embedding.reshape(512,512,-1)
origin_embedding = roma_embedding[256,256]
shift_embedding = roma_embedding[128,128]
fig,ax=plt.subplots(1,3,figsize=(15,5),sharey=True)

k_origin=np.sum(roma_embedding*origin_embedding,axis=-1)
min_k,max_k=np.amin(k_origin),np.amax(k_origin)
ax[0].imshow(k_origin[128:384,128:384],vmin=min_k,vmax=max_k)
ax[0].set_title('kernel to origin')

k_shift=np.sum(roma_embedding*shift_embedding,axis=-1)
ax[1].imshow(k_shift[0:256,0:256],vmin=min_k,vmax=max_k)
ax[1].set_title('kernel to shifted')

k_diff=k_origin[128:384,128:384]-k_shift[0:256,0:256]
ax[2].imshow(k_diff,vmin=min_k,vmax=max_k)
ax[2].set_title('kernel diff')

fig.suptitle("cosine and sine embedding")
plt.tight_layout()

image

Furthermore, enforcing rotation invariance by fixing an initial sampling and only learning the standard deviation could be another avenue to explore

I find your proposal to enforce these invariances intriguing and believe it could potentially improve the performance of RoMa and DKM. Further experimentation in this direction would indeed be valuable.

Thank you once again for your collaboration and for sharing your insights. Your contributions greatly enrich the discussion and pave the way for exciting future research directions.

Parskatt commented 4 months ago

Wow this is really cool! It must be related to using the regression-by-classification loss in RoMa. I bet this can cause potential issues when changing resolution at test-time. Perhaps though it also gives some minor performance gains? Not sure. Also, as you say, it may be sensitive to rotations.

Regarding using both cosine and sine, I think that makes sense, and I guess its the most common combination. I thought it wouldn't make much difference in the end, but perhaps some nice properties are better preserved as you demonstrate.

xjiangan commented 4 months ago

Hi, Thanks for your insights! The regression-by-classification loss in RoMa does seem like a plausible explanation for the observed patterns. I appreciate our collaboration and the opportunity to explore these ideas together. It's been a fruitful discussion, and I believe we've covered some interesting ground. Closing the issue now. If you have any more thoughts, feel free to reach out.

Best regards.