snowflakedb / snowflake-ml-python

Apache License 2.0
38 stars 9 forks source link

Snowpark XGBRegressor Ignores Sample Weights, Producing Identical Predictions for Different Models #111

Open robertlessmore opened 1 month ago

robertlessmore commented 1 month ago
  1. What version of Python are you using?

Python 3.11.8 | packaged by Anaconda, Inc. | (main, Feb 26 2024, 21:34:05) [MSC v.1916 64 bit (AMD64)]

What operating system and processor architecture are you using? Windows-10-10.0.22631-SP0

  1. What are the component versions in the environment?

_py-xgboost-mutex 2.0 cpu_0 abseil-cpp 20220623.0 h0e60522_0 conda-forge absl-py 1.4.0 py311haa95532_0 aiobotocore 2.7.0 py311haa95532_0 aiohttp 3.9.3 py311h2bbff1b_0 aioitertools 0.7.1 pyhd3eb1b0_0 aiosignal 1.2.0 pyhd3eb1b0_0 alembic 1.8.1 py311haa95532_0 anyio 3.5.0 py311haa95532_0 appdirs 1.4.4 pyhd3eb1b0_0 argon2-cffi 21.3.0 pyhd3eb1b0_0 argon2-cffi-bindings 21.2.0 py311h2bbff1b_0 arrow-cpp 10.0.1 h9c18f36_4_cpu conda-forge asn1crypto 1.5.1 py311haa95532_0 asttokens 2.0.5 pyhd3eb1b0_0 async-lru 2.0.4 py311haa95532_0 attrs 23.1.0 py311haa95532_0 aws-c-auth 0.6.19 h2bbff1b_0 aws-c-cal 0.5.20 h2bbff1b_0 aws-c-common 0.8.5 h2bbff1b_0 aws-c-compression 0.2.16 h2bbff1b_0 aws-c-event-stream 0.2.15 hd77b12b_0 aws-c-http 0.6.25 h2bbff1b_0 aws-c-io 0.13.10 h2bbff1b_0 aws-c-mqtt 0.7.13 h2bbff1b_0 aws-c-s3 0.1.51 h2bbff1b_0 aws-c-sdkutils 0.1.6 h2bbff1b_0 aws-checksums 0.1.13 h2bbff1b_0 aws-crt-cpp 0.18.16 hd77b12b_0 aws-sdk-cpp 1.9.379 h2768dcf_5 conda-forge babel 2.11.0 py311haa95532_0 beautifulsoup4 4.12.2 py311haa95532_0 blas 1.0 mkl bleach 4.1.0 pyhd3eb1b0_0 blinker 1.6.2 py311haa95532_0 boost-cpp 1.82.0 h59b6b97_2 botocore 1.31.64 py311haa95532_0 bottleneck 1.3.7 py311hd7041d2_0 brotli 1.0.9 h2bbff1b_7 brotli-bin 1.0.9 h2bbff1b_7 brotli-python 1.0.9 py311hd77b12b_7 bzip2 1.0.8 h2bbff1b_5 c-ares 1.19.1 h2bbff1b_0 ca-certificates 2024.7.2 haa95532_0 https://repo.anaconda.com/pkgs/snowflake cachetools 4.2.2 pyhd3eb1b0_0 certifi 2024.7.4 py311haa95532_0 https://repo.anaconda.com/pkgs/snowflake cffi 1.16.0 py311h2bbff1b_0 charset-normalizer 2.0.4 pyhd3eb1b0_0 click 8.1.7 py311haa95532_0 cloudpickle 2.2.1 py311haa95532_0 colorama 0.4.6 py311haa95532_0 comm 0.2.1 py311haa95532_0 contourpy 1.2.0 py311h59b6b97_0 cryptography 41.0.7 py311h89fc84f_0 cycler 0.11.0 pyhd3eb1b0_0 databricks-cli 0.17.6 py311haa95532_1 debugpy 1.6.7 py311hd77b12b_0 decorator 5.1.1 pyhd3eb1b0_0 defusedxml 0.7.1 pyhd3eb1b0_0 docker-py 4.4.1 py311haa95532_5 docker-pycreds 0.4.0 pyhd3eb1b0_0 entrypoints 0.4 py311haa95532_0 executing 0.8.3 pyhd3eb1b0_0 filelock 3.13.1 py311haa95532_0 flask 2.2.5 py311haa95532_0 fonttools 4.25.0 pyhd3eb1b0_0 freetype 2.12.1 ha860e81_0 frozenlist 1.4.0 py311h2bbff1b_0 fsspec 2023.10.0 py311haa95532_0 gflags 2.2.2 hd77b12b_1 gitdb 4.0.7 pyhd3eb1b0_0 gitpython 3.1.37 py311haa95532_0 glog 0.6.0 h4797de2_0 conda-forge greenlet 3.0.1 py311hd77b12b_0 grpc-cpp 1.51.1 h9c18f36_0 conda-forge icc_rt 2022.1.0 h6049295_2 icu 73.1 h6c2663c_0 idna 3.4 py311haa95532_0 importlib-metadata 6.0.0 py311haa95532_0 importlib_resources 6.1.1 py311haa95532_1 intel-openmp 2023.1.0 h59b6b97_46320 ipykernel 6.28.0 py311haa95532_0 ipython 8.20.0 py311haa95532_0 ipywidgets 8.1.2 py311haa95532_0 itsdangerous 2.0.1 pyhd3eb1b0_0 jedi 0.18.1 py311haa95532_1 jinja2 3.1.3 py311haa95532_0 jmespath 1.0.1 py311haa95532_0 joblib 1.2.0 py311haa95532_0 jpeg 9e h2bbff1b_1 json5 0.9.6 pyhd3eb1b0_0 jsonschema 4.19.2 py311haa95532_0 jsonschema-specifications 2023.7.1 py311haa95532_0 jupyter 1.0.0 py311haa95532_9 jupyter-lsp 2.2.0 py311haa95532_0 jupyter_client 8.6.0 py311haa95532_0 jupyter_console 6.6.3 py311haa95532_0 jupyter_core 5.5.0 py311haa95532_0 jupyter_events 0.8.0 py311haa95532_0 jupyter_server 2.10.0 py311haa95532_0 jupyter_server_terminals 0.4.4 py311haa95532_1 jupyterlab 4.0.11 py311haa95532_0 jupyterlab_pygments 0.1.2 py_0 jupyterlab_server 2.25.1 py311haa95532_0 jupyterlab_widgets 3.0.10 py311haa95532_0 kiwisolver 1.4.4 py311hd77b12b_0 krb5 1.20.1 h5b6d351_0 lerc 3.0 hd77b12b_0 libabseil 20220623.0 cxx17_h1a56200_6 conda-forge libarrow 10.0.1 h226723c_4_cpu conda-forge libboost 1.82.0 h3399ecb_2 libbrotlicommon 1.0.9 h2bbff1b_7 libbrotlidec 1.0.9 h2bbff1b_7 libbrotlienc 1.0.9 h2bbff1b_7 libclang 14.0.6 default_hb5a9fac_1 libclang13 14.0.6 default_h8e68704_1 libcrc32c 1.1.2 hd77b12b_0 libcurl 8.5.0 h86230a5_0 libdeflate 1.17 h2bbff1b_1 libevent 2.1.12 h56d1f94_1 libffi 3.4.4 hd77b12b_0 libgoogle-cloud 2.5.0 h5fc25aa_1 conda-forge libgrpc 1.51.1 h6a6baca_0 conda-forge libpng 1.6.39 h8cc25b3_0 libpq 12.17 h906ac69_0 libprotobuf 3.21.12 h12be248_2 conda-forge libsodium 1.0.18 h62dcd97_0 libssh2 1.10.0 he2ea4bf_2 libthrift 0.16.0 h9ce19ad_2 conda-forge libtiff 4.5.1 hd77b12b_0 libutf8proc 2.8.0 h82a8f57_0 conda-forge libwebp-base 1.3.2 h2bbff1b_0 libxgboost 1.7.3 hd77b12b_0 libzlib 1.2.13 hcfcfb64_5 conda-forge lz4-c 1.9.4 h2bbff1b_0 mako 1.2.3 py311haa95532_0 markdown 3.4.1 py311haa95532_0 markupsafe 2.1.3 py311h2bbff1b_0 matplotlib-base 3.8.0 py311hf62ec03_0 matplotlib-inline 0.1.6 py311haa95532_0 mistune 2.0.4 py311haa95532_0 mkl 2023.1.0 h6b88ed4_46358 mkl-service 2.4.0 py311h2bbff1b_1 mkl_fft 1.3.8 py311h2bbff1b_0 mkl_random 1.2.4 py311h59b6b97_0 mlflow 2.3.1 py311hd1fac3c_0 multidict 6.0.4 py311h2bbff1b_0 munkres 1.1.4 py_0 nbclient 0.8.0 py311haa95532_0 nbconvert 7.10.0 py311haa95532_0 nbformat 5.9.2 py311haa95532_0 nest-asyncio 1.6.0 py311haa95532_0 notebook 7.0.8 py311haa95532_0 notebook-shim 0.2.3 py311haa95532_0 numexpr 2.8.7 py311h1fcbade_0 numpy 1.26.4 py311hdab7c0b_0 numpy-base 1.26.4 py311hd01c5d8_0 oauthlib 3.2.2 py311haa95532_0 openjpeg 2.4.0 h4fc8c34_0 openssl 3.3.0 hcfcfb64_0 conda-forge orc 1.9.0 hada7b9e_1 conda-forge overrides 7.4.0 py311haa95532_0 packaging 23.1 py311haa95532_0 pandas 1.5.3 py311heda8569_0 pandocfilters 1.5.0 pyhd3eb1b0_0 parso 0.8.3 pyhd3eb1b0_0 patsy 0.5.6 pyhd8ed1ab_0 conda-forge pillow 10.2.0 py311h2bbff1b_0 pip 23.3.1 py311haa95532_0 platformdirs 3.10.0 py311haa95532_0 ply 3.11 py311haa95532_0 prometheus_client 0.14.1 py311haa95532_0 prompt-toolkit 3.0.43 py311haa95532_0 prompt_toolkit 3.0.43 hd3eb1b0_0 protobuf 4.21.12 py311h12c1d0e_0 conda-forge psutil 5.9.0 py311h2bbff1b_0 pure_eval 0.2.2 pyhd3eb1b0_0 py-xgboost 1.7.3 py311haa95532_0 pyarrow 10.0.1 py311h8a3a540_0 pycparser 2.21 pyhd3eb1b0_0 pygments 2.15.1 py311haa95532_1 pyjwt 2.4.0 py311haa95532_0 pyopenssl 23.2.0 py311haa95532_0 pyparsing 3.0.9 py311haa95532_0 pyqt 5.15.10 py311hd77b12b_0 pyqt5-sip 12.13.0 py311h2bbff1b_0 pysocks 1.7.1 py311haa95532_0 python 3.11.8 he1021f5_0 python-dateutil 2.8.3+snowflake1 py311haa95532_1 https://repo.anaconda.com/pkgs/snowflake python-fastjsonschema 2.16.2 py311haa95532_0 python-json-logger 2.0.7 py311haa95532_0 python_abi 3.11 2_cp311 conda-forge pytimeparse 1.1.8 py311haa95532_0 pytz 2023.3.post1 py311haa95532_0 pywin32 305 py311h2bbff1b_0 pywinpty 2.0.10 py311h5da7b33_0 pyyaml 6.0.1 py311h2bbff1b_0 pyzmq 25.1.2 py311hd77b12b_0 qt-main 5.15.2 h19c9488_10 qtconsole 5.5.1 py311haa95532_0 qtpy 2.4.1 py311haa95532_0 querystring_parser 1.2.4 py311haa95532_0 re2 2022.06.01 h0e60522_1 conda-forge referencing 0.30.2 py311haa95532_0 requests 2.31.0 py311haa95532_1 retrying 1.3.3 pyhd3eb1b0_2 rfc3339-validator 0.1.4 py311haa95532_0 rfc3986-validator 0.1.1 py311haa95532_0 rpds-py 0.10.6 py311h062c2fa_0 s3fs 2023.10.0 py311haa95532_0 scikit-learn 1.2.2 py311hd77b12b_1 scipy 1.11.4 py311hc1ccb85_0 seaborn 0.13.2 hd8ed1ab_2 conda-forge seaborn-base 0.13.2 pyhd8ed1ab_2 conda-forge send2trash 1.8.2 py311haa95532_0 setuptools 68.2.2 py311haa95532_0 sip 6.7.12 py311hd77b12b_0 six 1.16.0 pyhd3eb1b0_1 smmap 4.0.0 pyhd3eb1b0_0 snappy 1.1.10 h6c2663c_1 sniffio 1.3.0 py311haa95532_0 snowflake-connector-python 3.7.0 py311hd77b12b_0 snowflake-ml-python 1.4.0 pypy_0 https://raw.githubusercontent.com/snowflakedb/snowflake-ml-python/conda/releases snowflake-snowpark-python 1.13.0 py311haa95532_0 sortedcontainers 2.4.0 pyhd3eb1b0_0 soupsieve 2.5 py311haa95532_0 sqlalchemy 2.0.25 py311h2bbff1b_0 sqlite 3.41.2 h2bbff1b_0 sqlparse 0.4.4 py311haa95532_0 stack_data 0.2.0 pyhd3eb1b0_0 statsmodels 0.14.2 py311h0a17f05_0 conda-forge tabulate 0.9.0 py311haa95532_0 tbb 2021.8.0 h59b6b97_0 terminado 0.17.1 py311haa95532_0 threadpoolctl 2.2.0 pyh0d69192_0 tinycss2 1.2.1 py311haa95532_0 tk 8.6.12 h2bbff1b_0 tomlkit 0.11.1 py311haa95532_0 tornado 6.3.3 py311h2bbff1b_0 traitlets 5.7.1 py311haa95532_0 typing-extensions 4.9.0 py311haa95532_1 typing_extensions 4.9.0 py311haa95532_1 tzdata 2024a h04d1e81_0 ucrt 10.0.20348.0 haa95532_0 urllib3 2.0.7 py311haa95532_0 utf8proc 2.6.1 h2bbff1b_1 vc 14.2 h21ff451_1 vc14_runtime 14.38.33130 h82b7239_18 conda-forge vs2015_runtime 14.38.33130 hcb4865c_18 conda-forge waitress 2.0.0 pyhd3eb1b0_0 wcwidth 0.2.5 pyhd3eb1b0_0 webencodings 0.5.1 py311haa95532_1 websocket-client 0.58.0 py311haa95532_4 werkzeug 2.3.8 py311haa95532_0 wheel 0.41.2 py311haa95532_0 widgetsnbextension 4.0.10 py311haa95532_0 win_inet_pton 1.1.0 py311haa95532_0 winpty 0.4.3 4 wrapt 1.14.1 py311h2bbff1b_0 xgboost 1.7.3 py311haa95532_0 xz 5.4.6 h8cc25b3_0 yaml 0.2.5 he774522_0 yarl 1.9.3 py311h2bbff1b_0 zeromq 4.3.5 hd77b12b_0 zipp 3.17.0 py311haa95532_0 zlib 1.2.13 hcfcfb64_5 conda-forge zstd 1.5.5 hd43e919_0

  1. What did you do? from snowflake.ml.modeling.xgboost import XGBRegressor from snowflake.snowpark.functions import col, random, sin, when, lit from utils import get_session

