jamesdolezal / slideflow

Deep learning library for digital pathology, with both Tensorflow and PyTorch support.
https://slideflow.dev
GNU General Public License v3.0
234 stars 39 forks source link

`generate_mil_features` for extracting last layer activations in TransMIL and AttentionMIL models #304

Closed matte-esse closed 8 months ago

matte-esse commented 11 months ago

Added new function generate_mil_features that returns a MILFeatures object generated from extracting activation weights from the last layer of a TransMIL, AttentionMIL, CLAM_sb, or CLAM_mb model.

Required dependencies: None Args:

weights (str): Path to model weights to load. config (:class:slideflow.mil.TrainerConfigFastAI or :class:slideflow.mil.TrainerConfigCLAM): Configuration for building model. If weights is a path to a model directory, will attempt to read mil_params.json from this location and load saved configuration. Defaults to None. dataset (:class:slideflow.Dataset): Dataset. outcomes (str, list(str)): Outcomes. bags (str, list(str)): Path to bags, or list of bag file paths. Each bag should contain PyTorch array of features from all tiles in a slide, with the shape (n_tiles, n_features). get_last_layer_activations in the model classes is a copy of the forward method, but simply returns the layer weights before they are converted to scores.

matte-esse commented 10 months ago

Thanks, Matt. I've completed my first pass changes. Could you confirm that functionality continues to work as intended on your end? I will work on more rigorous testing in the next few days.

Yes, absolutely. I will comment here once I'm done with my testing.

matte-esse commented 9 months ago

@CostanzaSiniscalchi tested code and pushed some changes. Everything looks good on our end.