Kitware / trame

Trame lets you weave various components and technologies into a Web Application solely written in Python.
https://kitware.github.io/trame/
Other
431 stars 56 forks source link

Plotly figure resets unexpectedly #228

Closed cardinalgeo closed 1 year ago

cardinalgeo commented 1 year ago

Describe the bug

Upon closing or opening a drawer in an app containing a plotly figure, the figure is reset. For example, if the user zooms into a plotly plot, upon closing or opening a drawer, the axes ranges reset and the plot "unzooms." This behavior is unexpected, as a user may want to zoom into a region of interest in, for example, a scatter plot to better resolve the individual points when making selections. Closing or opening the drawer should not reset the figure, as this requires the user to repeatedly re-zoom into the region of interest — a cumbersome act.

To Reproduce

Steps to reproduce the behavior:

  1. Go to the example here in the Trame repo
  2. Change the following lines
    • Change line 15 from from trame.ui.vuetify import SinglePageLayout to from trame.ui.vuetify import SinglePageWithDrawerLayout
    • Change line 320 from with SinglePageLayout(server) as layout: to with SinglePageWithDrawerLayout(server) as layout:
    • Remove line 322 (layout.icon.click = ctrl.view_reset_camera)
  3. Zoom into the scatter plot on the right hand side
  4. Click the layout icon in the top left hand corner to open or close the drawer
  5. Observe error

The code changes discussed above are implemented below: Code

r"""
Version for trame 1.x - https://github.com/Kitware/trame/blob/release-v1/examples/VTK/Applications/RemoteSelection/app.py
Delta v1..v2          - https://github.com/Kitware/trame/commit/03f28bb0084490acabf218264b96a1dbb3a17f19
"""

import pandas as pd

# Plotly/chart imports
import plotly.graph_objects as go
import plotly.express as px

# Trame imports
from trame.app import get_server
from trame.assets.remote import HttpFile
from trame.ui.vuetify import SinglePageWithDrawerLayout
from trame.widgets import vuetify, plotly, trame, vtk as vtk_widgets

# VTK imports
from vtkmodules.vtkIOXML import vtkXMLUnstructuredGridReader
from vtkmodules.numpy_interface import dataset_adapter as dsa
from vtkmodules.vtkCommonDataModel import vtkSelection, vtkSelectionNode, vtkDataObject
from vtkmodules.vtkCommonCore import vtkIdTypeArray
from vtkmodules.vtkFiltersExtraction import vtkExtractSelection
from vtkmodules.vtkFiltersGeometry import vtkGeometryFilter
from vtkmodules.vtkRenderingCore import (
    vtkActor,
    vtkDataSetMapper,
    vtkRenderer,
    vtkRenderWindow,
    vtkRenderWindowInteractor,
    vtkHardwareSelector,
    vtkRenderedAreaPicker,
)

from vtkmodules.vtkInteractionStyle import (
    vtkInteractorStyleRubberBandPick,
    vtkInteractorStyleSwitch,
)  # noqa
import vtkmodules.vtkRenderingOpenGL2  # noqa

from vtkmodules.vtkInteractionStyle import vtkInteractorStyleRubberBandPick

# -----------------------------------------------------------------------------
# Data file information
# -----------------------------------------------------------------------------

dataset_file = HttpFile(
    "./data/disk_out_ref.vtu",
    "https://github.com/Kitware/trame/raw/master/examples/data/disk_out_ref.vtu",
    __file__,
)

# -----------------------------------------------------------------------------
# Trame setup
# -----------------------------------------------------------------------------

server = get_server()
state, ctrl = server.state, server.controller

# -----------------------------------------------------------------------------
# VTK
# -----------------------------------------------------------------------------

reader = vtkXMLUnstructuredGridReader()
reader.SetFileName(dataset_file.path)
reader.Update()
dataset = reader.GetOutput()

renderer = vtkRenderer()
renderer.SetBackground(1, 1, 1)
render_window = vtkRenderWindow()
render_window.AddRenderer(renderer)

rw_interactor = vtkRenderWindowInteractor()
rw_interactor.SetRenderWindow(render_window)
rw_interactor.GetInteractorStyle().SetCurrentStyleToTrackballCamera()

interactor_trackball = rw_interactor.GetInteractorStyle()
interactor_selection = vtkInteractorStyleRubberBandPick()
area_picker = vtkRenderedAreaPicker()
rw_interactor.SetPicker(area_picker)

surface_filter = vtkGeometryFilter()
surface_filter.SetInputConnection(reader.GetOutputPort())
surface_filter.SetPassThroughPointIds(True)

mapper = vtkDataSetMapper()
mapper.SetInputConnection(surface_filter.GetOutputPort())
actor = vtkActor()
actor.GetProperty().SetOpacity(0.5)
actor.SetMapper(mapper)

