Open chokyungjin opened 2 years ago
@chokyungjin I don't totally understand your question, but to clarify the pred_pixel_values
and masked_patches
are both in pixel space from the original image. they have just been im2col per patch
I tried to reshape pred_pixel_values to the original image size again, but the shape is different. The original image shape is 512, 512, 1 but pred_pixel_values shape is 1, 768, 256. Doesn't it was given the input value of ViT as whole original image?
It only contains the masked patches - to get back the whole image will take some more code. You'd need to unsort the masked and unmasked patches together, and then do your reshaping
I'm sorry, but can I request a pseudo-code? Thanks.
I'm uncertain whether you're asking for code to compute loss over the full image, or to reconstruct the full image for viewing. If the latter, this is the code I ended up using to reconstruct the input image from the patches. @lucidrains feel free to fold this in if this is something you feel would be useful. The function ingests the raw patches and can return:
masked_indices
masked_indices
and pred_pixel_values
def reconstruct_image(self, patches, model_input, mean, std, masked_indices=None, pred_pixel_values=None, patch_size=8):
"""
Reconstructs the image given patches. Can also reconstruct the masked image as well as the predicted image.
To reconstruct the raw image from the patches, set masked_indices=None and pred_pixel_values=None. To reconstruct
the masked image, set masked_indices= the masked_indices tensor created in the `forward` call. To reconstruct the
predicted image, set masked_indices and pred_pixel_values = to their respective tensors created in the `forward` call.
ARGS:
patches (torch.Tensor): The raw patches (pre-patch embedding) generated for the given model input. Shape is
(batch_size x num_patches x patch_size^2 * channels)
model_input (torch.Tensor): The input images to the given model (batch_size x channels x height x width)
mean (list[float]): An array representing the per-channel mean of the dataset used to
denormalize the input and predicted pixels. (1 x channels)
std (list[float]): An array representing the per-channel std of the dataset used to
denormalize the input and predicted pixels. (1 x channels)
masked_indices (torch.Tensor): The patch indices that are masked (batch_size x masking_ratio * num_patches)
pred_pixel_values (torch.Tensor): The predicted pixel values for the patches that are masked (batch_size x masking_ratio * num_patches x patch_size^2 * channels)
RETURN:
reconstructed_image (torch.Tensor): Tensor containing the reconstructed image (batch_size x channels x height x width)
"""
patches = patches.cpu()
masked_indices_in = masked_indices is not None
predicted_pixels_in = pred_pixel_values is not None
if masked_indices_in:
masked_indices = masked_indices.cpu()
if predicted_pixels_in:
pred_pixel_values = pred_pixel_values.cpu()
patch_width = patch_height = patch_size
reconstructed_image = patches.clone()
if masked_indices_in or predicted_pixels_in:
for i in range(reconstructed_image.shape[0]):
if masked_indices_in and predicted_pixels_in:
reconstructed_image[i, masked_indices[i].cpu()] = pred_pixel_values[i, :].cpu().float()
elif masked_indices_in:
reconstructed_image[i, masked_indices[i].cpu()] = 0
invert_patch = Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', w=int(model_input.shape[3] / patch_width),
h=int(model_input.shape[2] / patch_height), c=model_input.shape[1],
p1=patch_height, p2=patch_width)
reconstructed_image = invert_patch(reconstructed_image)
reconstructed_image = reconstructed_image.numpy().transpose(0, 2, 3, 1)
reconstructed_image *= np.array(std)
reconstructed_image += np.array(mean)
return reconstructed_image.transpose(0, 3, 2, 1)
reconstructed_image = reconstructed_image.detach().numpy().transpose(0, 2, 3, 1) #bgr reconstructed_image *= std reconstructed_image += mean
Can these lines of code be omitted?
If you omit those lines, then I believe reconstructed_image
will be close to zero-mean and unit standard-deviation (since the network is trained on normalized output), which isn't great for visualization. However, if you didn't train your model on normalized data, then you should be fine to remove those lines.
Ok, thank you very much, because if I add these, the output forecast image and mask image are all white, by the way, are STD and mean set to (1,3), and the output forecast image I trained for a while seems to have only pixels of color in the mask part
@kcetskcaz did you add the function to the MAE class? If so, how did you cal lit? If not, how did you implement it? I'm a bit confused over the whole thing.
Thank you for your efforts, but I have a question about MAE code.
https://github.com/lucidrains/vit-pytorch/blob/dc57c75478c98241fd232a64a7bb4c23c5861730/vit_pytorch/mae.py#L91
MSE loss was calculated between vectors other than the original image, but what code should I add to check if the recon works well through the output image?