Open jackaixin opened 1 month ago
Also, please kindly suggest the best way to perform dot product in my example above.
Yeah, that does not look right:
df = pl.DataFrame({
'foo': [[1, 2, 3, 4]]
})
df.select(
pl.col.foo.list.eval(pl.element().slice(0, 2)).alias('x'),
pl.col.foo.list.eval(pl.element().slice(2, 2)).alias('y'),
pl.col.foo.list.eval(
pl.element().slice(0, 2) + pl.element().slice(2, 2)
).alias('x + y')
)
# shape: (1, 3)
# ┌───────────┬───────────┬───────────┐
# │ x ┆ y ┆ x + y │
# │ --- ┆ --- ┆ --- │
# │ list[i64] ┆ list[i64] ┆ list[i64] │
# ╞═══════════╪═══════════╪═══════════╡
# │ [1, 2] ┆ [3, 4] ┆ [2, 4] │ # <- ERROR: x + x?
# └───────────┴───────────┴───────────┘
Also, please kindly suggest the best way to perform dot product in my example above.
You can use explode
and group_by
:
df = pl.DataFrame({
'values': [[0], [0, 2], [0, 2, 4], [2, 4, 0], [4, 0, 8]],
'weights': [[3], [2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]
})
(
df
.lazy()
.with_row_index()
.explode('values', 'weights')
.group_by('index', maintain_order=True)
.agg(
'values',
'weights',
(pl.col.values * pl.col.weights).sum().alias('dot')
)
.drop('index')
.collect()
)
shape: (5, 3)
┌───────────┬───────────┬─────┐
│ values ┆ weights ┆ dot │
│ --- ┆ --- ┆ --- │
│ list[i64] ┆ list[i64] ┆ i64 │
╞═══════════╪═══════════╪═════╡
│ [0] ┆ [3] ┆ 0 │
│ [0, 2] ┆ [2, 3] ┆ 6 │
│ [0, 2, 4] ┆ [1, 2, 3] ┆ 16 │
│ [2, 4, 0] ┆ [1, 2, 3] ┆ 10 │
│ [4, 0, 8] ┆ [1, 2, 3] ┆ 28 │
└───────────┴───────────┴─────┘
@ruoyu0088 thanks for this. I tried another version with explode
:
q2 = (
df2
.lazy()
.with_row_index()
.select(
'values',
'weights',
(pl.col('values').explode() * pl.col('weights').explode()).sum().over('index').alias('dot')
)
)
q2.collect()
which returns the same result. Performance is similar on the small example above (mine is slightly faster).
However, when I applied your version (df.explode.group_by.agg
) and mine (explode
+ over
) to a larger dataframe (~3m rows), yours is 2x faster than mine. Do you have an idea why that might be the case?
Another comment to your explode
implementation is that it seems to be consuming much more memory than the list.eval
version, although the explode.group_by
version is faster than list.eval
.
Checks
Reproducible example
Log output
No response
Issue description
I was trying to get dot product of
values
andweights
, and would like to use functions in thelist
namespace. I haven't found any built-inlist.dot
so I ended up usinglist.eval
in the hacky way above. But the code above was returning:We see that
values1
andvalues2
are what I expected from thepl.element().slice
operations, butdot
andsum
seem to be performing on the first slice itself instead offirst_slice.dot(second_slice)
orfirst_slice + second_slice
.Expected behavior
I expect the
dot
column to be exactly the same asdot2
column, and thesum
column to be the same assum2
.Installed versions