JirongZhang / DeepHomography

Content-Aware Unsupervised Deep Homography Estimation
MIT License
340 stars 57 forks source link

Question about code: patch_indices and torch.gather #54

Open seedlingfl opened 5 months ago

seedlingfl commented 5 months ago

I saw your data generation will output patch_indices (mesh indices), and in the code, lots of places used torch.gather based on the indices. Could you please explain why this is needed? Can I just use a simple indexing such as tensor[:, :, y: y + patch_h, x: x + patch_w] instead?

def getPatchFromFullimg(patch_size_h, patch_size_w, patchIndices, batch_indices_tensor, img_full):
    num_batch, num_channels, height, width = img_full.size()
    warped_images_flat = img_full.reshape(-1)
    patch_indices_flat = patchIndices.reshape(-1)

    pixel_indices = patch_indices_flat.long() + batch_indices_tensor
    mask_patch = torch.gather(warped_images_flat, 0, pixel_indices)
    mask_patch = mask_patch.reshape([num_batch, 1, patch_size_h, patch_size_w])

    return mask_patch