Marsilea-viz / marsilea

Declarative creation of composable visualization for Python (Complex heatmap, Upset plot, Oncoprint and more~)
https://marsilea.rtfd.io/
MIT License
174 stars 6 forks source link

Struggling to get access to the ax for specific plots within combined plot #50

Open newtonharry opened 1 month ago

newtonharry commented 1 month ago

Hi,

I have this code which is attempting to plot three different figures. The first is a track plot, and the second two are heatmaps. How do I get access to the second and third ax in order to add an x-axis to them?

from marsilea.plotter import Labels
import matplotlib.pyplot as plt

TRACK_HEIGHT = 0.5
TRACK_PAD = 0.1

# Initialize the ZeroHeight canvas for the track
fragment_track = ma.ZeroHeight(10, name="track")

# Add the peak area plot (assuming you have peak_counts defined)
fragment_track.add_bottom(
    mp.Area(
        peak_counts,  # Replace with your actual peak count data
        color="#4F6D7A",
        add_outline=False,
        alpha=1
    ),
    name="peak"
)

# Add a title to the canvas
fragment_track.add_title("A", align="left", pad=0.4)

# fragment_track.render()

# Create the fragment cell heatmap
fragment_cell_heatmap = ma.CatHeatmap(downsampled_matrix.toarray(), cmap="magma", label="Read Count", height=10, width=10, name="fragment_cell_heatmap")
fragment_cell_heatmap.group_rows(downsampled_cell_types, order=np.unique(downsampled_cell_types), spacing=0.015, )

# Add cell type labels on the left with no rotation
fragment_cell_heatmap.add_left(
    ma.plotter.Chunk(np.unique(downsampled_cell_types), rotation=0)
)

# Add title to the heatmap
fragment_cell_heatmap.add_title("B", align="left", pad=0.4)

#fragment_cell_heatmap.render()

# Create cosine similarity heatmap
similarities = cosine_similarity(downsampled_matrix.toarray().T)
cosine_similarity_heatmap = ma.Heatmap(similarities, cmap="magma", width=10, height=10, label="Cosine Similarity", name="cosine_similarities")
cosine_similarity_heatmap.add_title("C", align="left", pad=0.4)
#cosine_similarity_heatmap.render()

# Combine all components into one layout
comb = fragment_track / 1.0 / fragment_cell_heatmap / 0.5 / cosine_similarity_heatmap
comb.add_legends("right", stack_size=10, stack_by="column")
comb.render()

# Get the Axes object corresponding to the "track" and "peak" plot
ax = comb.get_ax("track", "peak")

# Explicitly show the x-axis and customize it
ax.set_axis_on()
ax.xaxis.set_visible(True)  # Ensure the x-axis is visible
ax.tick_params(axis='x', which='both', bottom=True, labelbottom=True)  # Ensure ticks and labels are shown

# Define the start and end genomic positions (to map the relative positions)
genomic_start = 58572445
genomic_end = 58573849
num_ticks = 7  # Number of tick positions you want to show

# Set relative tick positions (e.g., positions 1 to N)
relative_positions = np.linspace(0, len(peak_counts) - 1, num_ticks).astype(int)

# Set the corresponding genomic positions as the labels
genomic_labels = np.linspace(genomic_start, genomic_end, num_ticks).astype(int)

# Set the relative tick positions and map them to the genomic labels
ax.set_xticks(relative_positions)
ax.set_xticklabels(["chr19:" + str(label) for label in genomic_labels], rotation=40, fontsize=10)
ax.set_ylabel("Read counts")
ax.set_xlabel("Genomic position")

Thanks!

Mr-Milk commented 1 month ago

Thanks for using Marsilea!

I forgot to implement the API that allows the user to retrieve the main ax after combining the plots. This API will be available in the next release. Sorry for the confusion! Ideally, you will have something like:

comb.get_main_ax('track') # For track
comb.get_main_ax('heatmap') # For heatmap

For your use case, I would suggest something like below. Because the xticks is controlled by the plotter, otherwise you need to either implement your RenderPlan (see how: https://marsilea.readthedocs.io/en/stable/tutorial/new_renderplan.html) or you can add an empty canvas with .add_canvas()

import numpy as np
import marsilea as ma
import marsilea.plotter as mp

data = np.random.randint(0, 10, 1000)
mat = np.random.randn(10, 1000)

start = 58572000
end = 58573000

track = ma.ZeroHeight(10, name="track")
track.add_bottom(mp.Area(data, label="Read Count"), name="peak")
track.add_bottom(mp.Labels(
    ["" if i % 100 else "-" for i in np.arange(start, end)],
))
track.add_bottom(mp.Labels(
    ["" if i % 100 else i for i in np.arange(start, end)],
    rotation=0,
    fontsize=6,
    label="Genome Position",
    label_loc="right",
))

heatmap = ma.Heatmap(mat, height=1, name="heat")
comb = track / heatmap
comb.render()

image