session = get_session.session()

N = 105 _ONE_MILLION = 106

df = session.range(1, N).to_df("ind").with_column( "x_0", ((random() % _ONE_MILLION)/_ONE_MILLION) )

df = df

df = df.with_columns(["weights1","weights2","weights3"],[lit(1.0),when(col("ind") < lit(N / 10), 1.0).otherwise(0.0),when(col("ind") > lit(N / 10), 1.0).otherwise(0.0)])

df = df.with_column( "target", when(col("ind") < lit(N / 10), 1.0).otherwise(0.0) col("x_0") + when(col("ind") > lit(N / 10), 1.0).otherwise(0.0) sin(10*col("x_0")) )

parameters = { "input_cols":["X_0"], "label_cols":["TARGET"], }

model1 = XGBRegressor(
**parameters, sample_weight_col="weights1", output_cols= ["PREDICTION1"],

) model2 = XGBRegressor(
**parameters, sample_weight_col="weights2", output_cols= ["PREDICTION2"],

) model3 = XGBRegressor(
**parameters, sample_weight_col="weights3", output_cols= ["PREDICTION3"],

)

models = [model1, model2, model3] for m in models: m.fit(df)

test = session.range(-1, 1,0.01).to_df("X_0").with_column( "sinus", sin(10*col("X_0")) )

