Closed jiangyijin closed 1 month ago
self.encoder = timm.create_model(
'tf_efficientnet_b0_ns_jft_in1k',
pretrained=False,
features_only=True,
out_indices=[4]
)
checkpoint_path = '/home/ubuntu/Downloads/tf_efficientnet_b0.ns_jft_in1k/pytorch_model.bin'
state_dict = torch.load(checkpoint_path)
self.encoder.load_state_dict(state_dict, strict=False)
I successfully resolved the issue by ignoring the inapplicable layers mentioned above. The current operating status is:
Epoch 001 [ 415/1425] (lr: 1.00e-04) [379 (382) ms] chamfer 1.12e-02 (1.39e-02) occupancy 2.03e-01 (2.54e-01) reproj 8.81e-02 (1.14e-01) trans 2.60e-02 (3.50e-02) rot 2.09e+00 (2.20e+00) overlap 7.62e-02 (1.74e-01) taper 8.52e-02 (2.29e-01)
Is my solution correct?
Yes, only the weights of the encoder part of the CNN are loaded, and this is done automatically in timm.create_model
when features_only=True
is set.
/home/ubuntu/anaconda3/envs/jiyuantiqu/bin/python /data/tiqu_jiyuan/sat-sq-recon/tools/train.py --cfg ../experiments/config.yaml 2024-08-22 16:48:17.354611: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-08-22 16:48:18.153098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT /home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/albumentations/check_version.py:49: UserWarning: Error fetching version info
data = fetch_version_info()
2024/08/22 16:48:23 Outputs (e.g., checkpoints) are saved at: output/spe3r/exp
2024/08/22 16:48:23 Messages and tensorboard logs are saved at: log/spe3r/exp/train_20240822_16_48_23
2024/08/22 16:48:23 Random seed: 42
2024/08/22 16:48:23 GPU-accelerated training: ENABLED
2024/08/22 16:48:23 • 2 x NVIDIA GeForce RTX 4090
2024/08/22 16:48:23 Creating Model ...
2024/08/22 16:48:23 Loaded from checkpoint '/home/ubuntu/Downloads/tf_efficientnet_b0.ns_jft_in1k/tf_efficientnet_b0.ns_jft_in1k/pytorch_model.bin'
Traceback (most recent call last):
File "/data/tiqu_jiyuan/sat-sq-recon/tools/train.py", line 274, in
train(cfg)
File "/data/tiqu_jiyuan/sat-sq-recon/tools/train.py", line 90, in train
net = Model(cfg, fov=camera['horizontalFOV'], device=device)
File "/data/tiqu_jiyuan/sat-sq-recon/tools/../core/nets/model.py", line 42, in init
self.encoder = Encoder(cfg.MODEL.LATENT_DIM).to(device)
File "/data/tiqu_jiyuan/sat-sq-recon/tools/../core/nets/modules/encoder.py", line 16, in init
self.encoder = timm.create_model(
File "/home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/timm/models/_factory.py", line 122, in create_model
load_checkpoint(model, checkpoint_path)
File "/home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/timm/models/_helpers.py", line 84, in load_checkpoint
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
File "/home/ubuntu/anaconda3/envs/jiyuantiqu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for EfficientNetFeatures:
Unexpected key(s) in state_dict: "conv_head.weight", "bn2.weight", "bn2.bias", "bn2.running_mean", "bn2.running_var", "bn2.num_batches_tracked", "classifier.weight", "classifier.bias".
Process finished with exit code 1 Thank you very much for your work! Due to network issues, I downloaded the tf_efficientnet_b0.ns_jft_in1k model myself, but I keep receiving an error about the model. Could you please provide the download link for this model?