Rayhane-mamah / Efficient-VDVAE

Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"
https://arxiv.org/abs/2203.13751
MIT License
188 stars 21 forks source link

How to preprocess a new dataset? #9

Open turian opened 1 year ago

turian commented 1 year ago

I have a new dataset of 128x128 images. Can you provide README instructions on how to preprocess it?

Vanlogh commented 1 year ago

Hello @turian,

Thank you for bringing this to our attention. We will hopefully revise the README section in the future to include this, but for now here are the steps to load your own data:

  1. Check the preprocessing pipeline existent in efficient_vdvae_torch (or) efficient_vdvae_jax /data/generic_data_loader.py and verify it's compatible with the requirements in your own dataset. The generic pipeline has 3 parts: A normalization depending on the number of bits you want your images to be, A normalization so that images are in [-1,1] and an optional horizontal flip for your training dataset (Which you can control through hparams)

  2. Go to efficient_vdvae_torch (or) efficient_vdvae_jax/hparams.cfg and set your data paths and all the other parameters like the size of your images 128x128 in data section. Make sure to set a new string as your dataset_source. (new_data in this example.)

  3. Go to efficient_vdvae_torch/train.py and add your new dataset_source string to the list of supported datasets that use the generic dataloader pipeline.

if hparams.data.dataset_source in ['ffhq', 'celebAHQ', 'celebA', 'new_data']:
        train_files, train_filenames = create_filenames_list(hparams.data.train_data_path)
        val_files, val_filenames = create_filenames_list(hparams.data.val_data_path)
        train_loader, val_loader = train_val_data_generic(train_files, train_filenames, val_files, val_filenames,
                                                          hparams.run.num_gpus, local_rank)

For JAX it's quite similar. Go to efficient_vdvae_jax/train.py and add the new dataset_source.

    # Load datasets
    if hparams.data.dataset_source in ('ffhq', 'celebAHQ', 'celebA', 'new_data'):
        train_data, val_data = create_generic_datasets()

Hopefully that answers your question. Let me know if there's something that's not clear :). Otherwise, please feel free to close this issue.

Thank you! Louay Hazami

Rayhane-mamah commented 1 year ago

Hello @turian and thanks for showing interest in our work.

We have added custom dataset support in our latest commit.

Sufficient instructions on how to use are available in this section of the README. We also provide utility scripts to train/val split or resize your data if needed (as explained in the new section of the README).

Hope this helps, let us know if there are still any pending issues concerning this feature.

Best, Rayhane.