Yukariin / NatSR_pytorch

Pytorch implementation of Natural and Realistic Single Image Super-Resolution
MIT License
17 stars 2 forks source link

Transfer Learning using custom dataset #1

Closed nelaturuharsha closed 4 years ago

nelaturuharsha commented 4 years ago

Hello, I had two queries:

  1. I wanted to fine tune your existing model on images that I've collected for the purpose and I was wondering how I could go about doing that as in what would the steps I'd need to follow? Like generating the dataset in the right format etc.,
  2. The models you've uploaded are for 2x upsampling, how should I proceed in order to train a 3x/4x model?

Thank you so much once again for open sourcing your implementation.

Thank you, Sree Harsha

Yukariin commented 4 years ago

Hi @SreeHarshaNelaturu!

  1. You can use either DatasetFromFolder, DatasetFromList or SQLDataset. I, personally, use SQLDataset. You can generate it by using gen_data.py, you need to pick the patch scale, size and stride, and waifu2x-like noise optionally. For training x2 model with no noise you can use default parameters and just specify your input/output paths: python gen_data.py --input_dir data/ --output dataset/anime_train_x2_n0.db For training FRSR baseline use train_frsr.py. Download pre-trained FRSR model and place it in snapshots/ckpt dir, then run: python train_frsr.py --root dataset/anime_train_x2_n0.db --scale 2 --max_iter 100000 --resume 50000 For training full NSR baseline you need to train NMD first by using train_nmd.py. To train NMD you need to generate validation dataset too. In my experience 60000-90000 iters is more than enough: python train_nmd.py --root dataset/anime_train_x2_n0.db --val dataset/anime_val_x2_n0.db --scale 2 --max_iter 60000 After finishing NMD train you can start NSR training. Download pre-trained NSR model and place it in snapshots/ckpt dir, then run: python train.py --root dataset/anime_train_x2_n0.db --nmd snapshots/ckpt/NMD_60000.pth --scale 2 --max_iter 320000 --resume 220000 --transfer

  2. In my experience double x2 pass produces better results than single x4 pass. Haven't tried x3 yet. To train x3/x4 model you need to generate the dataset of corresponding patch size: python gen_data.py --scale 3 --stride 90 --input_dir data/ --output dataset/anime_train_x3_n0.db or python gen_data.py --scale 4 --stride 120 --input_dir data/ --output dataset/anime_train_x4_n0.db Then train the corresponding model: python train_frsr.py --root dataset/anime_train_x3_n0.db --scale 3 or python train_frsr.py --root dataset/anime_train_x4_n0.db --scale 4

nelaturuharsha commented 4 years ago

Thank you so much for the detailed reply, this is awesome. I should be able to go about doing this right away, and I believe what you mentioned about running 2x twice is true, and interesting to look into as to why its the case.

Thank you once again!