pandas-dev / pandas

Flexible and powerful data analysis / manipulation library for Python, providing labeled data structures similar to R data.frame objects, statistical functions, and much more
https://pandas.pydata.org
BSD 3-Clause "New" or "Revised" License
43.4k stars 17.83k forks source link

BUG: Unexpected cast to float for `DataFrame.groupby().agg(engine="numba")` #58869

Open willsthompson opened 4 months ago

willsthompson commented 4 months ago

Pandas version checks

Reproducible Example

import numpy as np
import pandas as pd
import numba as nb
from numba import njit

@njit(nb.int64(nb.int64[:], nb.int64[:]))
def return_one(values, index):
    return np.int64(1)

df = pd.DataFrame({"group": [1, 1, 2, 2, 2], "y": [1, 1, 2, 3, 2]})

noagg = return_one(df["y"].values, df.index.values)
aggs1 = df.groupby("group").agg(lambda s: return_one(s.values, s.index.values))
aggs2 = df.groupby("group").agg(return_one, engine="numba")

print(type(noagg))  # int
print(aggs1["y"].dtype)  # np.int64
print(aggs2["y"].dtype)  # np.float64

Issue Description

When executing the numba-compiled function standalone or with groupby.agg() without the numba engine, ints are returned as expected. However, when using Pandas' numba-compiled agg function, ints are coerced into floats.

Best I can tell the issue is here https://github.com/pandas-dev/pandas/blob/b162331554d7c7f6fd46ddde1ff3908f2dc8bcce/pandas/core/groupby/numba_.py#L114, which seems to always initialize the array to float. When I add a dtype=values.dtype to that constructor, it works as expected.

Expected Behavior

Ideally the numba engine behavior would match Pandas' and use the dtype of the aggregate function's result, but I'm not sure that's possible. A more practical solution would be to accept an optional result_dtype argument that defaults to np.float (making that default more explicit) and pass that to the result array constructor. It would also be helpful for this nuance to be documented.

Installed Versions

INSTALLED VERSIONS ------------------ commit : 2e218d10984e9919f0296931d92ea851c6a6faf5 python : 3.11.9.final.0 python-bits : 64 OS : Darwin OS-release : 23.4.0 Version : Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000 machine : arm64 processor : arm byteorder : little LC_ALL : None LANG : None LOCALE : en_US.UTF-8 pandas : 1.5.3 numpy : 1.26.4 pytz : 2024.1 dateutil : 2.8.2 setuptools : 65.5.0 pip : 24.0 Cython : 3.0.10 pytest : 7.2.2 hypothesis : None sphinx : None blosc : None feather : None xlsxwriter : None lxml.etree : 5.2.1 html5lib : None pymysql : None psycopg2 : 2.9.3 jinja2 : 3.1.3 IPython : None pandas_datareader: None bs4 : 4.12.3 bottleneck : None brotli : None fastparquet : None fsspec : 2024.3.1 gcsfs : None matplotlib : 3.8.4 numba : 0.59.1 numexpr : None odfpy : None openpyxl : None pandas_gbq : None pyarrow : 16.0.0 pyreadstat : None pyxlsb : None s3fs : None scipy : 1.13.0 snappy : None sqlalchemy : 2.0.30 tables : None tabulate : 0.9.0 xarray : None xlrd : None xlwt : None zstandard : None tzdata : 2024.1
rhshadrach commented 4 months ago

https://github.com/pandas-dev/pandas/pull/35759#discussion_r473598046

cc @mroeschke