Open lisaner000 opened 6 months ago
During the training,I met the bug as following. Can you give me some advice? Epoch 1/45 输入的s_feat的大小为: torch.Size([32, 9, 2048]) Traceback (most recent call last): File "train.py", line 142, in main(cfg) File "train.py", line 133, in main debug_freq=cfg.DEBUG_FREQ, File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 344, in fit self.train() File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 188, in train preds, scores = self.generator(inp, is_train=True) File "/root/miniconda3/envs/yzy_staf/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-tmp/STAF-main/lib/models/staf.py", line 321, in forward mid_s_feat = s_feat.permute(0, 2, 3, 4, 1) RuntimeError: number of dims don't match in permute
The s_feat size should be [batch_size, 9, 2048, 7, 7]. I guess that you forgot to delete the last average pool of ResNet50 so that you got the s_feat with a shape of [31, 9, 2048].
During the training,I met the bug as following. Can you give me some advice? Epoch 1/45 输入的s_feat的大小为: torch.Size([32, 9, 2048]) Traceback (most recent call last): File "train.py", line 142, in main(cfg) File "train.py", line 133, in main debug_freq=cfg.DEBUG_FREQ, File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 344, in fit self.train() File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 188, in train preds, scores = self.generator(inp, is_train=True) File "/root/miniconda3/envs/yzy_staf/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-tmp/STAF-main/lib/models/staf.py", line 321, in forward mid_s_feat = s_feat.permute(0, 2, 3, 4, 1) RuntimeError: number of dims don't match in permute
The s_feat size should be [batch_size, 9, 2048, 7, 7]. I guess that you forgot to delete the last average pool of ResNet50 so that you got the s_feat with a shape of [31, 9, 2048].
Thank you very much! Can you describe the answer to be more specific, please? I should check which python file?
During the training,I met the bug as following. Can you give me some advice? Epoch 1/45 输入的s_feat的大小为: torch.Size([32, 9, 2048]) Traceback (most recent call last): File "train.py", line 142, in main(cfg) File "train.py", line 133, in main debug_freq=cfg.DEBUG_FREQ, File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 344, in fit self.train() File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 188, in train preds, scores = self.generator(inp, is_train=True) File "/root/miniconda3/envs/yzy_staf/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-tmp/STAF-main/lib/models/staf.py", line 321, in forward mid_s_feat = s_feat.permute(0, 2, 3, 4, 1) RuntimeError: number of dims don't match in permute
The s_feat size should be [batch_size, 9, 2048, 7, 7]. I guess that you forgot to delete the last average pool of ResNet50 so that you got the s_feat with a shape of [31, 9, 2048].
Thank you very much! Can you describe the answer to be more specific, please? I should check which python file?
The backbone of STAF is defined as the class PyMAF_ResNet in lib/models/pymaf_resnet.py. You can initiate the feature extractor like how I did in demo.py
During the training,I met the bug as following. Can you give me some advice? Epoch 1/45 输入的s_feat的大小为: torch.Size([32, 9, 2048]) Traceback (most recent call last): File "train.py", line 142, in
main(cfg)
File "train.py", line 133, in main
debug_freq=cfg.DEBUG_FREQ,
File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 344, in fit
self.train()
File "/root/autodl-tmp/STAF-main/lib/core/trainer.py", line 188, in train
preds, scores = self.generator(inp, is_train=True)
File "/root/miniconda3/envs/yzy_staf/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/autodl-tmp/STAF-main/lib/models/staf.py", line 321, in forward
mid_s_feat = s_feat.permute(0, 2, 3, 4, 1)
RuntimeError: number of dims don't match in permute