liuqk3 / PUT

Paper 'Transformer based Pluralistic Image Completion with Reduced Information Loss' in TPAMI 2024 and 'Reduce Information Loss in Transformers for Pluralistic Image Inpainting' in CVPR2022
MIT License
173 stars 15 forks source link

Train problem #14

Open zhangbaijin opened 1 year ago

zhangbaijin commented 1 year ago

Thanks for your contribution, there is a problem when i train it on 512X512 dataset on the Second stage Transformer_XX.yaml, RuntimeError: The size of tensor a (1024) must match the size of tensor b (512) at non-singleton dimension 1 How should i change the yaml if i want to train my dataset with 512X512

liuqk3 commented 1 year ago

Thanks for your interests in our work.

If you use our provided P-VQVAE to train the second stage transformer on 512x512, one thing should be kept in mind is that our provided P-VQVAE divides images into 16x16 patches. The number of tokens (or the length of sequence) for an images is 64x64. Change 1024 in the following line to 4096.

https://github.com/liuqk3/PUT/blob/7de8ce0ada1e63e8c5300857a463d68380e142f0/configs/put_cvpr2022/ffhq/transformer_ffhq.yaml#L5

and change [32,32] in the following line to [64, 64]. https://github.com/liuqk3/PUT/blob/7de8ce0ada1e63e8c5300857a463d68380e142f0/configs/put_cvpr2022/ffhq/transformer_ffhq.yaml#L25

These modifications should work. Best wishes!

zhangbaijin commented 1 year ago

Thanks for your help, can you give an example of how to train transformer, the pvqvae training time is below: pvqvae_Bihua: val generator: Epoch 1/150 | rec_loss: 0.3322 | loss: 0.3337 | quantize_loss: 0.0015 | used_unmasked_quantize_embed: 1.0000 | used_masked_quantize_embed: 1.0000 | unmasked_num_token: 13061.0000 | masked_num_token: 3323.0000 discriminator: Epoch 1/150 | loss: 0.0000 pvqvae_Bihua: train: Epoch 2/150 iter 6/92 || generator | rec_loss: 0.2488 | loss: 0.2501 | quantize_loss: 0.0013 | used_unmasked_quantize_embed: 1.0000 | used_masked_quantize_embed: 1.0000 | unmasked_num_token: 14845.0000 | masked_num_token: 1539.0000 || discriminator | loss: 0.0000 || generator_lr: 7.64e-06, discriminator_lr: 7.64e-07 || data_time: 0.2s | fbward_time: 0.2s | iter_time: 0.4s | iter_avg_time: 0.4s | epoch_time: 02s | spend_time: 01m:10s | left_time: 01h:23m:57s pvqvae_Bihua: train: Epoch 2/150 iter 16/92 || generator | rec_loss: 0.2533 | loss: 0.2547 | quantize_loss: 0.0014 | used_unmasked_quantize_embed: 1.0000 | used_masked_quantize_embed: 1.0000 | unmasked_num_token: 13854.0000 | masked_num_token: 2530.0000 || discriminator | loss: 0.0000 || generator_lr: 8.04e-06, discriminator_lr: 8.04e-07 || data_time: 0.2s | fbward_time: 0.2s | iter_time: 0.4s | iter_avg_time: 0.4s | epoch_time: 06s | spend_time: 01m:14s | left_time: 01h:23m:47s pvqvae_Bihua: train: Epoch 2/150 iter 26/92 || generator | rec_loss: 0.3294 | loss: 0.3312 | quantize_loss: 0.0018 | used_unmasked_quantize_embed: 1.0000 | used_masked_quantize_embed: 1.0000 | unmasked_num_token: 11445.0000 | masked_num_token: 4939.0000 || discriminator | loss: 0.0000 || generator_lr: 8.44e-06, discriminator_lr: 8.44e-07 || data_time: 0.2s | fbward_time: 0.2s | iter_time: 0.4s | iter_avg_time: 0.4s | epoch_time: 10s | spend_time: 01m:17s | left_time: 01h:23m:34s pvqvae_Bihua: train: Epoch 2/150 iter 36/92 || generator | rec_loss: 0.2410 | loss: 0.2427 | quantize_loss: 0.0017 | used_unmasked_quantize_embed: 1.0000 | used_masked_quantize_embed: 1.0000 | unmasked_num_token: 14810.0000 | masked_num_token: 1574.0000 || discriminator | loss: 0.0000 || generator_lr: 8.84e-06, discriminator_lr: 8.84e-07 || data_time: 0.1s | fbward_time: 0.2s | iter_time: 0.4s | iter_avg_time: 0.4s | epoch_time: 13s | spend_time: 01m:21s | left_time: 01h:23m:19s pvqvae_Bihua: train: Epoch 2/150 iter 46/92 || generator | rec_loss: 0.2721 | loss: 0.2745 | quantize_loss: 0.0024 | used_unmasked_quantize_embed: 1.0000 | used_masked_quantize_embed: 1.0000 | unmasked_num_token: 11951.0000 | masked_num_token: 4433.0000 || discriminator | loss: 0.0000 || generator_lr: 9.24e-06, discriminator_lr: 9.24e-07 || data_time: 0.2s | fbward_time: 0.2s | iter_time: 0.4s | iter_avg_time: 0.4s | epoch_time: 17s | spend_time: 01m:24s | left_time: 01h:23m:09s pvqvae_Bihua: train: Epoch 2/150 iter 56/92 || generator | rec_loss: 0.2303 | loss: 0.2324 | quantize_loss: 0.0021 | used_unmasked_quantize_embed: 1.0000 | used_masked_quantize_embed: 1.0000 | unmasked_num_token: 15426.0000 | masked_num_token: 958.0000 || discriminator | loss: 0.0000 || generator_lr: 9.64e-06, discriminator_lr: 9.64e-07 || data_time: 0.1s | fbward_time: 0.2s | iter_time: 0.4s | iter_avg_time: 0.4s | epoch_time: 20s | spend_time: 01m:28s | left_time: 01h:22m:57s

My train dataset is 1048, is it to fast? and what's the command of train transformer, it this right? python train_net.py --name pvqvae_Bihua --config_file ./configs/put_cvpr2022/bihua/transformer_bihua.yaml --num_node 1 --tensorboard --auto_resume

liuqk3 commented 1 year ago

Hi @zhangbaijin , As you said, your dataset is small. The training time depends on the number of images in training set and the number of epochs you provided. After training, you can check the reconstruction results of an image.

For the training of transformer, you should provide a suitable config file based on the pre-trained P-VQVAE. If you indeed have written the trainsformer config file correctly, the command you give should work. But if it's me to train a transformer, I will set the --name transformer_Bihua to distinguish from the pre-trained P-VQVAE on Bihua dataset.

For more details please refer to README.md file.

zhangbaijin commented 1 year ago

Thanks a lot, there is problem when run inference.py, if len(list(set(data_i['relative_path']) & processed_image)) == len(data_i['relative_path']): KeyError: 'relative_path' And i add "relative_path" in transformer_xxx.yaml like use_provided_mask_ratio: [0.4, 1.0] use_provided_mask: 0.8 mask: 1.0 mask_low_to_high: 0.0 mask_low_size: [32, 32] multi_image_mask: False return_data_keys: [image, mask,relative_path] But it doesn't work, can you tell me how to set it ? Thanks for your time. Best wishes

liuqk3 commented 1 year ago

Did you add return_data_keys: [image, mask,relative_path] in the validation_datasets? According to my experience, if you did, it should work. But you should have a debug about this.