Open keller-mark opened 1 week ago
def create_viewconfig(sdata):
if not sdata.is_backed():
raise ValueError("Given spatialdata object must be backed by .zarr storage.")
data_block = _create_data_block(sdata)
scales_block = _create_scales_block(sdata)
marks_block = _create_marks_block(sdata)
viewconfig = {
"$schema": "https://spatialdata-plot.github.io/schema/viewconfig/v1.json",
"height": 4.8, # height in inches (get from figsize?)
"width": 6.4, # width in inches
"data": data_block,
"scales": scales_block,
"marks": marks_block,
}
return viewconfig
def sdata_element_to_uuid(element_name):
# Get the UUID for this element, reflecting the UUIDs in the "data" block of the config.
return element_name + "_uuid"
def get_scale_name(pl_call_params):
# Get the name for this scale, reflecting the "scales" block of the config.
return "cmap"
def get_shapes_color_encoding(pl_call_params):
# Need to consider whether to use "cmap", "color", "col_for_color", or "palette".
# If "groups" is present, this should be used when generating the domain for the corresponding scale.
# Note: "color" and "col_for_color" are expected to be mutually exclusive.
# TODO: Note that this mark's "data" section refers to the shapes spatialdata element.
# But "field" in the below may refer to a column in the table that annotates these shapes.
if pl_call_params.get("col_for_color") is not None:
# TODO: should `get_scale_name` be passed the scale type (like categorical or quantitative)?
#adata = sdata.
#col_series = sc.get.obs_df(adata, col_for_color)
#if isinstance(col_series.dtype, pd.CategoricalDtype): # This is a categorical column, so should use a categorical color scale.
# pass
return { "scale": get_scale_name(pl_call_params), "field": pl_call_params["col_for_color"] }
elif pl_call_params.get("color") is not None:
# If "color" is present in the plotting tree, the value is assumed to be a color-like string.
return { "value": pl_call_params["color"] }
# TODO: the marks and scales seem to be pretty coupled. may be easier to generate both blocks simultaneously (e.g., modify both in each for loop iteration).
def plotting_tree_to_marks(plotting_tree):
plotting_tree_dict = dump_plotting_tree(plotting_tree) # TODO: do not rely on dict; use plotting tree directly
out = [] # caller will set { ..., "marks": out }
for pl_call_id, pl_call_params in plotting_tree_dict.items():
if pl_call_id.endswith("_render_images"):
for channel_index in pl_call_params["channel"]:
out.append({
"type": "raster_image", # TODO: what name to use here?
"from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
"zindex": pl_call_params["zorder"],
"encode": {
"opacity": { "value": pl_call_params.get("alpha") },
"color": {"scale": get_scale_name(pl_call_params), "field": channel_index }
}
})
if pl_call_id.endswith("_render_shapes"):
out.append({
"type": "shape",
"from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
"zindex": pl_call_params["zorder"],
"encode": {
"fillOpacity": { "value": pl_call_params.get("fill_alpha") },
"fillColor": get_shapes_color_encoding(pl_call_params),
"strokeWidth": { "value": pl_call_params.get("outline_width") }, # TODO: check whether this is the key used in the spatial plotting tree # TODO: what are the units?
"strokeColor": { "value": pl_call_params.get("outline_color") },
"strokeOpacity": { "value": pl_call_params.get("outline_alpha") },
}
})
if pl_call_id.endswith("_render_points"):
out.append({
"type": "point",
"from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
"zindex": pl_call_params["zorder"],
"encode": {
"opacity": { "value": pl_call_params.get("alpha") },
"color": get_shapes_color_encoding(pl_call_params),
"size": { "value": pl_call_params.get("size") },
}
})
if pl_call_id.endswith("_render_labels"):
out.append({
"type": "raster_labels", # TODO: what name to use here?
"from": {"data": sdata_element_to_uuid(pl_call_params["element"])},
"zindex": pl_call_params["zorder"],
"encode": {
"opacity": { "value": pl_call_params.get("alpha") },
"fillColor": get_shapes_color_encoding(pl_call_params),
"strokeColor": get_shapes_color_encoding(pl_call_params),
"strokeWidth": { "value": pl_call_params.get("contour_px") }, # TODO: check whether this is the key used in the spatial plotting tree
"strokeOpacity": { "value": pl_call_params.get("outline_alpha") }, # TODO: check whether this is the key used in the spatial plotting tree
"fillOpacity": { "value": pl_call_params.get("fill_alpha") }, # TODO: check whether this is the key used in the spatial plotting tree
}
})
return out
def _create_data_block(sdata):
# start with base level
base_level = {
"name": "sdata_base_level",
"url": sdata.path,
"format": {"type": "spatialdata"},
"version": "0.2.0" # TODO: version of spatialdata object?
}
data_block = [base_level]
# for each element of the plotting tree, create a level
for i, render_call in enumerate(sdata.plotting_tree.keys()):
element_type = render_call.split("_")[-1]
block = {
"name": f"render_call_{i}",
"format": {"type": f"spatialdata_{element_type}"},
"version": "0.1.0", # TODO: ???
"source": base_level["name"],
"transform": [
{
"type": "filter",
"expr": f"datum['{sdata.plotting_tree[render_call].element}']"
}
]
}
data_block.append(block)
# TODO: add a block for each table? (or only the ones we need for color annotation)
return data_block
name subject to change
Merge our functions from the hackathon:
plotting_tree_to_data
plotting_tree_to_scales
plotting_tree_to_marks