LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701
MIT License
785 stars 36 forks source link

How to train RCG with small-scale dataset? #19

Closed fungtion closed 8 months ago

fungtion commented 8 months ago

Hi, I have a small-scale dataset of 10k+ images, and they are mainly about human crowd. I trained RDM with these images only, and generated images using pretrained mocov3 and MAGE model on imagenet, but the FID of the results is very high ~50, and I have no idea what is wrong with it, maybe the scale of dataset?

LTH14 commented 8 months ago

One thing that might affect the performance is the VQGAN tokenizer: the provided tokenizer is pre-trained on ImageNet, and may not achieve good reconstruction performance on your dataset. The provided moco v3 is also pre-trained on ImageNet, which might affect performance. Dataset scale might be a problem, but you have 10k+ images already, which should produce reasonable results.

fungtion commented 8 months ago

Maybe I should finetune all these models with my custom data, is there any difference between finetuning moco, rdm, pixel generator and other common deep learning models, such as freezing layers, reduce learning rate?

LTH14 commented 8 months ago

Unfortunately I don't have any experience about this -- maybe first try fine-tuning with a small learning rate and look at the losses.

fungtion commented 8 months ago

Ok, I will try it myself, thank you.