ashleve / lightning-hydra-template

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡
4.24k stars 653 forks source link

Add Custom Dataloader Example into template #452

Open abhijeetdhakane opened 2 years ago

abhijeetdhakane commented 2 years ago

@ashleve this is an excellent template. But if you add some files, then it will be usable for a lot many audiences. I felt a little tedious for data modules. The template has an MNIST data module which is 'standard.' In many used cases people use custom Dataset, If you add those, then it will be more helpful.

Just an example I tried:

datamodule -> component -> dataloader.py

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

from https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

then datamodules -> XYZ_datamodule.py

A DataModule implements 5 key methods:
        def prepare_data(self):
            # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
            # download data, pre-process, split, save to disk, etc...
        def setup(self, stage):
            # things to do on every process in DDP
            # load data, set variables, etc...
        def train_dataloader(self):
            # return train dataloader
        def val_dataloader(self):
            # return validation dataloader
        def test_dataloader(self):
            # return test dataloader
        def teardown(self):
            # called on every process in DDP
            # clean up after fit or test

I ran into errors while doing the above, my datamodule_XYZ.yaml like:

_target_:src.datamodules.XYZ_datamodule.XZYDataModule
datasets:
       train:
       _target_: src.datamodules.components.dataloader.FaceLandmarksDataset
       csv_file: ...
       root_dir: ...
       transform:
           _target_: ...

      val:

      test:

batches:
   train:10
   val:10
   test:10

...

JohannesK14 commented 2 years ago

Hi @abhijeetdhakane, could you post your error message and your implementation of the DataModule?

abhijeetdhakane commented 2 years ago

Resolved no worries!!

ashleve commented 2 years ago

Hey @abhijeetdhakane

Sorry for the late response. I agree it would be valuable to add example of the Dataset class. Let me reopen this for now as a todo task for near future.