# Selection
selection_extract = vtkExtractSelection()
selection_mapper = vtkDataSetMapper()
selection_mapper.SetInputConnection(selection_extract.GetOutputPort())
selection_actor = vtkActor()
selection_actor.GetProperty().SetColor(1, 0, 1)
selection_actor.GetProperty().SetPointSize(5)
selection_actor.SetMapper(selection_mapper)
selection_actor.SetVisibility(0)

renderer.AddActor(actor)
renderer.AddActor(selection_actor)
renderer.ResetCamera()

selector = vtkHardwareSelector()
selector.SetRenderer(renderer)
selector.SetFieldAssociation(vtkDataObject.FIELD_ASSOCIATION_POINTS)

# vtkDataSet to DataFrame
py_ds = dsa.WrapDataObject(dataset)
pt_data = py_ds.PointData
cols = {}
for name in pt_data.keys():
    array = pt_data[name]
    shp = array.shape
    if len(shp) == 1:
        cols[name] = array
    else:
        for i in range(shp[1]):
            cols[name + "_%d" % i] = array[:, i]
DATAFRAME = pd.DataFrame(cols)
FIELD_NAMES = list(cols.keys())
SELECTED_IDX = []

# -----------------------------------------------------------------------------
# Callbacks
# -----------------------------------------------------------------------------

@state.change("figure_size", "scatter_x", "scatter_y")
def update_figure(figure_size, scatter_x, scatter_y, **kwargs):
    if figure_size is None:
        return

    # Generate figure
    bounds = figure_size.get("size", {})
    fig = px.scatter(
        DATAFRAME,
        x=scatter_x,
        y=scatter_y,
        width=bounds.get("width", 200),
        height=bounds.get("height", 200),
    )

    # Update selection settings
    fig.data[0].update(
        selectedpoints=SELECTED_IDX,
        selected={"marker": {"color": "red"}},
        unselected={"marker": {"opacity": 0.5}},
    )

    # Update chart
    ctrl.update_figure(fig)

# -----------------------------------------------------------------------------

@state.change("vtk_selection")
def update_interactor(vtk_selection, **kwargs):
    if vtk_selection:
        # remote view
        rw_interactor.SetInteractorStyle(interactor_selection)
        interactor_selection.StartSelect()
        # local view
        state.interactorSettings = VIEW_SELECT
    else:
        # remote view
        rw_interactor.SetInteractorStyle(interactor_trackball)
        # local view
        state.interactorSettings = VIEW_INTERACT

# -----------------------------------------------------------------------------

def on_chart_selection(selected_point_idxs):
    global SELECTED_IDX
    SELECTED_IDX = selected_point_idxs if selected_point_idxs else []
    npts = len(SELECTED_IDX)

    ids = vtkIdTypeArray()
    ids.SetNumberOfTuples(npts)
    for idx, p_id in enumerate(SELECTED_IDX):
        ids.SetTuple1(idx, p_id)
        idx += 1

    sel_node = vtkSelectionNode()
    sel_node.GetProperties().Set(
        vtkSelectionNode.CONTENT_TYPE(), vtkSelectionNode.INDICES
    )
    sel_node.GetProperties().Set(vtkSelectionNode.FIELD_TYPE(), vtkSelectionNode.POINT)
    sel_node.SetSelectionList(ids)
    sel = vtkSelection()
    sel.AddNode(sel_node)

    selection_extract.SetInputDataObject(0, py_ds.VTKObject)
    selection_extract.SetInputDataObject(1, sel)
    selection_extract.Update()
    selection_actor.SetVisibility(1)

    # Update 3D view
    ctrl.view_update()

def on_box_selection_change(selection):
    global SELECTED_IDX
    if selection.get("mode") == "remote":
        actor.GetProperty().SetOpacity(1)
        selector.SetArea(
            int(renderer.GetPickX1()),
            int(renderer.GetPickY1()),
            int(renderer.GetPickX2()),
            int(renderer.GetPickY2()),
        )
    elif selection.get("mode") == "local":
        camera = renderer.GetActiveCamera()
        camera_props = selection.get("camera")

        # Sync client view to server one
        camera.SetPosition(camera_props.get("position"))
        camera.SetFocalPoint(camera_props.get("focalPoint"))
        camera.SetViewUp(camera_props.get("viewUp"))
        camera.SetParallelProjection(camera_props.get("parallelProjection"))
        camera.SetParallelScale(camera_props.get("parallelScale"))
        camera.SetViewAngle(camera_props.get("viewAngle"))
        render_window.SetSize(selection.get("size"))

        actor.GetProperty().SetOpacity(1)
        render_window.Render()

        area = selection.get("selection")
        selector.SetArea(
            int(area[0]),
            int(area[2]),
            int(area[1]),
            int(area[3]),
        )

    # Common server selection
    s = selector.Select()
    n = s.GetNode(0)
    ids = dsa.vtkDataArrayToVTKArray(n.GetSelectionData().GetArray("SelectedIds"))
    surface = dsa.WrapDataObject(surface_filter.GetOutput())
    SELECTED_IDX = surface.PointData["vtkOriginalPointIds"][ids].tolist()

    selection_extract.SetInputConnection(surface_filter.GetOutputPort())
    selection_extract.SetInputDataObject(1, s)
    selection_extract.Update()
    selection_actor.SetVisibility(1)
    actor.GetProperty().SetOpacity(0.5)

    # Update scatter plot with selection
    update_figure(**state.to_dict())

    # Update 3D view
    ctrl.view_update()

    # disable selection mode
    state.vtk_selection = False

