jamesdolezal / slideflow

Deep learning library for digital pathology, with both Tensorflow and PyTorch support.
https://slideflow.dev
GNU General Public License v3.0
234 stars 39 forks source link

[BUG] sf.Heatmap() function incompatible with MIL model #296

Closed quark412 closed 1 year ago

quark412 commented 1 year ago

It seems like sf.Heatmap does not work with the clam_sb model, because it assumes the existence of a file called params.json, whereas in the case of MIL models the model directory contains a file called mil_params.json instead.

My stack trace is as follows:

heatmap = sf.Heatmap(slide_path, model_path)
ModelParamsNotFoundError: Model parameters file (params.json) not found.
---------------------------------------------------------------------------
ModelParamsNotFoundError                  Traceback (most recent call last)
File <command-1127483361868555>:1
----> 1 heatmap = sf.Heatmap(slide_path, model_path)

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/slideflow/heatmap.py:131, in Heatmap.__init__(self, slide, model, stride_div, batch_size, num_threads, num_processes, img_format, generate, generator_kwargs, device, load_method, **wsi_kwargs)
    127     raise ValueError("Invalid argument: cannot supply both "
    128                      "num_processes and num_threads")
    129 self.insets = []  # type: List[Inset]
--> 131 model_config = sf.util.get_model_config(model)
    132 self.uq = model_config['hp']['uq']
    133 if img_format == 'auto' and 'img_format' not in model_config:

File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.10/site-packages/slideflow/util/__init__.py:677, in get_model_config(model_path)
    675     config = load_json(join(dirname(model_path), 'params.json'))
    676 else:
--> 677     raise errors.ModelParamsNotFoundError
    678 # Compatibility for pre-1.1
    679 if 'norm_mean' in config:

ModelParamsNotFoundError: Model parameters file (params.json) not found.

This isn't too bad an issue, though, because I am able to successfully generate heatmaps by setting attention_heatmaps=True when training clam_sb.

As a separate question, how feasible would it be to make it possible for non-attention-based models such as mil_fc to generate heatmaps (based on their activations at a given layer, say)?

jamesdolezal commented 1 year ago

Hi quark412,

sf.Heatmap only works with tile-based models (models trained with Project.train()). It does not work with MIL models.

MIL heatmaps are bit tricker for two reasons:

1) MIL models use feature bags (where features are generated from images using a feature generator / encoder). When generating a heatmap, should we require the user to pass in the calculated bags of features? Or should the user provide a feature generator / encoder, and the heatmap function would auto-generate features for each tile? There are advantages and disadvantages for each approach. 2) What is actually being displayed by the heatmap? For attention-based MIL models, it makes sense to display attention as the heatmap (which is what we do by default when training and validating MIL models). But, as you mention, what about MIL models that don't use attention? Should the interface support allowing the user to generate heatmaps using layer activations instead of attention, as you suggest?

These nuances are part of the reason why we don't yet have a unified interface for generating heatmaps from trained MIL models, although it is a high priority. We probably won't use the sf.Heatmap interface that tile-based models use, since it's a fundamentally different technical approach that would utilize different configuration parameters.

When you are training/validating an MIL model with attention_heatmaps=True, it uses a lower-level function sf.util.location_heatmap() for rendering the attention heatmaps on a slide. If you need to create a custom heatmap for your applications before we have a chance to build the unified MIL heatmap interface, I would recommend building something from that function. This is the function that we will likely use under the hood for heatmap rendering regardless.

Sorry about the lack of clarity regarding the use of sf.Heatmap and its restriction to tile-based models. I'll work on clarifying this in the documentation. I'm hoping MIL heatmaps will be available in the upcoming 2.1 release (ETA 1-2 months), although it may get pushed to 2.2.