scverse / spatialdata-plot

Static plotting for spatialdata
BSD 3-Clause "New" or "Revised" License
27 stars 13 forks source link

Plotting multiple elements in the same `ax` seems to work only when `show()` is not called. #71

Open LucaMarconato opened 1 year ago

LucaMarconato commented 1 year ago

I refer to the code mentioned in this other issue: https://github.com/scverse/spatialdata-plot/issues/68

This code here:

    ax = plt.gca()
    sdata.pl.render_shapes(element='s', na_color=(0.5, 0.5, 0.5, 0.5)).pl.render_points().pl.show(ax=ax)
    sdata.pl.render_shapes(element='c', na_color=(0.7, 0.7, 0.7, 0.5)).pl.show(ax=ax)
    plt.show()

doesn't work if I run the code as a script, but it works in interactive mode (where because of a bug the plots are not shown until I call plt.show()). I suggest to do like scanpy and having a parameter show: bool. I suggest also that if the parameter ax is not None, then show is set to False. I don't remember if this one is also a behavior of scanpy, but I think it's reasonable.

josenimo commented 1 week ago

Hello devs,

I have a really cool function on my hands, and I saving a summary plot is proving to be quite difficult. So I am kinda restarting this issue.

My function would take an image as an input, perform segmentation of the image using Cellpose via SOPA, and produce a PNG file with a hyperparameter search, to decide what is the best segmentation.

Currently I am running this code for plotting each ax object, in a fig that has many axes.

sdata.pl.render_images(
    element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green']
).pl.render_shapes(
    element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32
).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png"))

When this line is reached in the CLI a matplotlib popup comes up with the entire figure, but a single filled ax object. I have to manually close this first figure, and then the other axes are plotted, and then the entire figure saved (I think overwriting itself).

I have looked into matplotlib docs but I found no clear answer.

Any tips, ideas, or comments, very welcome. For the plotting or the function in general.

Best, Jose

