mehta-lab / VisCy

computer vision models for single-cell phenotyping
BSD 3-Clause "New" or "Revised" License
20 stars 2 forks source link

Contrastive Learning Implementation #90

Open alishbaimran opened 2 weeks ago

alishbaimran commented 2 weeks ago

My understanding of the current classifier and task with 60X res data can be found here: https://docs.google.com/document/d/1j3UePmDJL_1V_9j7v3I4nLgAgKuFXXuqmlW8ZFOyTk0/edit?usp=sharing.

Given this approach, we'd need to modify HCSDataModule to support triplet sampling. Specifically:

The goal of triplet sampling is to minimize the distance between the anchor and the positive while maximizing the distance between the anchor and the negative in the learned embedding space.

# takes a base_transform and applies it to a sample to generate anchor and positive samples.
# When the __call__ method is invoked with a sample, it applies the base_transform to the sample twice: first to create the anchor and second to create the positive.

class TripletTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, sample):
        anchor = self.transform(sample)
        positive = self.transform(sample)
        return anchor, positive

# The TripletDataset class is initialized with the dataset and a transform function. When the __getitem__ method is called with an index (idx):
# Anchor and Positive: The same data sample is retrieved for both the anchor and positive.
# Negative Sampling: A different sample is randomly selected as the negative.
#  If a transform is provided:
# The TripletTransform is used to apply the base_transform to both the anchor and positive samples, creating augmented versions.
# The base_transform is applied directly to the negative sample to create its augmented version (if wanted).

class TripletDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        anchor = self.data[idx]
        positive = self.data[idx]
        # simple negative sampling
        negative_idx = ...
        negative = self.data[negative_idx]
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)
        return (anchor, positive, negative)

Here the TripletTransform class takes a base transformation (defined in base_transform) and applies it to create the anchor and positive samples.

Modify HCSDataModule:

class TripletHCSDataModule(HCSDataModule):
    def __init__(
        self,
        data_path: str,
        source_channel: Union[str, Sequence[str]],
        target_channel: Union[str, Sequence[str]],
        z_window_size: int,
        split_ratio: float = 0.8,
        batch_size: int = 16,
        num_workers: int = 8,
        architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
        yx_patch_size: tuple[int, int] = (256, 256),
        normalizations: list[MapTransform] = [],
        augmentations: list[MapTransform] = [],
        caching: bool = False,
        ground_truth_masks: Optional[Path] = None,
    ):
        super().__init__(
            data_path,
            source_channel,
            target_channel,
            z_window_size,
            split_ratio,
            batch_size,
            num_workers,
            architecture,
            yx_patch_size,
            normalizations,
            augmentations,
            caching,
            ground_truth_masks
        )
        self.triplet_transform = TripletTransform(transforms.Compose(normalizations + augmentations))

#update to use TripletDataset
    def setup(self, stage: Optional[str] = None):
        super().setup(stage)
        if stage in ("fit", "validate"):
            self.train_dataset = TripletDataset(self.train_dataset.data, transform=self.triplet_transform)
            self.val_dataset = TripletDataset(self.val_dataset.data, transform=self.triplet_transform)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size // 3,  # adjust batch size for triplets
            num_workers=self.num_workers,
            shuffle=True,
            persistent_workers=bool(self.num_workers),
            prefetch_factor=4 if self.num_workers else None,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size // 3,  # adjust batch size for triplets
            num_workers=self.num_workers,
            shuffle=False,
            prefetch_factor=4 if self.num_workers else None,
            persistent_workers=bool(self.num_workers),
        )
# example of what could be included in the augmentations list
base_transform = transforms.Compose([
    transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)], p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()
])

Using this updated dataloader:

data_module = TripletHCSDataModule(
    dataset_path="...",
    source_channel=["Phase", "Sensor"],
    target_channel=["Inf_mask"],
    yx_patch_size=[128, 128],
    split_ratio=0.8,
    z_window_size=1,
    architecture="2D",
    num_workers=4,
    batch_size=64,
    normalizations=[
        NormalizeSampled(
            keys=["Sensor", "Phase"],
            level="fov_statistics",
            subtrahend="median",
            divisor="iqr",
        )
    ],
    augmentations=[
        RandWeightedCropd(
            num_samples=8,
            spatial_size=[-1, 128, 128],
            keys=["Sensor", "Phase", "Inf_mask"],
            w_key="Inf_mask",
        )
    ]
)

