WarmongeringBeaver / PhenDiff

Train and fine-tune diffusion models. Perform image-to-image class transfer experiments.
9 stars 1 forks source link

CudNN error when prob_unconditional ≠ 1.0 #5

Open nicoboou opened 1 month ago

nicoboou commented 1 month ago

Hi !

I've been working a bit with your code, thanks again for providing it; nice work !

I adapted it to my needs especially in regards to the Dataset class part (needed to load images contained in .npz files), and it works fine when I train a DDIM unconditionally (prob_unconditional = 1.0).

But i get an annoying cudNN error when I try to train the same DDIM with the following parameters (i.e. prob_unconditional ≠ 1.0):

ERROR LOG error_phendiff

PARAMETERS Capture d’écran 2024-06-10 à 14 23 06

Any idea how to troubleshoot this ? Thanks for your help ! :)

nicoboou commented 1 month ago

NOTE: here is the Dataset class that I use:

class LRUCache(OrderedDict):
    """Least Recently Used Cache implementation using OrderedDict."""

    def __init__(self, capacity: int):
        super().__init__()
        self.capacity = capacity

    def get(self, key):
        if key not in self:
            return None
        self.move_to_end(key)
        return self[key]

    def put(self, key, value):
        self[key] = value
        self.move_to_end(key)
        if len(self) > self.capacity:
            self.popitem(last=False)

class RawMicroscopeCropped(Dataset):
    def __init__(
        self,
        data,
        train_type: str = "train",
        transform=None,
        no_labels: bool = False,
        cache_size: int = 2000,
    ):
        random.seed(42)  # Setting the random seed for reproducibility
        self.data = data
        self.train_type = train_type
        self.transform = transform
        self.cache = LRUCache(capacity=cache_size)

        self.index = []
        self.labels = []
        self.slide_ids = []
        self.slide_to_bags = defaultdict(list)

        self.classes = sorted(set([entry["label"] for entry in data]))

        self.no_labels = no_labels

        self._build_index(data)

    def _build_index(self, data):
        """
        Get the paths of the files containing the data and balance the bags across classes.
        """
        labels_to_bags = defaultdict(list)
        for entry in data:
            slide_id = entry["slide_id"]
            image_id = entry["image_id"]
            label = entry["label"]
            num_cells = entry["num_cells"]

            file_name = f"{slide_id}_{image_id}_{label}"
            file_path = os.path.join(
                "/projects/smala/Nicolas/malaria_detection/src/data/datasets/microscope/",
                "cropped",
                "processed_images",
                file_name + ".npz",
            )
            num_blocks = int(num_cells) // 100

            for block_idx in range(num_blocks):
                labels_to_bags[label].append((file_path, block_idx, slide_id))

        # Determine the minimum number of bags per class for balancing
        min_bags = min(len(bags) for bags in labels_to_bags.values())

        for label, bags in labels_to_bags.items():
            bags = bags[:min_bags]  # Ensuring consistency by taking the first min_bags
            for file_path, block_idx, slide_id in bags:
                for i in range(100):
                    self.index.append((file_path, block_idx, i))
                    self.labels.append(label)
                    self.slide_ids.append(slide_id)

    def apply_subset(self, subset_indices):
        """
        Apply the subset of indices to the dataset.
        """
        self.index = [self.index[i] for i in subset_indices]
        self.labels = [self.labels[i] for i in subset_indices]
        self.slide_ids = [self.slide_ids[i] for i in subset_indices]

    def _load_data(self, path):
        data = self.cache.get(path)
        if data is None:
            with np.load(path, allow_pickle=True) as loaded_data:
                data = loaded_data["data"]
                self.cache.put(path, data)
        return data

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        path, block_idx, img_idx = self.index[idx]
        data = self._load_data(path)
        img = data[block_idx][img_idx]

        # Transform and stack as before
        if self.transform:
            img = self.transform(Image.fromarray(img))
        else:
            img = transforms.ToTensor()(img)

        if self.no_labels:
            return img

        # Get the label
        label = torch.tensor(int(self.labels[idx]))

        return img, label

    def get_balanced_subset_indices(self, percentage):
        """
        Get indices for a balanced subset of the dataset containing the specified percentage of each class.
        """
        if not (1 <= percentage <= 100):
            raise ValueError("Percentage must be between 1 and 100")

        subset_indices = []
        labels_to_indices = defaultdict(list)

        # Group indices by label
        for idx, label in enumerate(self.labels):
            labels_to_indices[label].append(idx)

        # Calculate the number of samples per class for the subset
        for label, indices in labels_to_indices.items():
            num_samples = max(1, int(len(indices) * (percentage / 100)))
            subset_indices.extend(random.sample(indices, num_samples))

        return subset_indices