lewandofskee / DiAD

Official implementation of DiAD: A Diffusion-based Framework for Multi-class Anomaly Detection.
Apache License 2.0
123 stars 16 forks source link

RuntimeError: Error(s) in loading state_dict for FeatureListNet: Unexpected key(s) in state_dict: "fc.weight", "fc.bias". #22

Closed boxbox2 closed 8 months ago

boxbox2 commented 8 months ago
Traceback (most recent call last):
  File "train.py", line 50, in <module>
    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
    self._call_and_handle_interrupt(
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
    self._dispatch()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
    self.training_type_plugin.start_training(self)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
    self._results = trainer.run_stage()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in run_stage
    return self._run_train()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1304, in _run_train
    self._run_sanity_check(self.lightning_module)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1368, in _run_sanity_check
    self._evaluation_loop.run()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 140, in run
    self.on_run_start(*args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 96, in on_run_start
    self._on_evaluation_epoch_start()
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 214, in _on_evaluation_epoch_start
    self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1483, in call_hook
    output = model_fx(*args, **kwargs)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/DiAD/ldm/models/diffusion/ddpm.py", line 484, in on_validation_epoch_start
    pretrained_model = timm.create_model("resnet50", pretrained=False, features_only=True,checkpoint_path="models/resnet50_a1_0-14fe96d1.pth")
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/timm/models/factory.py", line 74, in create_model
    load_checkpoint(model, checkpoint_path)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/timm/models/helpers.py", line 75, in load_checkpoint
    incompatible_keys = model.load_state_dict(state_dict, strict=strict)
  File "/opt/conda/envs/diad/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for FeatureListNet:
        Unexpected key(s) in state_dict: "fc.weight", "fc.bias". 

Why does the generated model after I run the build_model.py show RuntimeError: Error(s) when I run the train.py? Unexpected key(s) in state_dict: "fc.weight", "fc.bias". This mistake. I checked to set strict to False


model.load_state_dict(x,False)

but isn't the original code already False?

lewandofskee commented 8 months ago

The error occurred at "/workspace/DiAD/ldm/models/diffusion/ddpm.py"

pretrained_model = timm.create_model("resnet50", pretrained=False, features_only=True,checkpoint_path="models/resnet50_a1_0-14fe96d1.pth")