Model details:

Other ideas: try simclr vs triplet sampling

Good resource: https://lilianweng.github.io/posts/2021-05-31-contrastive/

alishbaimran commented 5 days ago

This code implements triplet contrastive learning training with the following set up:

Updated dataloader code:

import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from iohub import open_ome_zarr
from monai.transforms import Compose, RandAdjustContrastd, RandAffined, RandGaussianNoised
import pytorch_lightning as pl

class TripletCellDataset(Dataset):
    def __init__(self, base_path, z_slices=None, transforms=None):
        self.base_path = base_path
        self.transforms = transforms
        self.positions = self._get_positions()
        self.z_slices = z_slices if z_slices else slice(None)

    # list of path to every single cell in the dataset
    def _get_positions(self):
        ds = open_ome_zarr(self.base_path, layout="hcs", mode="r")
        return [path for path, _ in ds.positions()]

    def _load_data(self, position):
        ds = open_ome_zarr(self.base_path, layout="hcs", mode="r")
        return ds[position]['0'][:, :, self.z_slices, :, :]

    def __len__(self):
        return len(self.positions)

    def __getitem__(self, idx):
        anchor_position = self.positions[idx]
        anchor_data = self._load_data(anchor_position)

        positive_data = anchor_data.copy()
        if self.transforms:
            positive_data = self.transforms({'image': positive_data})['image']

        negative_position = random.choice(self.positions)
        while negative_position == anchor_position:
            negative_position = random.choice(self.positions)
        negative_data = self._load_data(negative_position)

        # three tensors are returned. Each tensor has a shape of (48, 2, selected_z_slices, 200, 200). 
        return torch.tensor(anchor_data, dtype=torch.float32), torch.tensor(positive_data, dtype=torch.float32), torch.tensor(negative_data, dtype=torch.float32)

class TripletCellDataModule(pl.LightningDataModule):
    def __init__(self, base_path, z_slices=None, batch_size=8, num_workers=4):
        super().__init__()
        self.base_path = base_path
        self.z_slices = z_slices
        self.batch_size = batch_size
        self.num_workers = num_workers

        # other transformations etc can be set here 
        self.transforms = Compose([
            RandAdjustContrastd(keys=["image"], prob=0.5, gamma=(0.5, 1.5)),
            RandAffined(keys=["image"], prob=0.5, rotate_range=(0.1, 0.1), scale_range=(0.1, 0.1)),
            RandGaussianNoised(keys=["image"], prob=0.5, mean=0.0, std=0.1)
        ])

    def setup(self, stage=None):
        self.train_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices, transforms=self.transforms)
        self.val_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices)
        self.test_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

getitem returns three tensors (anchor, positive, negative), each tensor has a shape of (48, 2, selected_z_slices, 200, 200).

Training code:

from pytorch_lightning import Trainer, LightningModule
import torch.nn.functional as F
import torch
from torchvision import models

class TripletNet(LightningModule):
    def __init__(self):
        super(TripletNet, self).__init__()
        self.resnet = models.resnet18(pretrained=True) # can change this to resnet we use
        self.resnet.fc = torch.nn.Linear(self.resnet.fc.in_features, 128)  #can change output of embedding shape here2
        self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-6, swap=False, reduction='mean')

    def forward(self, x):
        return self.resnet(x)

    def training_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_embed = self(anchor)
        positive_embed = self(positive)
        negative_embed = self(negative)

        loss = self.triplet_loss(anchor_embed, positive_embed, negative_embed)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

base_path = '/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/patch_final.zarr'
z_slices = slice(10, 20)  

datamodule = TripletCellDataModule(base_path, z_slices=z_slices)

model = TripletNet()
trainer = Trainer(max_epochs=10)
trainer.fit(model, datamodule)

training_step:

A few thoughts: