daa233 / generative-inpainting-pytorch

A PyTorch reimplementation for paper Generative Image Inpainting with Contextual Attention (https://arxiv.org/abs/1801.07892)
MIT License
472 stars 97 forks source link

Mask processing method in ContextualAttention #63

Closed herbiezhao closed 2 years ago

herbiezhao commented 2 years ago

code in ContextualAttention

    # m shape: [N, C, k, k, L]
    m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
    m = m.permute(0, 4, 1, 2, 3)    # m shape: [N, L, C, k, k]
    m = m[0]    # m shape: [L, C, k, k]
    # mm shape: [L, 1, 1, 1]
    mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True)==0.).to(torch.float32)
    mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]

The meaning of this code is: a batch uses the same mask. If a batch uses different mask, does the code need to be modified?

daa233 commented 2 years ago

The meaning of this code is: a batch uses the same mask.

Yes, it is consistent with the description in the paper.

If a batch uses different mask, does the code need to be modified?

The code needs to be modified to use different masks for different samples.