toshas / torch-fidelity

High-fidelity performance metrics for generative models in PyTorch
Other
1.01k stars 66 forks source link

How to use ResNet34 instead of Inception v3 to compute FID and KID #60

Closed Xingcun-Li closed 3 months ago

Xingcun-Li commented 3 months ago

Could you please advise on what steps I should take if I want to use a pre-trained ResNet34 model, which produces 512-dimensional features on my own dataset, to replace the original Inception v3 for calculating FID and KID?

toshas commented 3 months ago

Sure, you just need to subclass and register a new feature extractor as explained here: https://torch-fidelity.readthedocs.io/en/latest/extensibility.html#register-a-new-feature-extractor

This way you can either use this new feature extractor withing the calculate_metrics API, or you can copy-paste the fidelity.py file, pre-register your new feature extractor, and use it as a shell app supporting your new feature extractor.

For an example of subclassing, check here: https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_dinov2.py

Xingcun-Li commented 3 months ago
`class` FeatureExtractorResNet34(FeatureExtractorBase):
    INPUT_IMAGE_SIZE = 256
    MEAN = [0.5, 0.5, 0.5]
    STD = [0.5, 0.5, 0.5]

    def __init__(self, name, features_list, feature_extractor_weights_path=None, **kwargs):
        super(FeatureExtractorResNet34, self).__init__(name, features_list)
        self.model = torch.load("./models/datasetkidney4classes_modelResNet34.pt")
        self.eval()
        self.requires_grad_(False)

    def forward(self, x):
        print(f"x.shape = {x.shape}")
        if not torch.is_tensor(x):
            raise TypeError("Expecting input as torch.Tensor")
        if x.dtype != torch.float32:
            x = x.to(torch.float32)

        if not (x.dim() == 4 and x.shape[1] == 3):
            raise ValueError(f"Input is not Bx3xHxW: {x.shape}")

        x = torchvision.transforms.functional.resize(x, (self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE))
        x = torchvision.transforms.functional.normalize(x, mean=self.MEAN, std=self.STD)

        features = self.model(x)
        print(f"features.shape={features.shape}")
        features = features.view(features.size(0), -1)
        print(f"after process features.shape={features.shape}")

        features_dict = {
            "2048": features
        }

        return tuple(features_dict[a] for a in self.features_list)

    @staticmethod
    def get_provided_features_list():
        return ("2048",)

    @staticmethod
    def get_default_feature_layer_for_metric(metric):
        return {
            "isc": "2048",
            "fid": "2048",
            "kid": "2048",
            "prc": "2048",
        }[metric]

    @staticmethod
    def can_be_compiled():
        return True

    @staticmethod
    def get_dummy_input_for_compile():
        return (torch.rand([1, 3, 256, 256]) * 255).to(torch.float32)

def calculate_fid_kid(self, model_path, real_images_path, generated_images_path, kid_subset_size=100):
        torch_fidelity.register_feature_extractor('resnet34-fe', FeatureExtractorResNet34)
        feature_extractor = FeatureExtractorResNet34(
                name='resnet34-fe',
                features_list=['2048'],
                feature_extractor_weights_path=model_path
        )

        metrics = torch_fidelity.calculate_metrics(
            input1=real_images_path,
            input2=generated_images_path,
            cuda=True,  # GPU
            isc=False,   # Inception Score
            fid=True,   # FID
            kid=True,   # KID
            samples_find_deep=True,  # 递归搜索目录
            kid_subset_size=kid_subset_size,
            feature_extractor='resnet34-fe'
        )
        return metrics

@toshas I defined and registered a custom feature extractor, as shown in the code above. However, the FID and KID outputs seem unusually large. If I don't use '2048' in the code, it keeps throwing errors. Is it feasible to use the 512-dimensional features from ResNet34? Could you provide some suggestions? Thank you very much~

toshas commented 3 months ago

I can make a few educated guesses without running this code:

  1. You seem to be loading a non-standard resnet-34 from file. It might be custom-trained, but I think a default ImageNet trained one is a better universal feature extractor than anything custom-made. So maybe try to load it from torchvision.
  2. A standard resnet would be trained to accept images of 224x224 size, not 256x256. Since it might already be fully convolutional, your features are a 4D tensor rather than 2D, with unpooled spatial dimension. Instead of doing features = features.view(features.size(0), -1) you'd have to resize to such an input size, that your spatial dimensions are (1,1), after which you can just squeeze them and obtain a (B,2048) tensor.
  3. You seem to normalize by 0.5; standard imagenet-trained image statistics are different - this may play a role in the magnitudes of your features.
  4. Double check the normalization procedure, that values in the range 0,255 never skip normalization (or are not normalized with 0.5, 0.5 stats. The stats with magnitude less than 0.5 means that the input should be not 0,255, but 0,1.
Xingcun-Li commented 3 months ago

I can make a few educated guesses without running this code:

  1. You seem to be loading a non-standard resnet-34 from file. It might be custom-trained, but I think a default ImageNet trained one is a better universal feature extractor than anything custom-made. So maybe try to load it from torchvision.
  2. A standard resnet would be trained to accept images of 224x224 size, not 256x256. Since it might already be fully convolutional, your features are a 4D tensor rather than 2D, with unpooled spatial dimension. Instead of doing features = features.view(features.size(0), -1) you'd have to resize to such an input size, that your spatial dimensions are (1,1), after which you can just squeeze them and obtain a (B,2048) tensor.
  3. You seem to normalize by 0.5; standard imagenet-trained image statistics are different - this may play a role in the magnitudes of your features.
  4. Double check the normalization procedure, that values in the range 0,255 never skip normalization (or are not normalized with 0.5, 0.5 stats. The stats with magnitude less than 0.5 means that the input should be not 0,255, but 0,1.

I'm following and replicating an image-to-image translation paper from TIP24, where they use a retrained ResNet34 on a specific dataset to evaluate the FID and KID of the generated images. The torch.load here loads a standard ResNet-34 without the head. I also tried using the pretrained ResNet-34 from torchvision, but the FID and KID were still abnormally high. The features obtained from the avgpool are of shape [64, 512, 1, 1], so there's a features = features.view(features.size(0), -1) operation to reshape it to [64, 512]. I'm using the normalization method from the retrained ResNet34. I'm considering whether there might be an issue with the definition and registration method of my feature extractor, or if torch-fidelity doesn't support 512-dimensional feature representations.

toshas commented 3 months ago

It does not exclude any dimensionality of features, 512 in particular. Inception has various features like 192 and 768; you can try using them to see if the resulting magnitudes are abnormal

Xingcun-Li commented 3 months ago

Many thanks~