keller-mark / spatialdata-vis-interop

0 stars 0 forks source link

`convert_plotting_tree_to_vis_interop_config` function #11

Open keller-mark opened 1 week ago

keller-mark commented 1 week ago

name subject to change

Merge our functions from the hackathon:

Sonja-Stockhaus commented 1 week ago
def create_viewconfig(sdata):
    if not sdata.is_backed():
        raise ValueError("Given spatialdata object must be backed by .zarr storage.")

    data_block = _create_data_block(sdata)
    scales_block = _create_scales_block(sdata)
    marks_block = _create_marks_block(sdata)

    viewconfig = {
        "$schema": "https://spatialdata-plot.github.io/schema/viewconfig/v1.json",
        "height": 4.8, # height in inches (get from figsize?)
        "width": 6.4, # width in inches
        "data": data_block,
        "scales": scales_block,
        "marks": marks_block,
    }
    return viewconfig
keller-mark commented 1 week ago
def sdata_element_to_uuid(element_name):
    # Get the UUID for this element, reflecting the UUIDs in the "data" block of the config.
    return element_name + "_uuid"

def get_scale_name(pl_call_params):
    # Get the name for this scale, reflecting the "scales" block of the config.
    return "cmap"

def get_shapes_color_encoding(pl_call_params):
    # Need to consider whether to use "cmap", "color", "col_for_color", or "palette".
    # If "groups" is present, this should be used when generating the domain for the corresponding scale.
    # Note: "color" and "col_for_color" are expected to be mutually exclusive.

    # TODO: Note that this mark's "data" section refers to the shapes spatialdata element.
    # But "field" in the below may refer to a column in the table that annotates these shapes.
    if pl_call_params.get("col_for_color") is not None:
        # TODO: should `get_scale_name` be passed the scale type (like categorical or quantitative)?
        #adata = sdata.
        #col_series = sc.get.obs_df(adata, col_for_color)
        #if isinstance(col_series.dtype, pd.CategoricalDtype): # This is a categorical column, so should use a categorical color scale.
        #  pass
        return { "scale": get_scale_name(pl_call_params), "field": pl_call_params["col_for_color"] }
    elif pl_call_params.get("color") is not None:
        # If "color" is present in the plotting tree, the value is assumed to be a color-like string.
        return { "value": pl_call_params["color"] }

# TODO: the marks and scales seem to be pretty coupled. may be easier to generate both blocks simultaneously (e.g., modify both in each for loop iteration).
def plotting_tree_to_marks(plotting_tree):
    plotting_tree_dict = dump_plotting_tree(plotting_tree) # TODO: do not rely on dict; use plotting tree directly
    out = [] # caller will set { ..., "marks": out }
    for pl_call_id, pl_call_params in plotting_tree_dict.items():
        if pl_call_id.endswith("_render_images"):
            for channel_index in pl_call_params["channel"]:
                out.append({
                  "type": "raster_image", # TODO: what name to use here?
                  "from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
                  "zindex": pl_call_params["zorder"],
                  "encode": {
                      "opacity": { "value": pl_call_params.get("alpha") },
                      "color": {"scale": get_scale_name(pl_call_params), "field": channel_index }
                  }
                })
        if pl_call_id.endswith("_render_shapes"):
            out.append({
                "type": "shape",
                "from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
                "zindex": pl_call_params["zorder"],
                "encode": {
                    "fillOpacity": { "value": pl_call_params.get("fill_alpha") },
                    "fillColor": get_shapes_color_encoding(pl_call_params),
                    "strokeWidth": { "value": pl_call_params.get("outline_width") }, # TODO: check whether this is the key used in the spatial plotting tree # TODO: what are the units?
                    "strokeColor": { "value": pl_call_params.get("outline_color") },
                    "strokeOpacity": { "value": pl_call_params.get("outline_alpha") },
                }
            })
        if pl_call_id.endswith("_render_points"):
            out.append({
                "type": "point",
                "from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
                "zindex": pl_call_params["zorder"],
                "encode": {
                    "opacity": { "value": pl_call_params.get("alpha") },
                    "color": get_shapes_color_encoding(pl_call_params),
                    "size": { "value": pl_call_params.get("size") },
                }
            })
        if pl_call_id.endswith("_render_labels"):
            out.append({
                "type": "raster_labels", # TODO: what name to use here?
                "from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
                "zindex": pl_call_params["zorder"],
                "encode": {
                    "opacity": { "value": pl_call_params.get("alpha") },
                    "fillColor": get_shapes_color_encoding(pl_call_params),
                    "strokeColor": get_shapes_color_encoding(pl_call_params),
                    "strokeWidth": { "value": pl_call_params.get("contour_px") }, # TODO: check whether this is the key used in the spatial plotting tree
                    "strokeOpacity": { "value": pl_call_params.get("outline_alpha") }, # TODO: check whether this is the key used in the spatial plotting tree
                    "fillOpacity": { "value": pl_call_params.get("fill_alpha") }, # TODO: check whether this is the key used in the spatial plotting tree
                }
            })
    return out
Sonja-Stockhaus commented 1 week ago
def _create_data_block(sdata):
    # start with base level
    base_level = {
        "name": "sdata_base_level",
        "url": sdata.path,
        "format": {"type": "spatialdata"},
        "version": "0.2.0" # TODO: version of spatialdata object?
    }
    data_block = [base_level]

    # for each element of the plotting tree, create a level
    for i, render_call in enumerate(sdata.plotting_tree.keys()):
        element_type = render_call.split("_")[-1]
        block = {
          "name": f"render_call_{i}",
          "format": {"type": f"spatialdata_{element_type}"},
          "version": "0.1.0", # TODO: ???
          "source": base_level["name"],
          "transform": [
            {
              "type": "filter",
              "expr": f"datum['{sdata.plotting_tree[render_call].element}']"
            }
          ]
        }
        data_block.append(block)

    # TODO: add a block for each table? (or only the ones we need for color annotation)

    return data_block