SX-SS / GD-ViG

3 stars 0 forks source link

Help on Testing Weights of model #5

Open youssefmohana opened 4 days ago

youssefmohana commented 4 days ago
import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from typing import  Any , Optional,List , Tuple
import matplotlib.pyplot as plt
from numpy.typing import NDArray
import cv2 
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from  glob import glob
import os
import numpy as np 
import re
import pytorch_lightning as pl
import albumentations as A
import seaborn as sns 
from collections import Counter
from albumentations.pytorch.transforms import ToTensorV2
from scipy.ndimage import gaussian_filter
import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader
from typing import  Any , Optional,List , Tuple
import matplotlib.pyplot as plt
from numpy.typing import NDArray
import cv2 
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from  glob import glob
import os
import numpy as np 
import re
import pytorch_lightning as pl
import albumentations as A
import seaborn as sns 
from collections import Counter
from albumentations.pytorch.transforms import ToTensorV2
from scipy.ndimage import gaussian_filter
import seaborn as sns
from collections import Counter
from wordcloud import WordCloud
import numpy as np 
import re
import pytorch_lightning as pl
import albumentations as A
import seaborn as sns 
from collections import Counter
from albumentations.pytorch.transforms import ToTensorV2
from scipy.ndimage import gaussian_filter
pl.seed_everything(42)
class EyeGaze(Dataset):
    def __init__(self, data_dir='./mimic_part_jpg', phase="train", img_size=224) -> None:
        self.root: str = data_dir
        self.phase: str = phase
        self.img_size = img_size
        self.T = self.get_transform(self.phase, self.img_size)
        self.csv = pd.read_csv(os.path.join(self.root, "gaze", "fixations.csv"))
        self.labels = ["CHF", "Normal", "pneumonia"]
        self.labelsdict = {"CHF": 1, "Normal": 0, "pneumonia": 2}
        self.idlist = []
        for label in self.labels:
            self.idlist.extend(glob(os.path.join(self.root, self.phase, label, "*.jpg")))

    def __len__(self) -> int:
        return len(self.idlist)

    def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, int]:

        """
            get image , label and convert gaze data to gaze map   
        """

        img_path: str = self.idlist[index]
        id: str = img_path.split("/")[-1].split(".jpg")[0]
        label: int = self.labelsdict.get(os.path.basename(os.path.dirname(img_path)), -1)
        if label == -1:
            raise ValueError(f"Label not found for path: {img_path}")

        img = np.array(Image.open(img_path))
        # print(f"image size {img.shape}")
        img_h , img_w = img.shape
        gaze_map = self.get_gaze_map(id, (img_h, img_w))
        gaze_map = np.array(gaze_map)
        if self.T:
            transformed = self.T(image=img, mask=gaze_map)
            img = transformed["image"] 
            gaze = transformed["mask"]
            gaze = gaze.unsqueeze(0) 

        return img, gaze,label

    def get_gaze_map(self, id: str, img_size: tuple[int, int]) -> np.ndarray:
        """
        Parameters:
        1. id: gaze id
        2. img_size: height and width
        Returns: gaze_map
        """
        gaze_map = np.zeros(img_size, dtype=np.float32)
        idcsv = self.csv.loc[self.csv["DICOM_ID"] == id].reset_index(drop=True)

        if idcsv.empty:
            print(f"No gaze data found for ID: {id}")
            return gaze_map

        for ix, row in idcsv.iterrows():
            x , y = float(row["X_ORIGINAL"]), float(row["Y_ORIGINAL"])

            if ix == 0:
                duration = row["Time (in secs)"]
            else:
                duration = row["Time (in secs)"] - idcsv.loc[ix-1, "Time (in secs)"]
            # Ensure x and y are within bounds
            x = max(0, min(int(x), img_size[1] - 1))
            y = max(0, min(int(y), img_size[0] - 1))

            gaze_map[y, x] += duration

        # Apply Gaussian smoothing with a specified sigma value
        sigma = 25  # Adjust this value as needed
        gaze_map = gaussian_filter(gaze_map, sigma=sigma)
        if gaze_map.max() > gaze_map.min():
            gaze_map = (gaze_map - gaze_map.min()) / (gaze_map.max() - gaze_map.min())
        return gaze_map

    def get_transform(self, phase, img_size):
        if phase == 'train':
            return A.Compose(
                    [
                        A.Rotate(limit=5),
                        A.Resize(256, 256),
                        A.RandomResizedCrop(224, 224),
                        A.HorizontalFlip(),
                        A.ShiftScaleRotate(),
                        A.RandomGamma(),
                        A.RandomBrightnessContrast(),
                        A.Normalize(mean=0.456, std=0.224),
                        ToTensorV2(),
                    ])
        else:
            return A.Compose(
                    [
                        A.Resize(224, 224),
                        A.Normalize(mean=0.456, std=0.224),
                        ToTensorV2(),
                    ])

class MIMICEyeDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, img_size=224):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.img_size = img_size

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            self.train_dataset = EyeGaze(self.data_dir, phase='train', img_size=self.img_size)
            self.val_dataset = EyeGaze(self.data_dir, phase='test', img_size=self.img_size)
        if stage == 'test' or stage is None:
            self.test_dataset = EyeGaze(self.data_dir, phase='test', img_size=self.img_size)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8, pin_memory=True, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8, pin_memory=True, collate_fn=self.collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8, pin_memory=True, collate_fn=self.collate_fn)

    def collate_fn(self, batch):
        images = torch.stack([item[0] for item in batch])
        gazes = torch.stack([item[1] for item in batch])
        labels = torch.tensor([item[3] for item in batch])
        return {'image': images, "gaze": gazes, 'labels': labels}

image

Help please and dataset used https://drive.google.com/file/d/1jB0jENWn8NqCB0w9YCuEKpgm0Uiu5fdv/view?usp=share_link so help please when test it on it with wieght that you give it to me so results so bad

SX-SS commented 1 day ago

I suggest you test it using the code from the repository. From your results, your predictions seem to be random.