Open wence- opened 4 months ago
If you are grouping by multiple keys, then sorting each key individually does not give you a sorted keyset.
Consider the following data:
key
A B | data
----- -----
1 1 | 0.1
1 2 | 0.8
2 1 | 0.3
2 2 | 0.2
The key is sorted, but B
is not sorted. If you were to sort column A and B separately, you would distort your data, and have:
key
A B | data
----- -----
1 1 | 0.1
1 1 | 0.8 <-- Not same corresponding data value
2 2 | 0.3 <-- Not same corresponding data value
2 2 | 0.2
I admit the example is a bit contrived, but just suppose both keys are sorted "somehow". You can compute the breaks in linear time with a scan, and don't need to hash.
That would work in theory, but polars doesn't keep track of multi-column sorts; sorted flags are on a per-column basis. If I recall correctly, they considered it at one point but things start getting really complicated when you try to propagate multi-column flags.
That would work in theory, but polars doesn't keep track of multi-column sorts; sorted flags are on a per-column basis. If I recall correctly, they considered it at one point but things start getting really complicated when you try to propagate multi-column flags.
afaik datafusion does it, might be worth seeing how they do it
Also, I'm not sure there's any guarantee that the fast slice path you want returns the groups in maintained order anyway, since the groups are computed in separate threads. You'll need to debug to figure that out.
FYI you can use struct here to probably get the fast path you want, at the expense of constructing the key column:
df = df.with_columns(pl.struct("a", "b").alias("key"))
print(df.group_by("key").agg(pl.col("value")).collect())
That would work in theory, but polars doesn't keep track of multi-column sorts; sorted flags are on a per-column basis. If I recall correctly, they considered it at one point but things start getting really complicated when you try to propagate multi-column flags.
You don't need to maintain multi column sortedness. If you know all your key columns are sorted that is enough information to know you can compute the breaks with a scan
If A
is sorted and B
is sorted, then (A, B)
is probably not sorted, as per the example above. If it's not sorted, then your groups are not contiguous and we cannot reference each group as a slice.
If
A
is sorted andB
is sorted, then(A, B)
is probably not sorted, as per the example above.
I don't follow this at all. Suppose A is sorted in ascending order and B is sorted in descending order. Then for the pair of columns (A, B)
, for every i
, it holds:
A[i] <= A[i+1] && B[i] >= B[i+1]
which sounds sorted to me.
If it's not sorted, then your groups are not contiguous and we cannot reference each group as a slice.
I think my initial example was misleading. Ignore the data (or pretend it's random): particularly, note that I'm not trying to (in the example) maintain the association of the value column with the key columns, though if you want, pretend I did this:
# deliberately reordering, but sorting keys.
reordered = df.select(key1.sort(), key2.sort(), value)
q = reordered.group_by(key1, key2).agg(value)
I claim that both keys are sorted, and therefore, even though there's more than one key, one can still utilise that information: the slice breakpoints are the union of the breakpoints for key1 and key2.
Sorry, you're right--if all of the key columns are already sorted, then the union of the breaks do indeed define the breaks between the groups.
This is a pretty special case though, as to get there usually requires sorting columns independently from each other, which means breaking any relationship between them. For example:
A B C
-------
1 1 a
1 2 b
1 1 c
1 2 d
Sorting A and B separately above essentially breaks any relationship between them, and there is no way to maintain the relationship and sort them individually.
It is possible to have a fast path for when the group keys are all sorted, and it would definitely help with performance.
import polars as pl
df = pl.LazyFrame(
{
"key1": [1, 1, 1, 2, 3, 1, 4, 6, 7],
"key2": [2, 2, 2, 2, 6, 1, 4, 6, 8],
"value": [1, 2, 3, 4, 5, 6, 7, 8, 9],
}
)
q = df.group_by(pl.col("key1").sort(), pl.col("key2").sort(), maintain_order=True).agg(pl.col("value"))
q.collect()
Description
Consider:
I was expecting this to come out in sorted order, however it does not.
If I only group on a single key, then the pre-sorting appears to be utilised.
Is this not done because finding the segment boundaries is more expensive in the multi-key case?