ShanghaiTech-IMPACT / TeethDreamer

[MICCAI 2024] TeethDreamer: 3D Teeth Reconstruction from Five Intra-oral Photographs
MIT License
18 stars 2 forks source link

用自己数据训练时遇到问题 #3

Open WKangC opened 3 weeks ago

WKangC commented 3 weeks ago

您好!根据您给的提示,我运用自己的数据微调预训练zero123模型时出现如下错误“RuntimeError: Given groups=1, weight of size [320, 8, 3, 3], expected input[8, 20, 32, 32] to have 8 channels, but got 20 channels instead”。我将配置文件yaml中的in_channels参数改为8时可以正常训练,但是运用训练后ckpt文件进行推理时就会出现如下错误“RuntimeError: Error(s) in loading state_dict for TeethDreamer: size mismatch for model.diffusion_model.input_blocks.0.0.weight: copying a param with shape torch.Size([320, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 20, 3, 3]).”,请问需要如何解决呢?期待您的回复,谢谢!

Xcf-xcf commented 2 weeks ago

请问您在训练时使用的数据是用单个图片作为condition来训练的吗?

WKangC commented 2 weeks ago

不是的。根据您的提示,将训练数据按照如下格式划分: image 每个文件夹里的形式是这样的: image image 上述文件夹与图片都是通过提示的blender --background代码生成的。 splits.pkl的内容如下: image 请问一下,是那个步骤没有复现正确呢?

Xcf-xcf commented 2 weeks ago

因为这里训练时会将4个图片condition的latent与带噪目标latent拼接,所有diffusion_model的第一层卷积的输入通道应该是20。但我不知道为什么您在训练时的第一层卷积输入通道是8,您可以检查一下是否是加载zero123预训练模型时出了问题或者是模型初始化是否按照预期进行?

WKangC commented 2 weeks ago

$python TeethDreamer.py -b configs/TeethDreamer.yaml --gpus 0 --finetune_from ckpt/zero123-xl.ckpt data.target_dir=data/csy/train/target/ data.input_dir=data/csy/train/input/ data.uid_set_pkl=data/csy/train/splits.pkl data.validation_dir=data/csy/train/input/ Global seed set to 6033 Running on GPUs 0 making attention of type 'vanilla' with 512 in_channels Working with z of shape (1, 4, 32, 32) = 4096 dimensions. making attention of type 'vanilla' with 512 in_channels Selected timesteps for ddim sampler: [ 1 6 11 16 21 26 31 36 41 46 51 56 61 66 71 76 81 86 91 96 101 106 111 116 121 126 131 136 141 146 151 156 161 166 171 176 181 186 191 196 201 206 211 216 221 226 231 236 241 246 251 256 261 266 271 276 281 286 291 296 301 306 311 316 321 326 331 336 341 346 351 356 361 366 371 376 381 386 391 396 401 406 411 416 421 426 431 436 441 446 451 456 461 466 471 476 481 486 491 496 501 506 511 516 521 526 531 536 541 546 551 556 561 566 571 576 581 586 591 596 601 606 611 616 621 626 631 636 641 646 651 656 661 666 671 676 681 686 691 696 701 706 711 716 721 726 731 736 741 746 751 756 761 766 771 776 781 786 791 796 801 806 811 816 821 826 831 836 841 846 851 856 861 866 871 876 881 886 891 896 901 906 911 916 921 926 931 936 941 946 951 956 961 966 971 976 981 986 991 996] Attempting to load state from ckpt/zero123-xl.ckpt Manual init: model.diffusion_model.input_blocks.0.0.weight /home/xxx/miniconda3/envs/TeethDreamer/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:467: LightningDeprecationWarning: SettingTrainer(gpus='0')is deprecated in v1.7 and will be removed in v2.0. Please useTrainer(accelerator='gpu', devices='0')` instead. rank_zero_deprecation( GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs ============= length of train_dataset 8 ============= ============= length of val_dataset 2 ============= accumulate_grad_batches = 1 ++++ NOT USING LR SCALING ++++ Setting learning rate to 5.00e-05 ============= length of train_dataset 8 ============= ============= length of val_dataset 2 ============= LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3] setting learning rate to 0.0001 ... Setting up LambdaLR scheduler... model: base_learning_rate: 5.0e-05 target: ldm.models.diffusion.teeth_dreamer.TeethDreamer params: view_num: 8 image_size: 256 cfg_scale: 2.0 output_num: 8 batch_view_num: 4 finetune_unet: false finetune_projection: true drop_conditions: false clip_image_encoder_path: ckpt/ViT-L-14.pt scheduler_config: target: ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps:

Xcf-xcf commented 2 weeks ago

抱歉,我在readme忘记注明pkl文件名格式。这里在加载数据集时会判断pkl文件名中是否包含mv从而判断是否为单个图片作为condition。所以您可以将splits.pkl文件名改为mv-splits.pkl,其次该文件中的case id的格式应为XXX_norm_lower或XXX_norm_upper,无需包含_front,_left,_right,_upper或_down。

WKangC commented 2 weeks ago

非常感谢您耐心的回复和指导,现在已经能成功训练了。