Closed brunosan closed 2 months ago
I continue investigating. This pattern of not learning on the upper half seems to apply across all bands.
Top row is input channels, middle is output and bottom is normalized difference.
https://github.com/Clay-foundation/model/assets/434029/d0d4684e-adfd-42d2-9591-fd18c0ad8d53
I made a last test with another image set. Same pattern:
Top row is input channels, middle is output and bottom is normalized difference.
https://github.com/Clay-foundation/model/assets/434029/fb58836c-f7cd-4e2f-8ce2-4be5898f2b7c
Trying to understand this:
If helpful, the code to create the images above is here
It selected the image on the first 3 batches with maximum variance (as a way to pick an image that has features, instead of just flat), then pulls the RGB as RGB, and the rest of the groups using average:
class LogIntermediatePredictions(L.Callback):
"""Visualize the model results at the end of every epoch."""
def __init__(self):
"""
Instantiates with wandb-logger.
"""
super().__init__()
self.selected_image = None
def on_validation_end(
self,
trainer: L.Trainer,
pl_module: L.LightningModule,
) -> None:
"""
Called when the validation loop ends.
At the end of each epoch, takes the first batch from validation dataset
& logs the model predictions to wandb-logger for humans to interpret
how model evolves over time.
"""
with torch.no_grad():
# Get WandB logger
self.logger = get_wandb_logger(trainer=trainer)
if self.selected_image is None:
self.selected_image = self.select_image(trainer, pl_module)
self.log_images(trainer, pl_module)
def select_image(self, trainer, pl_module):
print("Selecting image with max variance")
batches = islice(iter(trainer.val_dataloaders), 3)
max_variance = -1
for ibatch in batches:
batch = {
k: v.to(pl_module.device)
for k, v in ibatch.items()
if isinstance(v, torch.Tensor)
}
images = batch["pixels"] # Shape: [batch_size, channels, height, width]
variances = images.var(
dim=[1, 2, 3], keepdim=False
) # Calculate variance across C, H, W dimensions
max_var_index = torch.argmax(variances).item()
if variances[max_var_index] > max_variance:
max_variance = variances[max_var_index]
self.selected_image = max_var_index
assert self.selected_image is not None
print(f"Selected image with max variance: {self.selected_image}")
return self.selected_image
def log_images(self, trainer, pl_module):
if self.selected_image >= trainer.val_dataloaders.batch_size:
batch = next(
islice(
iter(trainer.val_dataloaders),
self.selected_image // trainer.val_dataloaders.batch_size,
None,
)
)
else:
batch = next(iter(trainer.val_dataloaders))
batch = {
k: v.to(pl_module.device)
for k, v in batch.items()
if isinstance(v, torch.Tensor)
}
# ENCODER
(
encoded_unmasked_patches,
unmasked_indices,
masked_indices,
masked_matrix,
) = pl_module.model.encoder(batch)
# DECODER
pixels = pl_module.model.decoder(
encoded_unmasked_patches, unmasked_indices, masked_indices
)
pixels = rearrange(
pixels,
"b c (h w) (p1 p2) -> b c (h p1) (w p2)",
h=pl_module.model.image_size // pl_module.model.patch_size,
p1=pl_module.model.patch_size,
)
assert pixels.shape == batch["pixels"].shape
band_groups = {
"rgb": (2, 1, 0),
"<rededge>": (3, 4, 5, 7),
"<ir>": (6, 8, 9),
"<sar>": (10, 11),
"dem": (12,),
}
n_rows, n_cols = (
3,
len(band_groups),
) # Rows for Input, Prediction, Difference
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5))
def normalize_img(img):
lower_percentile, upper_percentile = 1, 99
lower_bound = np.percentile(img, lower_percentile)
upper_bound = np.percentile(img, upper_percentile)
img_clipped = np.clip(img, lower_bound, upper_bound)
return (img_clipped - img_clipped.min()) / (
img_clipped.max() - img_clipped.min()
)
for col, (group_name, bands) in enumerate(band_groups.items()):
input_img = batch["pixels"][:, bands, :, :]
pred_img = pixels[:, bands, :, :]
input_img = (
input_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0)
)
pred_img = (
pred_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0)
)
if group_name == "rgb":
# Normalize RGB images
input_norm = normalize_img(input_img)
pred_norm = normalize_img(pred_img)
# Calculate absolute difference for RGB
diff_rgb = np.abs(input_norm - pred_norm)
else:
# Calculate mean for non-RGB bands if necessary
input_mean = (
input_img.mean(axis=2) if input_img.ndim > 2 else input_img # noqa: PLR2004
)
pred_mean = pred_img.mean(axis=2) if pred_img.ndim > 2 else pred_img # noqa: PLR2004
# Normalize and calculate difference
input_norm = normalize_img(input_mean)
pred_norm = normalize_img(pred_mean)
diff_rgb = np.abs(input_norm - pred_norm)
axs[0, col].imshow(input_norm, cmap="gray" if group_name != "rgb" else None)
axs[1, col].imshow(pred_norm, cmap="gray" if group_name != "rgb" else None)
axs[2, col].imshow(diff_rgb, cmap="gray" if group_name != "rgb" else None)
for ax in axs[:, col]:
ax.set_title(
f"""{group_name} {'Input' if ax == axs[0, col] else
'Pred' if ax == axs[1, col] else
'Diff'}"""
)
ax.axis("off")
plt.tight_layout()
self.logger.experiment.log({"Images": wandb.Image(fig)})
plt.close(fig)
Another case, trainning over Bali.
https://github.com/Clay-foundation/model/assets/434029/d865ee5b-c2af-4a5e-a5a8-03951e26c83c
Maybe it's me but I can't play the videos on my end. In any case, I can see a trace of what you describe illustrated in #170.
Can you try to reproduce the same issue but with a finer or coarser patching? Is it always exactly the upper half of the patches and only the upper half that display this behaviour?
Thanks for checking @alkalait. Haven't tested throroughly, but wanted to file an Issue meanwhile. I've put all the videos on this public folder, in case others cannot play (you might need VLC). https://drive.google.com/drive/folders/1BBOG7dWC5wqzmjjS-YL3svJBY-p6IA6r?usp=drive_link
Thanks for highlighting this issue @brunosan.
The issue is addressed in this PR #193 - In the validation phase, as shuffle
was False, we were masking just the upper half of the image that was generating the artifacts.
We have solved this but not sure if merged yet @srmsoumya ?
This as been solved, and is also no longer an issue for v1. So I am closing this here.
I took a 1k chips and trained it for 2k epochs.
Based on the images saved from the wandb hook, it seems clear that the upper part of the images are not really learning, or not at similar rates. As if its weights are not updated.
This is the animated video (made with ffmpeg).
https://github.com/Clay-foundation/model/assets/434029/ca432461-2d5f-4e46-a624-040dd1c5a02b
I'm not clear how this could happen, neither if actually a difference in the learning or on the unpatchify code. for visualization