Open alishbaimran opened 2 weeks ago
This code implements triplet contrastive learning training with the following set up:
Updated dataloader code:
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from iohub import open_ome_zarr
from monai.transforms import Compose, RandAdjustContrastd, RandAffined, RandGaussianNoised
import pytorch_lightning as pl
class TripletCellDataset(Dataset):
def __init__(self, base_path, z_slices=None, transforms=None):
self.base_path = base_path
self.transforms = transforms
self.positions = self._get_positions()
self.z_slices = z_slices if z_slices else slice(None)
# list of path to every single cell in the dataset
def _get_positions(self):
ds = open_ome_zarr(self.base_path, layout="hcs", mode="r")
return [path for path, _ in ds.positions()]
def _load_data(self, position):
ds = open_ome_zarr(self.base_path, layout="hcs", mode="r")
return ds[position]['0'][:, :, self.z_slices, :, :]
def __len__(self):
return len(self.positions)
def __getitem__(self, idx):
anchor_position = self.positions[idx]
anchor_data = self._load_data(anchor_position)
positive_data = anchor_data.copy()
if self.transforms:
positive_data = self.transforms({'image': positive_data})['image']
negative_position = random.choice(self.positions)
while negative_position == anchor_position:
negative_position = random.choice(self.positions)
negative_data = self._load_data(negative_position)
# three tensors are returned. Each tensor has a shape of (48, 2, selected_z_slices, 200, 200).
return torch.tensor(anchor_data, dtype=torch.float32), torch.tensor(positive_data, dtype=torch.float32), torch.tensor(negative_data, dtype=torch.float32)
class TripletCellDataModule(pl.LightningDataModule):
def __init__(self, base_path, z_slices=None, batch_size=8, num_workers=4):
super().__init__()
self.base_path = base_path
self.z_slices = z_slices
self.batch_size = batch_size
self.num_workers = num_workers
# other transformations etc can be set here
self.transforms = Compose([
RandAdjustContrastd(keys=["image"], prob=0.5, gamma=(0.5, 1.5)),
RandAffined(keys=["image"], prob=0.5, rotate_range=(0.1, 0.1), scale_range=(0.1, 0.1)),
RandGaussianNoised(keys=["image"], prob=0.5, mean=0.0, std=0.1)
])
def setup(self, stage=None):
self.train_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices, transforms=self.transforms)
self.val_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices)
self.test_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
getitem returns three tensors (anchor, positive, negative), each tensor has a shape of (48, 2, selected_z_slices, 200, 200).
Training code:
from pytorch_lightning import Trainer, LightningModule
import torch.nn.functional as F
import torch
from torchvision import models
class TripletNet(LightningModule):
def __init__(self):
super(TripletNet, self).__init__()
self.resnet = models.resnet18(pretrained=True) # can change this to resnet we use
self.resnet.fc = torch.nn.Linear(self.resnet.fc.in_features, 128) #can change output of embedding shape here2
self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-6, swap=False, reduction='mean')
def forward(self, x):
return self.resnet(x)
def training_step(self, batch, batch_idx):
anchor, positive, negative = batch
anchor_embed = self(anchor)
positive_embed = self(positive)
negative_embed = self(negative)
loss = self.triplet_loss(anchor_embed, positive_embed, negative_embed)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
base_path = '/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/patch_final.zarr'
z_slices = slice(10, 20)
datamodule = TripletCellDataModule(base_path, z_slices=z_slices)
model = TripletNet()
trainer = Trainer(max_epochs=10)
trainer.fit(model, datamodule)
training_step:
A few thoughts:
My understanding of the current classifier and task with 60X res data can be found here: https://docs.google.com/document/d/1j3UePmDJL_1V_9j7v3I4nLgAgKuFXXuqmlW8ZFOyTk0/edit?usp=sharing.
Given this approach, we'd need to modify HCSDataModule to support triplet sampling. Specifically:
The goal of triplet sampling is to minimize the distance between the anchor and the positive while maximizing the distance between the anchor and the negative in the learned embedding space.
Here the TripletTransform class takes a base transformation (defined in base_transform) and applies it to create the anchor and positive samples.
Modify HCSDataModule:
Using this updated dataloader:
Model details:
Other ideas: try simclr vs triplet sampling
Good resource: https://lilianweng.github.io/posts/2021-05-31-contrastive/