Clay-foundation / model

The Clay Foundation Model (in development)
https://clay-foundation.github.io/model/
Apache License 2.0
262 stars 30 forks source link

Upper image does not train? #156

Closed brunosan closed 2 months ago

brunosan commented 5 months ago

I took a 1k chips and trained it for 2k epochs.

Based on the images saved from the wandb hook, it seems clear that the upper part of the images are not really learning, or not at similar rates. As if its weights are not updated.

This is the animated video (made with ffmpeg).

https://github.com/Clay-foundation/model/assets/434029/ca432461-2d5f-4e46-a624-040dd1c5a02b

I'm not clear how this could happen, neither if actually a difference in the learning or on the unpatchify code. for visualization

brunosan commented 5 months ago

I continue investigating. This pattern of not learning on the upper half seems to apply across all bands.

Top row is input channels, middle is output and bottom is normalized difference.

https://github.com/Clay-foundation/model/assets/434029/d0d4684e-adfd-42d2-9591-fd18c0ad8d53

brunosan commented 5 months ago

I made a last test with another image set. Same pattern:

Top row is input channels, middle is output and bottom is normalized difference.

https://github.com/Clay-foundation/model/assets/434029/fb58836c-f7cd-4e2f-8ce2-4be5898f2b7c

Trying to understand this:

brunosan commented 4 months ago

If helpful, the code to create the images above is here

It selected the image on the first 3 batches with maximum variance (as a way to pick an image that has features, instead of just flat), then pulls the RGB as RGB, and the rest of the groups using average:

class LogIntermediatePredictions(L.Callback):
    """Visualize the model results at the end of every epoch."""

    def __init__(self):
        """
        Instantiates with wandb-logger.
        """
        super().__init__()
        self.selected_image = None

    def on_validation_end(
        self,
        trainer: L.Trainer,
        pl_module: L.LightningModule,
    ) -> None:
        """
        Called when the validation loop ends.
        At the end of each epoch, takes the first batch from validation dataset
        & logs the model predictions to wandb-logger for humans to interpret
        how model evolves over time.
        """
        with torch.no_grad():
            # Get WandB logger
            self.logger = get_wandb_logger(trainer=trainer)

            if self.selected_image is None:
                self.selected_image = self.select_image(trainer, pl_module)
            self.log_images(trainer, pl_module)

    def select_image(self, trainer, pl_module):
        print("Selecting image with max variance")
        batches = islice(iter(trainer.val_dataloaders), 3)
        max_variance = -1
        for ibatch in batches:
            batch = {
                k: v.to(pl_module.device)
                for k, v in ibatch.items()
                if isinstance(v, torch.Tensor)
            }
            images = batch["pixels"]  # Shape: [batch_size, channels, height, width]
            variances = images.var(
                dim=[1, 2, 3], keepdim=False
            )  # Calculate variance across C, H, W dimensions
            max_var_index = torch.argmax(variances).item()
            if variances[max_var_index] > max_variance:
                max_variance = variances[max_var_index]
                self.selected_image = max_var_index
        assert self.selected_image is not None
        print(f"Selected image with max variance: {self.selected_image}")
        return self.selected_image

    def log_images(self, trainer, pl_module):
        if self.selected_image >= trainer.val_dataloaders.batch_size:
            batch = next(
                islice(
                    iter(trainer.val_dataloaders),
                    self.selected_image // trainer.val_dataloaders.batch_size,
                    None,
                )
            )
        else:
            batch = next(iter(trainer.val_dataloaders))

        batch = {
            k: v.to(pl_module.device)
            for k, v in batch.items()
            if isinstance(v, torch.Tensor)
        }
        # ENCODER
        (
            encoded_unmasked_patches,
            unmasked_indices,
            masked_indices,
            masked_matrix,
        ) = pl_module.model.encoder(batch)

        # DECODER
        pixels = pl_module.model.decoder(
            encoded_unmasked_patches, unmasked_indices, masked_indices
        )
        pixels = rearrange(
            pixels,
            "b c (h w) (p1 p2) -> b c (h p1) (w p2)",
            h=pl_module.model.image_size // pl_module.model.patch_size,
            p1=pl_module.model.patch_size,
        )

        assert pixels.shape == batch["pixels"].shape

        band_groups = {
            "rgb": (2, 1, 0),
            "<rededge>": (3, 4, 5, 7),
            "<ir>": (6, 8, 9),
            "<sar>": (10, 11),
            "dem": (12,),
        }

        n_rows, n_cols = (
            3,
            len(band_groups),
        )  # Rows for Input, Prediction, Difference
        fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5))

        def normalize_img(img):
            lower_percentile, upper_percentile = 1, 99
            lower_bound = np.percentile(img, lower_percentile)
            upper_bound = np.percentile(img, upper_percentile)
            img_clipped = np.clip(img, lower_bound, upper_bound)
            return (img_clipped - img_clipped.min()) / (
                img_clipped.max() - img_clipped.min()
            )

        for col, (group_name, bands) in enumerate(band_groups.items()):
            input_img = batch["pixels"][:, bands, :, :]
            pred_img = pixels[:, bands, :, :]
            input_img = (
                input_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0)
            )
            pred_img = (
                pred_img[self.selected_image].detach().cpu().numpy().transpose(1, 2, 0)
            )

            if group_name == "rgb":
                # Normalize RGB images
                input_norm = normalize_img(input_img)
                pred_norm = normalize_img(pred_img)
                # Calculate absolute difference for RGB
                diff_rgb = np.abs(input_norm - pred_norm)
            else:
                # Calculate mean for non-RGB bands if necessary
                input_mean = (
                    input_img.mean(axis=2) if input_img.ndim > 2 else input_img  # noqa: PLR2004
                )
                pred_mean = pred_img.mean(axis=2) if pred_img.ndim > 2 else pred_img  # noqa: PLR2004
                # Normalize and calculate difference
                input_norm = normalize_img(input_mean)
                pred_norm = normalize_img(pred_mean)
                diff_rgb = np.abs(input_norm - pred_norm)

            axs[0, col].imshow(input_norm, cmap="gray" if group_name != "rgb" else None)
            axs[1, col].imshow(pred_norm, cmap="gray" if group_name != "rgb" else None)
            axs[2, col].imshow(diff_rgb, cmap="gray" if group_name != "rgb" else None)

            for ax in axs[:, col]:
                ax.set_title(
                    f"""{group_name} {'Input' if ax == axs[0, col] else
                                     'Pred' if ax == axs[1, col] else
                                     'Diff'}"""
                )
                ax.axis("off")

        plt.tight_layout()
        self.logger.experiment.log({"Images": wandb.Image(fig)})
        plt.close(fig)
brunosan commented 4 months ago

Another case, trainning over Bali.

https://github.com/Clay-foundation/model/assets/434029/d865ee5b-c2af-4a5e-a5a8-03951e26c83c

alkalait commented 4 months ago

Maybe it's me but I can't play the videos on my end. In any case, I can see a trace of what you describe illustrated in #170.

Can you try to reproduce the same issue but with a finer or coarser patching? Is it always exactly the upper half of the patches and only the upper half that display this behaviour?

brunosan commented 4 months ago

Thanks for checking @alkalait. Haven't tested throroughly, but wanted to file an Issue meanwhile. I've put all the videos on this public folder, in case others cannot play (you might need VLC). https://drive.google.com/drive/folders/1BBOG7dWC5wqzmjjS-YL3svJBY-p6IA6r?usp=drive_link

srmsoumya commented 4 months ago

Thanks for highlighting this issue @brunosan.

The issue is addressed in this PR #193 - In the validation phase, as shuffle was False, we were masking just the upper half of the image that was generating the artifacts.

yellowcap commented 3 months ago

We have solved this but not sure if merged yet @srmsoumya ?

yellowcap commented 2 months ago

This as been solved, and is also no longer an issue for v1. So I am closing this here.