for m in models: test = m.predict(test)

test_snow = test.toPandas() print(test_snow)


output: X_0 SINUS PREDICTION1 PREDICTION2 PREDICTION3 0 -1.00 0.544021 0.515664 0.515664 0.515664 1 -0.99 0.457536 0.405519 0.405519 0.405519 2 -0.98 0.366479 0.183660 0.183660 0.183660 3 -0.97 0.271761 0.211220 0.211220 0.211220 4 -0.96 0.174327 0.039056 0.039056 0.039056 .. ... ... ... ... ... 195 0.95 -0.075151 0.047328 0.047328 0.047328 196 0.96 -0.174327 -0.060364 -0.060364 -0.060364 197 0.97 -0.271761 0.034832 0.034832 0.034832 198 0.98 -0.366479 -0.278535 -0.278535 -0.278535 199 0.99 -0.457536 -0.390598 -0.390598 -0.390598

  1. What did you expect to see? I expected different models to produce different predictions due to the varying sample weights (weights1, weights2, weights3). Specifically:

However, the Snowflake Snowpark implementation of XGBRegressor seems to ignore the sample weights, resulting in identical predictions for all models. Running a similar experiment directly with the standard xgboost library outside of Snowflake results in distinct linear and sinusoidal predictions for model2 and model3, respectively.

sfc-gh-afero commented 1 month ago

Thank you for reporting this issue, I was able to use your example to reproduce it on my end. We will investigate this issue as a bug.