mahmoodlab / UNI

Towards a general-purpose foundation model for computational pathology - Nature Medicine
Other
335 stars 44 forks source link

Release trained model on SegPath data #26

Open abs51295 opened 6 months ago

abs51295 commented 6 months ago

Hello,

Are you planning to release the trained model on SegPath data so that we can directly run inference on our samples?

pidemal commented 4 months ago

Not sure if this helps, but I made the following wrapper: @abs51295 , would appreciate if @Richarizardd confirmed what type of adapters they used. My understanding is that they did an end-to-end finetune.

## Wrap UNI to make it compatible for Mask2Former backbone
class UNIWrapper(nn.Module):
    def __init__(self, out_channels_list=[96, 192, 384, 768], num_features=4):
        super().__init__()
        self.uni_model = timm.create_model(
            "hf-hub:MahmoodLab/uni", 
            pretrained=True, 
            init_values=1e-5, 
            dynamic_img_size=True
        )

        # for param in self.uni_model.parameters():
        #     param.requires_grad = False

        self.adapters = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1024, out_channels, kernel_size=1),
                nn.Upsample(scale_factor=2**(num_features-1-i))
            ) for i, out_channels in enumerate(out_channels_list)
        ])

    def forward(self, x):
        features = self.uni_model(x)
        batch_size, num_features = features.shape
        features = features.view(batch_size, 1024, 1, 1)

        multi_scale_features = []
        for adapter in self.adapters:
            multi_scale_features.append(adapter(features))

        return type('FeatureMaps', (), {'feature_maps': multi_scale_features})()

def load_model(num_classes=1):
    image_processor = Mask2FormerImageProcessor(
        reduce_labels=True,
        do_resize=True,
        ignore_index=255,
        size={"height": 224, "width": 224}
    )

    model_config = Mask2FormerConfig.from_pretrained("facebook/mask2former-swin-tiny-cityscapes-semantic")
    login(token=HUGGINGFACE_TOKEN)
    # Instantiate Mask2Former model with UNI backbone 
    model = Mask2FormerForUniversalSegmentation(model_config)
    model.model.pixel_level_module.encoder = UNIWrapper()

    total_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"total trainable parameters: {total_parameters:,}")

    return model, image_processor