# -----------------------------------------------------------------------------
# Settings
# -----------------------------------------------------------------------------

DROPDOWN_STYLES = {
    "dense": True,
    "hide_details": True,
    "classes": "px-2",
    "style": "max-width: calc(25vw - 10px);",
}

CHART_STYLE = {
    "style": "position: absolute; left: 50%; transform: translateX(-50%);",
    "display_mode_bar": ("true",),
    "mode_bar_buttons_to_remove": (
        "chart_buttons",
        [
            "toImage",
            "resetScale2d",
            "zoomIn2d",
            "zoomOut2d",
            "toggleSpikelines",
            "hoverClosestCartesian",
            "hoverCompareCartesian",
        ],
    ),
    "display_logo": ("false",),
}

VTK_VIEW_SETTINGS = {
    "interactive_ratio": 1,
    "interactive_quality": 80,
}

VIEW_INTERACT = [
    {"button": 1, "action": "Rotate"},
    {"button": 2, "action": "Pan"},
    {"button": 3, "action": "Zoom", "scrollEnabled": True},
    {"button": 1, "action": "Pan", "alt": True},
    {"button": 1, "action": "Zoom", "control": True},
    {"button": 1, "action": "Pan", "shift": True},
    {"button": 1, "action": "Roll", "alt": True, "shift": True},
]

VIEW_SELECT = [{"button": 1, "action": "Select"}]

# -----------------------------------------------------------------------------
# UI
# -----------------------------------------------------------------------------

state.trame__title = "VTK selection"
ctrl.on_server_ready.add(ctrl.view_update)

with SinglePageWithDrawerLayout(server) as layout:
    layout.title.set_text("VTK & plotly")

    with layout.toolbar as tb:
        tb.dense = True
        vuetify.VSpacer()
        vuetify.VSelect(
            v_model=("scatter_y", FIELD_NAMES[1]),
            items=("fields", FIELD_NAMES),
            **DROPDOWN_STYLES,
        )
        vuetify.VSelect(
            v_model=("scatter_x", FIELD_NAMES[0]),
            items=("fields", FIELD_NAMES),
            **DROPDOWN_STYLES,
        )

    with layout.content:
        with vuetify.VContainer(fluid=True, classes="fill-height pa-0 ma-0"):
            with vuetify.VRow(dense=True, style="height: 100%;"):
                with vuetify.VCol(
                    classes="pa-0",
                    style="border-right: 1px solid #ccc; position: relative;",
                ):
                    view = vtk_widgets.VtkRemoteView(
                        # view = vtk_widgets.VtkLocalView(
                        render_window,
                        box_selection=("vtk_selection",),
                        box_selection_change=(on_box_selection_change, "[$event]"),
                        # For VtkRemoteView
                        **VTK_VIEW_SETTINGS,
                        # For VtkLocalView
                        interactor_settings=("interactorSettings", VIEW_SELECT),
                    )
                    ctrl.view_update = view.update
                    ctrl.view_reset_camera = view.reset_camera
                    vuetify.VCheckbox(
                        small=True,
                        on_icon="mdi-selection-drag",
                        off_icon="mdi-rotate-3d",
                        v_model=("vtk_selection", False),
                        style="position: absolute; top: 0; right: 0; z-index: 1;",
                        dense=True,
                        hide_details=True,
                    )
                with vuetify.VCol(classes="pa-0"):
                    with trame.SizeObserver("figure_size"):
                        html_plot = plotly.Figure(
                            selected=(
                                on_chart_selection,
                                "[$event?.points.map(({pointIndex}) => pointIndex)]",
                            ),
                            **CHART_STYLE,
                        )
                        ctrl.update_figure = html_plot.update

# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    server.start()

Expected behavior

I expect that the plotly figure should not be reset upon opening or closing the drawer.

Platform:

Device:

OS:

Browsers Affected:

cardinalgeo commented 1 year ago

Of note, the above also occurs upon refreshing the app. And in contrast to the plotly figure, the vtk widget does not reset in either situation (i.e., opening/closing drawer or refreshing the app).

jourdain commented 1 year ago

This is not a trame bug, this is a bug in your application.

Let me explain.

The issue you are facing is because, you are using a SizeObserver which will trigger the regeneration of the chart (python side) on every size change. And since you are not tracking the zoom (using events so the server knows about it), you will reset the chart and therefore the zoom.

So you have 2 possibles solutions.

First one: Fix the size (mainly height, as the width can be dynamic using 100%) and do not use the size observer.

Second one: Keep track of the zoomed area by listening to the chart and making sure that when you update the chart (from the server) you take into account that information.

HTH,

Seb