marcoaversa / diffinfinite

DiffInfinite Official Code
MIT License
23 stars 3 forks source link

Dataset structure and training script #6

Open HongLiuuuuu opened 3 months ago

HongLiuuuuu commented 3 months ago

Thanks for this repo and this impressive work! I am trying to train on my own dataset. Could you please provide the dataset structure and training script?

joihn commented 3 months ago

Here is what I've reverse engineered for the original dataloader

path_to_your_dataset/
   img1.jpg
   img2.jpg
   img3.jpg
   img1_mask.png
   img2_mask.png
   img3_mask.png

Make sure you add you path path_to_your_dataset in the Trainer class of the train.py (take inspiration from config/image_gen_train.yaml)

Make sure you use you own config system (unless you are in a jupyter notebook, vanilla one has some issue https://github.com/marcoaversa/diffinfinite/issues/3)

then

python -m accelerate.commands.launch train.py

Ps: original dataloader has quite some hardcoded stuff for their specific dataset, you might be better off rewriting yours

HongLiuuuuu commented 3 months ago

Thanks for your help!

HongLiuuuuu commented 3 months ago

I am trying to write the dataloader for my own dataset. However, it is hard for me to understand the returned mask and label in dataset_masks.py without the 'labels.csv' and 'labels.pickle' files. Could you please show some examples of the two files or explain?

joihn commented 3 months ago

here is a hint for your custom dataloder

from torch.utils.data import Dataset,DataLoader
class DatasetCustom(Dataset):

        """
        output must be (C x H x W)
        img is a torch tensor of 3x512x512, in the range [0.0, 1.0]
        mask is a torch tensor of 1x512x512
          the idx value in the mask should be :
          0: class unknown (mandatory class)
          1: your_first_class
          2 : your_second_class

          with p=0.5, you should feed an empty mask (0 everywhere)(unlabeled data)

        """
        # your code here

        return img, label