LeapLabTHU / ARC

ICCV 2023: Adaptive Rotated Convolution for Rotated Object Detection
Apache License 2.0
100 stars 3 forks source link

关于旋转矩阵的生成,有一些困惑 #1

Open GluttonK opened 10 months ago

GluttonK commented 10 months ago

`def _get_rotation_matrix(thetas): bs, g = thetas.shape device = thetas.device thetas = thetas.reshape(-1) # [bs, n] --> [bs x n]

x = torch.cos(thetas)
y = torch.sin(thetas)
x = x.unsqueeze(0).unsqueeze(0)  # shape = [1, 1, bs * g]
y = y.unsqueeze(0).unsqueeze(0)
a = x - y
b = x * y
c = x + y

rot_mat_positive = torch.cat((
    torch.cat((a, 1-a, torch.zeros(1, 7, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 1, bs*g, device=device), x-b, b, torch.zeros(1, 1, bs*g, device=device), 1-c+b, y-b, torch.zeros(1, 3, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device), 1-a, torch.zeros(1, 3, bs*g, device=device)), dim=1),
    torch.cat((b, y-b, torch.zeros(1,1 , bs*g, device=device), x-b, 1-c+b, torch.zeros(1, 4, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-c+b, x-b, torch.zeros(1, 1, bs*g, device=device), y-b, b), dim=1),
    torch.cat((torch.zeros(1, 3, bs*g, device=device), 1-a, torch.zeros(1, 2, bs*g, device=device), a, torch.zeros(1, 2, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 3, bs*g, device=device), y-b, 1-c+b, torch.zeros(1, 1, bs*g, device=device), b, x-b, torch.zeros(1, 1, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 7, bs*g, device=device), 1-a, a), dim=1)
), dim=0)  # shape = [k^2, k^2, bs*g]

rot_mat_negative = torch.cat((
    torch.cat((c, torch.zeros(1, 2, bs*g, device=device), 1-c, torch.zeros(1, 5, bs*g, device=device)), dim=1),
    torch.cat((-b, x+b, torch.zeros(1, 1, bs*g, device=device), b-y, 1-a-b, torch.zeros(1, 4, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 1, bs*g, device=device), 1-c, c, torch.zeros(1, 6, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 3, bs*g, device=device), x+b, 1-a-b, torch.zeros(1, 1, bs*g, device=device), -b, b-y, torch.zeros(1, 1, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 4, bs*g, device=device), torch.ones(1, 1, bs*g, device=device), torch.zeros(1, 4, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 1, bs*g, device=device), b-y, -b, torch.zeros(1, 1, bs*g, device=device), 1-a-b, x+b, torch.zeros(1, 3, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 6, bs*g, device=device), c, 1-c, torch.zeros(1, 1, bs*g, device=device)), dim=1),
    torch.cat((torch.zeros(1, 4, bs*g, device=device), 1-a-b, b-y, torch.zeros(1, 1, bs*g, device=device), x+b, -b), dim=1),
    torch.cat((torch.zeros(1, 5, bs*g, device=device), 1-c, torch.zeros(1, 2, bs*g, device=device), c), dim=1)
), dim=0)  # shape = [k^2, k^2, bs*g]

mask = (thetas >= 0).unsqueeze(0).unsqueeze(0)
mask = mask.float()                                                   # shape = [1, 1, bs*g]
rot_mat = mask * rot_mat_positive + (1 - mask) * rot_mat_negative     # shape = [k*k, k*k, bs*g]
rot_mat = rot_mat.permute(2, 0, 1)                                    # shape = [bs*g, k*k, k*k]
rot_mat = rot_mat.reshape(bs, g, rot_mat.shape[1], rot_mat.shape[2])  # shape = [bs, g, k*k, k*k]
return rot_mat`

在这个生成旋转矩阵的函数中,我不太理解得到rot_mat_positive 和rot_mat_negative这两个矩阵 的拼接原理。这只适用于旋转3x3标准卷积,我希望理解这段代码以试图将其扩展到更大的卷积上。 能否提供一些理论上的支持,非常感谢

yifanpu001 commented 10 months ago

can you schedule a remote meeting with me, and I'll explain the meaning of this code to you

GluttonK commented 10 months ago

Thank for your answer . Here is my email and QQ : 2287512349@qq.com You can send an email or Add Friends with me to schedule a remote meeting I will set aside sufficient time to wait for your reply

Anm-pinellia commented 10 months ago

I also have the same confusion about the code, it seems it only work with kernel size of 3. Can you give more explaination about it? Thanks very much.

yifanpu001 commented 10 months ago

@Anm-pinellia @GluttonK Hi, this can be taken a fast implementation of torch.nn.functional.affine_grid and torch.nn.functional.grid_sample. These two operations can be simplified when only rotating a 3x3 kernel.

aleeyang commented 7 months ago

我也有点confused 这个代码,最后通过这个卷积,输出的特征图 是已经纠正旋转的特征图吗

wind-waves-fll commented 3 months ago

I don't understand the rot_mat_positive and rot_mat_negative. why can we get the rot_mat by this way?maybe my math is bad.

yifanpu001 commented 3 months ago

Hi, the math derivation behind this code snippet is quiet complicated, so that I cannot make it clear by using several sentences. I will make some slides to show the derivation process when I'm free (maybe in the summer vacation).

wind-waves-fll commented 3 months ago

OK, thank you very much.

DenverLiao commented 3 months ago

For the mathematical formula here, I also have some confusion. Could you please provide a detailed derivation of the formula?

qing-yao commented 2 months ago

Hello, I don't quite understand the principles of obtaining the matrices rot_mat_positive and rot_mat_negative. Could you please explain? Thank you.

Nu1sance commented 2 months ago

looking forward to further details from authors