TissueImageAnalytics / tiatoolbox

Computational Pathology Toolbox developed by TIA Centre, University of Warwick.
https://warwick.ac.uk/tia
Other
381 stars 79 forks source link

Integrate foundation models available through `timm`: UNI, Virchow, Hibou, H-optimus-0, etc. #855

Open GeorgeBatch opened 2 months ago

GeorgeBatch commented 2 months ago

Description

I think it would be useful to integrate pre-trained foundation models from other labs into tiatoolbox.models.architecture.vanilla.py.

Currently, the _get_architecture() function allows the use of models from torchvision.models.

But another function _get_timm_architecture() could be made to incorporate foundation models which are available from timm with weights on HuggingFace Hub. All the models from time that I've used require users to sign the licence agreement with the authors, so the licencing question seems to be solved itself since there is no way users will get access to the model weights just through Tiatoolbox without getting the access request approved by the authors first.

What I Did

To add them myself, I copied de definition of CNNBackbone changing

  1. self.feat_extract = _get_timm_architecture(backbone)
  2. removed global average pooling because given a batch of images, these pathology foundation models come ready to output a feature vector of size (batch_size, embedding_size)

https://github.com/TissueImageAnalytics/tiatoolbox/blob/015652cc5c2357070592c8d46f5f2ff9a905a5c6/tiatoolbox/models/architecture/vanilla.py#L176-L270

Suggestion

Would you be interested in adding this functionality? If yes, I can make a pull request.

shaneahmed commented 2 months ago

This would be great. Please go ahead and create a PR. You can use logger.info to explain how to access weights.