tpark94 / sat-sq-recon

PyTorch implementation of the AIAA SciTech paper titled "Rapid Abstraction of Spacecraft 3D Structure from Single 2D Image"
MIT License
2 stars 2 forks source link

Model Error #1

Closed jiangyijin closed 1 month ago

jiangyijin commented 1 month ago

/home/ubuntu/anaconda3/envs/jiyuantiqu/bin/python /data/tiqu_jiyuan/sat-sq-recon/tools/train.py --cfg ../experiments/config.yaml 2024-08-22 16:48:17.354611: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-08-22 16:48:18.153098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT /home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/albumentations/check_version.py:49: UserWarning: Error fetching version info data = fetch_version_info() 2024/08/22 16:48:23 Outputs (e.g., checkpoints) are saved at: output/spe3r/exp 2024/08/22 16:48:23 Messages and tensorboard logs are saved at: log/spe3r/exp/train_20240822_16_48_23 2024/08/22 16:48:23 Random seed: 42 2024/08/22 16:48:23 GPU-accelerated training: ENABLED 2024/08/22 16:48:23 • 2 x NVIDIA GeForce RTX 4090 2024/08/22 16:48:23 Creating Model ... 2024/08/22 16:48:23 Loaded from checkpoint '/home/ubuntu/Downloads/tf_efficientnet_b0.ns_jft_in1k/tf_efficientnet_b0.ns_jft_in1k/pytorch_model.bin' Traceback (most recent call last): File "/data/tiqu_jiyuan/sat-sq-recon/tools/train.py", line 274, in train(cfg) File "/data/tiqu_jiyuan/sat-sq-recon/tools/train.py", line 90, in train net = Model(cfg, fov=camera['horizontalFOV'], device=device) File "/data/tiqu_jiyuan/sat-sq-recon/tools/../core/nets/model.py", line 42, in init self.encoder = Encoder(cfg.MODEL.LATENT_DIM).to(device) File "/data/tiqu_jiyuan/sat-sq-recon/tools/../core/nets/modules/encoder.py", line 16, in init self.encoder = timm.create_model( File "/home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/timm/models/_factory.py", line 122, in create_model load_checkpoint(model, checkpoint_path) File "/home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/timm/models/_helpers.py", line 84, in load_checkpoint incompatible_keys = model.load_state_dict(state_dict, strict=strict) File "/home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for EfficientNetFeatures: Unexpected key(s) in state_dict: "conv_head.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2.num_batches_tracked", "classifier.weight", "classifier.bias".

Process finished with exit code 1 Thank you very much for your work! Due to network issues, I downloaded the tf_efficientnet_b0.ns_jft_in1k model myself, but I keep receiving an error about the model. Could you please provide the download link for this model?

jiangyijin commented 1 month ago

self.encoder = timm.create_model( 'tf_efficientnet_b0_ns_jft_in1k', pretrained=False,
features_only=True, out_indices=[4] ) checkpoint_path = '/home/ubuntu/Downloads/tf_efficientnet_b0.ns_jft_in1k/pytorch_model.bin' state_dict = torch.load(checkpoint_path) self.encoder.load_state_dict(state_dict, strict=False) I successfully resolved the issue by ignoring the inapplicable layers mentioned above. The current operating status is: Epoch 001 [ 415/1425] (lr: 1.00e-04) [379 (382) ms] chamfer 1.12e-02 (1.39e-02) occupancy 2.03e-01 (2.54e-01) reproj 8.81e-02 (1.14e-01) trans 2.60e-02 (3.50e-02) rot 2.09e+00 (2.20e+00) overlap 7.62e-02 (1.74e-01) taper 8.52e-02 (2.29e-01)

jiangyijin commented 1 month ago

Is my solution correct?

tpark94 commented 1 month ago

Yes, only the weights of the encoder part of the CNN are loaded, and this is done automatically in timm.create_model when features_only=True is set.