SBU-BMI / wsinfer

🔥 🚀 Blazingly fast pipeline for patch-based classification in whole slide images
https://wsinfer.readthedocs.io
Apache License 2.0
59 stars 10 forks source link

add TILs VGG16 model #110

Closed kaczmarj closed 1 year ago

kaczmarj commented 1 year ago

this issue tracks the addition of the vgg16 tils model.

weights were pulled from Box: https://stonybrookmedicine.app.box.com/folder/128593362243?v=til-results-new-model

config for the original vgg16 tils model is here: https://stonybrookmedicine.app.box.com/file/757453637020

pixel values are normalized to [-1, 1] (see the code below from the Box folder with the original model).

    def preprocess_input(self, inputs):
        # normalize (mean 0, std=2)
        np.clip(inputs, 0, 255, inputs);
        inputs /= 255;
        inputs -= 0.5;
        inputs *= 2;
        inputs = tf.image.resize_images(inputs, (self.cnn_arch.input_img_height, self.cnn_arch.input_img_width));
        inputs = inputs.eval()
        return inputs;

the weights were then converted from tensorflow to pytorch using the conversion script in this repo. some names of weights were renamed to conform to timm using the code below:

def _filter_fn(state_dict):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        k_r = k
        k_r = k_r.replace('classifier.0', 'pre_logits.fc1')
        k_r = k_r.replace('classifier.3', 'pre_logits.fc2')
        k_r = k_r.replace('classifier.6', 'head.fc')
        if 'classifier.0.weight' in k:
            v = v.reshape(-1, 512, 7, 7)
        if 'classifier.3.weight' in k:
            v = v.reshape(-1, 4096, 1, 1)
        out_dict[k_r] = v
    return out_dict
kaczmarj commented 1 year ago

looks like the conversion script had to be updated to transpose the tf weights in fc6.

def convert_tf_to_pytorch(input_path, output_path, num_classes: int):
    try:
        ckpt = tf.train.load_checkpoint(input_path)
    except tf.errors.DataLossError:
        raise RuntimeError(
            "Error: could not load checkpoint. Did you pass in the stem of the path?"
            "Pass in the path without '.index' or '.meta' or '.data-00000-of-00001'."
        )

    new_state_dict = {}
    for tf_prefix, torch_prefix in tf_to_pytorch_layers:
        tf_weights = f"{tf_prefix}/weights"
        tf_biases = f"{tf_prefix}/biases"
        torch_weights = f"{torch_prefix}.weight"
        torch_biases = f"{torch_prefix}.bias"
        tf_weight_array = ckpt.get_tensor(tf_weights)
        tf_bias_array = ckpt.get_tensor(tf_biases)
        if "conv" in tf_weights:
            tf_weight_array = tf_weight_array.transpose([3, 2, 0, 1])
        elif "fc" in tf_weights:
            if tf_weights == "vgg_16/fc6/weights":
                # [7, 7, 512, 4096] -> [25088, 4096]
                tf_weight_array = tf_weight_array.transpose([2, 0, 1, 3])
                tf_weight_array = tf_weight_array.reshape((25088, 4096))
            # E.g., go from shape [1, 1, 4096, 1000] to [1000, 4096]
            tf_weight_array = tf_weight_array.squeeze().T
        new_state_dict[torch_weights] = torch.from_numpy(tf_weight_array)
        new_state_dict[torch_biases] = torch.from_numpy(tf_bias_array)

        # Test that conversion was (probably) done correctly.
    true_model = torchvision.models.vgg16()
    true_model.classifier[6] = torch.nn.Linear(4096, num_classes)
    if true_model.state_dict().keys() != new_state_dict.keys():
        raise RuntimeError(
            "Something went wrong... converted model keys do not match TorchVision"
            " VGG16 keys."
        )
    true_state_dict = true_model.state_dict()
    for true_k, new_k in zip(true_state_dict, new_state_dict):
        true_shape = true_state_dict[true_k].shape
        new_shape = new_state_dict[new_k].shape
        if true_shape != new_shape:
            raise ValueError(
                "Shape mismatch between converted parameters and reference parameters."
            )

    torch.save(new_state_dict, output_path)