Entire script Function ```python #system from loguru import logger import argparse import sys import os import time import spatialdata import spatialdata_plot #imports import skimage.segmentation as segmentation import skimage.io as io import numpy as np #yaml import yaml import math import matplotlib.pyplot as plt import re import os import matplotlib.gridspec as gridspec #sopa import sopa.segmentation import sopa.io def get_args(): """ Get arguments from command line """ description = """Expand labeled masks by a certain number of pixels.""" parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawDescriptionHelpFormatter) inputs = parser.add_argument_group(title="Required Input", description="Path to required input file") inputs.add_argument("-i", "--input", dest="input", action="store", required=True, help="File path to input mask or folders with many masks") inputs.add_argument("-c", "--config", dest="config", action="store", required=True, help="Path to config.yaml for cellpose parameters") inputs.add_argument("-o", "--output", dest="output", action="store", required=True, help="Path to output mask, or folder where to save the output masks") inputs.add_argument("-l", "--log-level",dest="loglevel", default='INFO', choices=["DEBUG", "INFO"], help='Set the log level (default: INFO)') arg = parser.parse_args() arg.input = os.path.abspath(arg.input) arg.config = os.path.abspath(arg.config) arg.output = os.path.abspath(arg.output) return arg def check_input_outputs(args): """ Check if input and output files exist """ #input assert os.path.isfile(args.input), "Input must be a file" assert args.input.endswith((".tif", ".tiff")), "Input file must be a .tif or .tiff file" #config assert os.path.isfile(args.config), "Config must exist" assert args.config.endswith(".yaml"), "Config file must be a .yaml file" #output if not os.path.exists(args.output): os.makedirs(args.output) assert os.path.isdir(args.output), "Output must be a folder" #create output folders os.makedirs(os.path.join(args.output, "pngs"), exist_ok=True) ### os.makedirs(os.path.join(args.output, "zarrs"), exist_ok=True) args.filename = os.path.basename(args.input).split(".")[0] args.zarr_path = os.path.join(args.output, f"{args.filename}.zarr") logger.info(f"Input, output and config files exist and checked.") def create_sdata(args): """ Create sdata object """ logger.info(f"Creating spatialdata object.") time_start = time.time() sdata = sopa.io.ome_tif(args.input) args.image_key = list(sdata.images.keys())[0] time_end = time.time() logger.info(f"Creating spatialdata object took {time_end - time_start} seconds.") return sdata def prepare_for_segmentation_search(sdata, args): """ Search for segments in sopa data """ logger.info(f"Preparing for segmentation search.") time_start = time.time() patches = sopa.segmentation.Patches2D(sdata, element_name=args.image_key, patch_width=1000, patch_overlap=100) patches.write() #reset channel names to their indexes, metadata to inconsistent new_c = list(range(len(sdata.images[args.image_key]['scale0'].coords['c'].values))) sdata.images[args.image_key] = sdata.images[args.image_key].assign_coords(c=new_c) time_end = time.time() logger.info(f"Preparation for segmentation took {time_end - time_start} seconds.") return sdata def read_yaml(file_path): """ Read yaml file """ logger.info(f"Reading yaml file.") with open(file_path, 'r') as file: data = yaml.safe_load(file) return data def segmentation_loop(sdata, args, config): """ Loop through different cellpose parameters """ logger.info(f"Starting segmentation loop.") for ft in config['flow_thresholds']: for cpt in config['cellprob_thresholds']: logger.info(f"Segmenting with FT: {ft} and CT: {cpt}") FT_str = str(ft).replace(".", "") #create method for segmenting method = sopa.segmentation.methods.cellpose_patch( diameter=config['cell_pixel_diameter'], channels=config['channels'], flow_threshold=ft, cellprob_threshold=cpt, model_type=config['model_type'] ) segmentation = sopa.segmentation.StainingSegmentation(sdata, method, channels=config['channels'], min_area=config['min_area']) #create temp dir to store segmentation of each tile cellpose_temp_dir = os.path.join(args.output, ".sopa_cache", "cellpose", f"run_FT{FT_str}_CPT{cpt}") #segment segmentation.write_patches_cells(cellpose_temp_dir) #read and solve conflicts cells = sopa.segmentation.StainingSegmentation.read_patches_cells(cellpose_temp_dir) cells = sopa.segmentation.shapes.solve_conflicts(cells) #save segmentation of entire image as shapes sopa.segmentation.StainingSegmentation.add_shapes( sdata, cells, image_key=args.image_key, shapes_key=f"cellpose_boundaries_FT{FT_str}_CT{cpt}") logger.info(f"Saving zarr to {args.zarr_path}") sdata.write(args.zarr_path, overwrite=True) logger.info(f"Segmentation loop finished.") def extract_ft_values(shape_titles): """Extract all unique ft values from a list of shape titles.""" ft_values = set() cpt_values = set() for title in shape_titles: match = re.search(r'_FT(\d+)_CT(\d+)', title) if match: ft_values.add(match.group(1)) cpt_values.add(match.group(2)) else: print(f"Warning: {title} does not match the expected pattern.") return sorted(ft_values), sorted(cpt_values) def plot(sdata, args, config): shape_titles = list(sdata.shapes.keys()) shape_titles.remove("sopa_patches") logger.info(f"Plotting {shape_titles} segmentations") logger.info unique_ft_values, unique_cpt_values = extract_ft_values(shape_titles) num_cols = len(unique_ft_values) num_rows = len(unique_cpt_values) logger.info(f"Unique FT values: {unique_ft_values} and Unique CT values: {unique_cpt_values}") ft_to_index = {ft: i for i, ft in enumerate(unique_ft_values)} cpt_to_index = {cpt: i for i, cpt in enumerate(unique_cpt_values)} fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols*6, num_rows*6), facecolor='black') gs = gridspec.GridSpec(num_rows, num_cols, wspace=0.1, hspace=0.1) for i, title in enumerate(shape_titles): #print number of title out of all titles logger.info(f"Rendering {i+1}/{len(shape_titles)} ||| {title}") ft, cpt = re.search(r'FT(\d+)_CT(\d+)', title).groups() row = cpt_to_index[cpt] col = ft_to_index[ft] ax = fig.add_subplot(gs[row, col]) ax.set_facecolor('black') ax.title.set_color('white') ax.set_xticks([]) ax.set_yticks([]) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) try: logger.info(f" Rendering image") sdata.pl.render_images( element=args.image_key, alpha=0.85, channel=config['channels'], palette=['green'] ).pl.render_shapes( element=title, fill_alpha=0.0, outline=True, outline_width=1.1, outline_color="yellow", outline_alpha=0.32 ).pl.show(ax=ax, title=title, save=os.path.join(args.output, "pngs", "segment_search.png")) logger.info(f"Saving plot to {os.path.join(args.output, 'pngs', 'segment_search.png')}") # plt.savefig(os.path.join(args.output, "pngs", "segment_search.png")) except: print(f"Could not render shapes of {title}") def main(): args = get_args() logger.remove() logger.add(sys.stdout, format="{time:HH:mm:ss.SS} | {level} | {message}", level=args.loglevel.upper()) check_input_outputs(args) sdata = create_sdata(args) sdata = prepare_for_segmentation_search(sdata, args) segmentation_loop(sdata, args, config=read_yaml(args.config)) plot(sdata, args, config=read_yaml(args.config)) if __name__ == "__main__": main() """ Example: python ./scripts/segment_search.py \ --input ./data/input/Exemplar001.ome.tif \ --config ./data/configs/config.yaml \ --output ./data/output/ """ ```