AlonzoLeeeooo / SketchRefiner

The official code implementation of "Towards Interactive Image Inpainting via Sketch Refinement".
https://ieeexplore.ieee.org/abstract/document/10533842
32 stars 1 forks source link

关于训练的问题 #10

Open liuna0211 opened 6 days ago

liuna0211 commented 6 days ago

不好意思又打扰您,我有一些问题想要咨询 我目前使用celeba-hq数据集,28000张作为训练, 1.parser.add_argument('--max_iters', default=280000, type=int, help='max iterations of training') parser.add_argument('--epochs', default=10, type=int, help='epochs of training') parser.add_argument('--batch_size', default=32, type=int, help='batch size') parser.add_argument('--num_workers', default=16, type=int, help='workers number of data loader') parser.add_argument('--sample_interval', default=28000, type=int, help='the interval of saving training samples') parser.add_argument('--checkpoint_interval', default=28000, type=int, help='the interval of saving checkpoints') 我想咨询如何确定----sample_interval的值,相关代码为self.iteration = epoch len(self.dataset) + (batch_index + 1) self.configs.batch_size if self.iteration % self.configs.sample_interval == 0: 很多时候 我的代码不会进入if self.iteration % self.configs.sample_interval == 0这句判断,导致不能保存模型,我想每个epoch结束保存一次,也就是每处理完28000张之后,保存一次 2.上一步会生成俩个权重文件,分别为checkpoint_epoch1_iters13984_optimizer_rm.pth,与checkpoint_epoch1_iters13984_registration_module.pth。我想咨询,进行EM训练,即这个命令行--train_EM --RM_checkpoint /path/to/model/weights/of/RM 使用哪个权重文件。 3.推理第一部分 python SRN_test.py --images /path/to/test/source/images --masks /path/to/test/masks --edge_prefix /path/to/detected/edges --sketch_prefix /path/to/input/sketches --output /path/to/output/dir --RM_checkpoint /path/to/RM/checkpoint --EM_checkpoint /path/to/EM/checkpoint 这些命令输入是否都是test数据, --edge_prefix /path/to/detected/edges --sketch_prefix /path/to/input/sketches 这俩句是否也是1000张test数据对应的边缘图与草图, --RM_checkpoint /path/to/RM/checkpoint --EM_checkpoint /path/to/EM/checkpoint这俩句命令对应的是哪个权重文件,问题同2,下载作者的权重文件connector.pth,registrator.pth是否是对应这俩个呢。 4.推理第二部分 python SIN_test.py --images /path/to/source/images --masks /path/to/masks --edges /path/to/detected/edges --sketches /path/to/input/sketches --refined_sketches /path/to/refined/sketches --num_samples maximum_samples 这些命令行对应的是train文件吗也就是28000张。 还是说对应的也是test文件。 --checkpoint /path/to/model/weights/of/SIM 关于这个命令行 是否对应下载权重文件中的celebahq.pth。 问题有点多,主要集中在我不清楚放的是训练集文件还是测试集文件。以及权重文件的对应,非常感谢您

AlonzoLeeeooo commented 6 days ago

不好意思又打扰您,我有一些问题想要咨询 我目前使用celeba-hq数据集,28000张作为训练, 1.parser.add_argument('--max_iters', default=280000, type=int, help='max iterations of training') parser.add_argument('--epochs', default=10, type=int, help='epochs of training') parser.add_argument('--batch_size', default=32, type=int, help='batch size') parser.add_argument('--num_workers', default=16, type=int, help='workers number of data loader') parser.add_argument('--sample_interval', default=28000, type=int, help='the interval of saving training samples') parser.add_argument('--checkpoint_interval', default=28000, type=int, help='the interval of saving checkpoints') 我想咨询如何确定----sample_interval的值,相关代码为self.iteration = epoch len(self.dataset) + (batch_index + 1) self.configs.batch_size if self.iteration % self.configs.sample_interval == 0: 很多时候 我的代码不会进入if self.iteration % self.configs.sample_interval == 0这句判断,导致不能保存模型,我想每个epoch结束保存一次,也就是每处理完28000张之后,保存一次 2.上一步会生成俩个权重文件,分别为checkpoint_epoch1_iters13984_optimizer_rm.pth,与checkpoint_epoch1_iters13984_registration_module.pth。我想咨询,进行EM训练,即这个命令行--train_EM --RM_checkpoint /path/to/model/weights/of/RM 使用哪个权重文件。 3.推理第一部分 python SRN_test.py --images /path/to/test/source/images --masks /path/to/test/masks --edge_prefix /path/to/detected/edges --sketch_prefix /path/to/input/sketches --output /path/to/output/dir --RM_checkpoint /path/to/RM/checkpoint --EM_checkpoint /path/to/EM/checkpoint 这些命令输入是否都是test数据, --edge_prefix /path/to/detected/edges --sketch_prefix /path/to/input/sketches 这俩句是否也是1000张test数据对应的边缘图与草图, --RM_checkpoint /path/to/RM/checkpoint --EM_checkpoint /path/to/EM/checkpoint这俩句命令对应的是哪个权重文件,问题同2,下载作者的权重文件connector.pth,registrator.pth是否是对应这俩个呢。 4.推理第二部分 python SIN_test.py --images /path/to/source/images --masks /path/to/masks --edges /path/to/detected/edges --sketches /path/to/input/sketches --refined_sketches /path/to/refined/sketches --num_samples maximum_samples 这些命令行对应的是train文件吗也就是28000张。 还是说对应的也是test文件。 --checkpoint /path/to/model/weights/of/SIM 关于这个命令行 是否对应下载权重文件中的celebahq.pth。 问题有点多,主要集中在我不清楚放的是训练集文件还是测试集文件。以及权重文件的对应,非常感谢您

  1. --sample_interval的意思是保存中间结果的频率,如果你想每个epoch都做一个可视化的话,个人感觉两种做法:(1)设置--sample_interval 28000;(2)在每一个epoch training loop结尾加一段可视化和保存模型权重的代码。你所说的不会进入我感觉有可能是batch_size导致的问题,因为如果你的batch_size > 1,那么pytorch dataloader的实现只会在这个epoch走len(dataloader) // batch_size的iteration,所以说最好是将--sample_interval设置小一些,或者是照前面的说法修改代码;
  2. checkpoint_epoch1_iters13984_optimizer_rm.pth是optimizer的权重,checkpoint_epoch1_iters13984_registration_module.pth是RM网络的权重,读取后者即可;
  3. 推理第一部分的命令都对应测试数据;--sketches_prefix--edges_prefix对应的是本地的文件夹路径,代码会自动读取文件夹中所有的图片;registrator.pth对应--RM_checkpointconnector.pth对应--EM_checkpiont
  4. 推理第二部分也都对应测试数据,--checkpoint对应的是celebahq.pth