SCLBD / DeepfakeBench

A comprehensive benchmark of deepfake detection
Other
453 stars 59 forks source link

[BUG] Mismatched pretrained model architectures in efficientnetb4.py #79

Open bendoesai opened 2 months ago

bendoesai commented 2 months ago

Hi all, I'm doing deepfake detection using the EfficientNetB4 architecture. I ran into an issue when I tried to load the pretrained weights provided.

After some digging, I found the source of the problem to be in the initializer of the EfficientNetB4 object (\training\networks\efficientnetb4.py). Line 32 is where the class invokes the efficientnet_pytorch package to generate the EfficientNetB4 architecture, either empty or pretrained on ImageNet. The problem is that EfficientNet trained on ImageNet doesn't match the architecture of the pretrained weights provided for deepfake detection. There's two possible ways I can think to resolve this issue:

  1. Let efficientnet_pytorch do all the work by passing weights_path as well as in_channels and num_classes to the from_pretrained() call. This would require that efficientnetb4 be retrained so that the naming conventions of the saved weights matches that of efficientnet_pytorch.

  2. (what I did) Load weights after reshaping the model. This option would eliminate the from_pretrained() call and instead load the model architecture (using from_name()), reshape it as needed, and then call

    if efficientnetb4_config['pretrained']:
        self.load_state_dict(torch.load(efficientnetb4_config['pretrained']), strict=False)

    at the end of __init__(). This would allow the current set of pretrained weights to work within DeepfakeBench and allow more control of model output reshaping.

I can provide the code I am currently using that I have validated on both training and test scripts if necessary. This is my first time submitting an issue on GitHub, so apologies if there's anything missing.

bendoesai commented 2 months ago

After further investigation, I'm finding that there are a few import issues caused by strict state_dict() imports due to differences in initializations of the model. currently found issues in efficientnet and I3D.

Examples of the error below:

Let me know if any of you have run into this issue and know a quick solution.

zhude233 commented 2 months ago

After further investigation, I'm finding that there are a few import issues caused by strict state_dict() imports due to differences in initializations of the model. currently found issues in efficientnet and I3D.

Examples of the error below:

  • EfficientNetB4: (additional in front of all keys)Missing key(s) in state_dict: "efficientnet._conv_stem.weight"``Unexpected key(s) in state_dict: "backbone.efficientnet._conv_stem.weight"``backbone.
  • I3D: (lacking in front of all keys)Missing key(s) in state_dict: "resnet.s1.pathway0_stem.conv.weight"``Unexpected key(s) in state_dict: "s1.pathway0_stem.conv.weight"``resnet.

Let me know if any of you have run into this issue and know a quick solution.

You just need to set the path of the function that loads the weights to None, and it will automatically download them from the internet. No other modifications are necessary.

bendoesai commented 2 months ago

You just need to set the path of the function that loads the weights to None, and it will automatically download them from the internet. No other modifications are necessary.

what function are you referring to?

zhude233 commented 2 months ago

You just need to set the path of the function that loads the weights to None, and it will automatically download them from the internet. No other modifications are necessary.

what function are you referring to?

image

Please change this address to None.You will be able to correctly load the pre-trained weights without losing the match.

bendoesai commented 2 months ago

Please change this address to None.You will be able to correctly load the pre-trained weights without losing the match.

If I set that path to None, Efficientnet-pytorch loads imagenet trained data. How do I load the weights that are provided by deepfakebench?

zhude233 commented 2 months ago

Please change this address to None.You will be able to correctly load the pre-trained weights without losing the match.

If I set that path to None, Efficientnet-pytorch loads imagenet trained data. How do I load the weights that are provided by deepfakebench?

The original wording in the deepfake bench readme is: "To run the training code, you should first download the pretrained weights for the corresponding backbones (These pre-trained weights are from ImageNet). You can download them from [Link]. After downloading, you need to put all the weights files into the folder ./training/pretrained." They only provide the pre-trained weights from ImageNet.

bendoesai commented 2 months ago

The original wording in the deepfake bench readme is: "To run the training code, you should first download the pretrained weights for the corresponding backbones (These pre-trained weights are from ImageNet). You can download them from [Link]. After downloading, you need to put all the weights files into the folder ./training/pretrained." They only provide the pre-trained weights from ImageNet.

there are pre-trained weights for detectors located at this link which is what I am trying to load (also linked at the top of the readme). These weights do not match the architecture for ImageNet classification, but do match the detector architecture. For example, they contain last_layer from line 50 of efficientnetb4.py as opposed to _fc as efficientnet_pytorch codes it.

zhude233 commented 2 months ago

The original wording in the deepfake bench readme is: "To run the training code, you should first download the pretrained weights for the corresponding backbones (These pre-trained weights are from ImageNet). You can download them from [Link]. After downloading, you need to put all the weights files into the folder ./training/pretrained." They only provide the pre-trained weights from ImageNet.

there are pre-trained weights for detectors located at this link which is what I am trying to load (also linked at the top of the readme). These weights do not match the architecture for ImageNet classification, but do match the detector architecture. For example, they contain last_layer from line 50 of efficientnetb4.py as opposed to _fc as efficientnet_pytorch codes it.

Sorry, I didn't see the weights before.

RickyZi commented 2 months ago

After further investigation, I'm finding that there are a few import issues caused by strict state_dict() imports due to differences in initializations of the model. currently found issues in efficientnet and I3D.

Examples of the error below:

  • EfficientNetB4: Missing key(s) in state_dict: "efficientnet._conv_stem.weight" Unexpected key(s) in state_dict: "backbone.efficientnet._conv_stem.weight" (additional backbone. in front of all keys)
  • I3D: Missing key(s) in state_dict: "resnet.s1.pathway0_stem.conv.weight" Unexpected key(s) in state_dict: "s1.pathway0_stem.conv.weight" (lacking resnet. in front of all keys)

Let me know if any of you have run into this issue and know a quick solution.

Hi @bendoesai thanks for highlighting this problem. In case you're still struggling with this, here's how I've corrected the mismatched state_dict keys. I've saved a copy of the EfficienNet pretrained weights and modified the keys in the following manner:

state_dict = torch.load(config['pretrained'])
# Step 1: Create a temporary list for keys to be modified
keys_to_modify = []

# Step 2: Iterate over state_dict.items() and identify keys to modify
for name, weights in state_dict.items():
    if 'backbone.' in name:
        new_key = name.replace('backbone.', '')
        keys_to_modify.append((name, new_key, weights))

# Step 3: Modify state_dict based on the keys identified
for old_key, new_key, weights in keys_to_modify:
    state_dict[new_key] = weights  # Add the new key with the correct name
    del state_dict[old_key]  # Delete the old key

# Step 4: Print all keys in state_dict
for key in state_dict.keys():
    print(key)

The model seems to load as expected with two output classes on the last layer.

The same solution might also be applied for the I3D model, by changing a bit the Step 2 (though I haven't tested it):

# Step 2: Iterate over state_dict.items() and identify keys to modify
for name, weights in state_dict.items():
    # add resnet. to the keys
    new_key = 'resnet.' + name
    keys_to_modify.append((name, new_key, weights))