S-aiueo32 / contextual_loss_pytorch

Contextual Loss (CX) and Contextual Bilateral Loss (CoBi).
MIT License
172 stars 45 forks source link

Bug in functions compute_l1_distance and compute_l2_distance #6

Open DmitryBabichev opened 4 years ago

DmitryBabichev commented 4 years ago

dist = dist.sum(dim=1).abs() in line 162 in contextual_loss/functional.py is not a l1 distance, correctly dist = dist.abs().sum(dim=1)

In line 162 in contextual_loss/functional.py you shold transpose matrix A: dist = y_s - 2 * A.transpose(1, 2) + x_s.transpose(0, 1)

Kai-46 commented 4 years ago

I believe there is also a bug in the compute_l2_distance function. One correct implementation can be:

def compute_l2_distance(x, y):
    N, C, H, W = x.size()
    x_vec = x.view(N, C, -1)
    y_vec = y.view(N, C, -1)
    x_s = torch.sum(x_vec ** 2, dim=1, keepdim=True)
    y_s = torch.sum(y_vec ** 2, dim=1, keepdim=True)

    A = y_vec.transpose(1, 2) @ x_vec
    # print(x.shape, y_s.shape, A.shape, x_s.shape)
    dist = y_s - 2 * A + x_s.transpose(1, 2)
    dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
    dist = dist.clamp(min=0.)

    return dist

Feel free to point out any potential bugs!

zhaowei11594 commented 4 years ago

I believe there is also a bug in the compute_l2_distance function. One correct implementation can be:

def compute_l2_distance(x, y):
    N, C, H, W = x.size()
    x_vec = x.view(N, C, -1)
    y_vec = y.view(N, C, -1)
    x_s = torch.sum(x_vec ** 2, dim=1, keepdim=True)
    y_s = torch.sum(y_vec ** 2, dim=1, keepdim=True)

    A = y_vec.transpose(1, 2) @ x_vec
    # print(x.shape, y_s.shape, A.shape, x_s.shape)
    dist = y_s - 2 * A + x_s.transpose(1, 2)
    dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
    dist = dist.clamp(min=0.)

    return dist

Feel free to point out any potential bugs! dist = y_s - 2 * A + x_s.transpose(0,1) #change here(0,1) https://github.com/S-aiueo32/contextual_loss_pytorch/issues/6 RuntimeError: The size of tensor a (4096) must match the size of tensor b (2) at non-singleton dimension 1 Good!!! You are right. if I do not correct here, when batch_size >1, it will go wrong.

Turlan commented 7 months ago
def compute_l2_distance(x, y):
    N, C, H, W = x.size()
    x_vec = x.view(N, C, -1)
    y_vec = y.view(N, C, -1)
    x_s = torch.sum(x_vec ** 2, dim=1, keepdim=True)
    y_s = torch.sum(y_vec ** 2, dim=1, keepdim=True)

    A = y_vec.transpose(1, 2) @ x_vec
    # print(x.shape, y_s.shape, A.shape, x_s.shape)
    dist = y_s - 2 * A + x_s.transpose(1, 2)
    dist = dist.transpose(1, 2).reshape(N, H*W, H*W)
    dist = dist.clamp(min=0.)

    return dist

As mentioned by @DmitryBabichev, A should be transposed. Namely:

dist = y_s - 2 * A.transpose(1, 2) + x_s.transpose(1, 2)