jamesdolezal / slideflow

Deep learning library for digital pathology, with both Tensorflow and PyTorch support.
https://slideflow.dev
GNU General Public License v3.0
230 stars 38 forks source link

[BUG] Heatmap generation error #360

Open jziggles opened 2 months ago

jziggles commented 2 months ago

I receive the following error with executing train_mil function below. I am using clam_mb as my MIL model. Please note that the code executes totally fine/completes when I do not include attention heatmaps. When using other MIL models (attention_mil) the heatmaps are generated fine without a problem. It would appear as though imshow is getting a 3D array with inappropriate dimensions (I believe the final dimension = 2 is the error?).
for train, val in splits: P.train_mil( config=config, outcomes='ER_Status_By_IHC', train_dataset=train, val_dataset=val, bags=('./path/pt_files/' + 'resnet50_postconv' + '/'), attention_heatmaps=True, cmap='cividis', interpolation=None )

Python 3.12.2 Slideflow 2.3.1

Traceback (most recent call last):

File "/XXX/path/slideflow/full_mil_clam.py", line 69, in

P.train_mil(

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/slideflow/project.py", line 3985, in train_mil

return train_mil(

^^^^^^^^^^

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/slideflow/mil/train/init.py", line 80, in train_mil

return _train_mil(config, **mil_kwargs)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/slideflow/mil/train/init.py", line 157, in _train_mil

return train_fn(

^^^^^^^^^

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/slideflow/mil/train/init.py", line 786, in train_fastai

generate_attention_heatmaps(

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/slideflow/mil/eval.py", line 824, in generate_attention_heatmaps

sf.util.location_heatmap(

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/slideflow/util/init.py", line 1361, in location_heatmap

ax.imshow(

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/matplotlib/init.py", line 1465, in inner

return func(ax, map(sanitize_sequence, args), *kwargs)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/matplotlib/axes/_axes.py", line 5759, in imshow

im.set_data(X)

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/matplotlib/image.py", line 723, in set_data

self._A = self._normalize_image_array(A)

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/jrg97/.conda/envs/slideflow-env/lib/python3.12/site-packages/matplotlib/image.py", line 693, in _normalize_image_array

raise TypeError(f"Invalid shape {A.shape} for image data")

TypeError: Invalid shape (27, 66, 2) for image data