narwhals-dev / narwhals

Lightweight and extensible compatibility layer between dataframe libraries!
https://narwhals-dev.github.io/narwhals/
MIT License
423 stars 76 forks source link

[Bug]: `lit` does not "broadcast" as expected #853

Open FBruzzesi opened 1 month ago

FBruzzesi commented 1 month ago

Describe the bug

Narwhals implementation of lit breaks when it is the left most expr.

Figured this out while doing a fold starting with a lit initial value.

Steps or code to reproduce the bug

import narwhals as nw
import pandas as pd
import polars as pl
import pyarrow as pa

data = {"a": [1, 2, 3]}

df_pd = pd.DataFrame(data)
df_pl = pl.DataFrame(data)
df_pa = pa.table(data)

@nw.narwhalify
def sum_on_the_right(df):
    return df.select(nw.col("a") + nw.lit(0))

sum_on_the_right(df_pl), sum_on_the_right(df_pd), sum_on_the_right(df_pa)
# This is all good and well

@nw.narwhalify
def sum_on_the_left(df):
    return df.select(nw.lit(0) + nw.col("a"))

sum_on_the_left(df_pl)  # polars has no issues

sum_on_the_left(df_pd)
# ValueError: Length mismatch: Expected axis has 3 elements, new values have 1 elements

sum_on_the_left(df_pa)
# ArrowInvalid: Array arguments must all be the same length

Expected results

Being able to broadcast value

Actual results

Broadcast from the left

Please run narwhals.show_version() and enter the output below.

System:
    python: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0]
   machine: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.35

Python dependencies:
     narwhals: 1.5.3
       pandas: 2.2.2
       polars: 1.5.0
         cudf: 
        modin: 0.31.0
      pyarrow: 17.0.0
        numpy: 2.0.1

Relevant log output

No response

FBruzzesi commented 1 month ago

Reverted in #858