LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701
MIT License
843 stars 40 forks source link

Question about training RDM with main_rdm.py #18

Open wyyy04 opened 11 months ago

wyyy04 commented 11 months ago

Thanks for your excellent work! I am facing difficulties while training RDM, and I hope to receive your assistance.

When training RDM, in ddpm.py, at line 564 in the get_input function, the input x (32, 256, 256, 3) after feature extraction has a shape of (32, 197, 768), where 32 is the batch size. However, an error occurs at line 578 ”rep = self.pretrained_encoder.head(rep)“ with the following traceback:

File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/user/Diffusion_FSAR/rcg-main/rdm/models/diffusion/ddpm.py", line 578, in get_input rep = self.pretrained_encoder.head(rep) File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward return F.batch_norm( File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py", line 2450, in batch_norm return torch.batch_norm( RuntimeError: running_mean should contain 197 elements not 4096

This appears to be a mismatch between the dimensions of the input x and the model "self.pretrained_encoder.head". I am uncertain about the cause, and I am hopeful to receive your clarification and support. Thank you!

LTH14 commented 11 months ago

Thanks for your interest! Please make sure your timm version is 0.3.2, as later versions use a different forward_features implementation. #9 Please check this issue for a similar problem and solution.

wyyy04 commented 11 months ago

Thank you very much for your guidance. I have resolved the current issue by updating the “timm” to 0.3.2. Best wishes!