plotly / plotly.py

The interactive graphing library for Python :sparkles: This project now includes Plotly Express!
https://plotly.com/python/
MIT License
16.23k stars 2.55k forks source link

non-pandas (e.g. Polars) plot raises if an argument contains a mix of column names and Series #4841

Open MarcoGorelli opened 4 days ago

MarcoGorelli commented 4 days ago
import polars as pl
import plotly.express as px
vendors = ["A", "B", "C", "D", "E", "F", "G", "H"]
sectors = [
    "Tech",
    "Tech",
    "Finance",
    "Finance",
    "Tech",
    "Tech",
    "Finance",
    "Finance",
]
regions = ["North", "North", "North", "North", "South", "South", "South", "South"]
values = [1, 3, 2, 4, 2, 2, 1, 4]
total = ["total"] * 8
data = pl.DataFrame(
    dict(
        vendors=vendors,
        sectors=sectors,
        regions=regions,
        values=values,
        total=total,
    )
)
path = [data['total'], "regions", "sectors", "vendors"]
fig = px.sunburst(data, path=path)
fig

This raises:

Traceback (most recent call last):
  File "/home/marcogorelli/scratch/.venv/lib/python3.12/site-packages/marimo/_runtime/executor.py", line 157, in execute_cell
    exec(cell.body, glbls)
  Cell marimo:///home/marcogorelli/scratch/untitled.py#cell=cell-0
, line 28, in <module>
    fig = px.sunburst(data, path=path)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marcogorelli/scratch/.venv/lib/python3.12/site-packages/plotly/express/_chart_types.py", line 1688, in sunburst
    return make_figure(
           ^^^^^^^^^^^^
  File "/home/marcogorelli/scratch/.venv/lib/python3.12/site-packages/plotly/express/_core.py", line 2117, in make_figure
    args = build_dataframe(args, constructor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marcogorelli/scratch/.venv/lib/python3.12/site-packages/plotly/express/_core.py", line 1455, in build_dataframe
    necessary_columns.update(i for i in args[field] if i in columns)
  File "/home/marcogorelli/scratch/.venv/lib/python3.12/site-packages/plotly/express/_core.py", line 1455, in <genexpr>
    necessary_columns.update(i for i in args[field] if i in columns)
                                                       ^^^^^^^^^^^^
  File "/home/marcogorelli/scratch/.venv/lib/python3.12/site-packages/pandas/core/indexes/base.py", line 5175, in __contains__
    hash(key)
TypeError: unhashable type: 'Series'

On the other hand, it works if:

The error is this part:

https://github.com/plotly/plotly.py/blob/72bacb569b5e571f3af9ada082e80dd948321d75/packages/python/plotly/plotly/express/_core.py#L1448-L1459

MarcoGorelli commented 4 days ago

I think #4790 would address this nicely: