pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
30.21k stars 1.95k forks source link

StringCacheMismatchError when using joblib.Parallel and Categorical data #18528

Open AndreiPashkin opened 2 months ago

AndreiPashkin commented 2 months ago

Checks

Reproducible example

import polars as pl
from joblib import Parallel, delayed, __version__ as joblib_version

pl.show_versions()
print(joblib_version)
pl.enable_string_cache()

test_df = pl.DataFrame({
    "item_id": [str(g) for g in range(100)]
}).with_columns(pl.col('item_id').cast(pl.Categorical))

def process(df, game_id):
    df.filter(pl.col('item_id') == pl.lit(game_id, pl.Categorical))

result = Parallel(
    n_jobs=-1,
)(delayed(process)(test_df, str(g)) for g in range(10))
# [process(test_df, str(g)) for g in range(10)]

Log output

--------Version info---------
Polars:               1.5.0
Index type:           UInt32
Platform:             Linux-5.15.154+-x86_64-with-glibc2.31
Python:               3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]

----Optional dependencies----
adbc_driver_manager:  <not installed>
cloudpickle:          3.0.0
connectorx:           <not installed>
deltalake:            <not installed>
fastexcel:            <not installed>
fsspec:               2024.6.1
gevent:               <not installed>
great_tables:         <not installed>
hvplot:               <not installed>
matplotlib:           3.7.5
nest_asyncio:         1.6.0
numpy:                1.26.4
openpyxl:             3.1.5
pandas:               2.2.2
pyarrow:              17.0.0
pydantic:             2.8.2
pyiceberg:            <not installed>
sqlalchemy:           2.0.30
torch:                2.4.0+cpu
xlsx2csv:             <not installed>
xlsxwriter:           <not installed>
1.4.2
---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 463, in _process_worker
    r = call_item()
  File "/opt/conda/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 291, in __call__
    return self.fn(*self.args, **self.kwargs)
  File "/opt/conda/lib/python3.10/site-packages/joblib/parallel.py", line 598, in __call__
    return [func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/joblib/parallel.py", line 598, in <listcomp>
    return [func(*args, **kwargs)
  File "/tmp/ipykernel_2632/317073648.py", line 15, in process
  File "/opt/conda/lib/python3.10/site-packages/polars/dataframe/frame.py", line 4554, in filter
    return self.lazy().filter(*predicates, **constraints).collect(_eager=True)
  File "/opt/conda/lib/python3.10/site-packages/polars/lazyframe/frame.py", line 2027, in collect
    return wrap_df(ldf.collect(callback))
polars.exceptions.StringCacheMismatchError: cannot compare categoricals coming from different sources, consider setting a global StringCache.

Help: if you're using Python, this may look something like:

    with pl.StringCache():
        # Initialize Categoricals.
        df1 = pl.DataFrame({'a': ['1', '2']}, schema={'a': pl.Categorical})
        df2 = pl.DataFrame({'a': ['1', '3']}, schema={'a': pl.Categorical})
    # Your operations go here.
    pl.concat([df1, df2])

Alternatively, if the performance cost is acceptable, you could just set:

    import polars as pl
    pl.enable_string_cache()

on startup.
"""

The above exception was the direct cause of the following exception:

StringCacheMismatchError                  Traceback (most recent call last)
Cell In[76], line 17
     14 def process(df, game_id):
     15     df.filter(pl.col('item_id') == pl.lit(game_id, pl.Categorical))
---> 17 result = Parallel(
     18     n_jobs=-1,
     19 )(delayed(process)(test_df, str(g)) for g in range(10))
     20 # [process(test_df, str(g)) for g in range(10)]

File /opt/conda/lib/python3.10/site-packages/joblib/parallel.py:2007, in Parallel.__call__(self, iterable)
   2001 # The first item from the output is blank, but it makes the interpreter
   2002 # progress until it enters the Try/Except block of the generator and
   2003 # reaches the first `yield` statement. This starts the asynchronous
   2004 # dispatch of the tasks to the workers.
   2005 next(output)
-> 2007 return output if self.return_generator else list(output)

File /opt/conda/lib/python3.10/site-packages/joblib/parallel.py:1650, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1647     yield
   1649     with self._backend.retrieval_context():
-> 1650         yield from self._retrieve()
   1652 except GeneratorExit:
   1653     # The generator has been garbage collected before being fully
   1654     # consumed. This aborts the remaining tasks if possible and warn
   1655     # the user if necessary.
   1656     self._exception = True

File /opt/conda/lib/python3.10/site-packages/joblib/parallel.py:1754, in Parallel._retrieve(self)
   1747 while self._wait_retrieval():
   1748 
   1749     # If the callback thread of a worker has signaled that its task
   1750     # triggered an exception, or if the retrieval loop has raised an
   1751     # exception (e.g. `GeneratorExit`), exit the loop and surface the
   1752     # worker traceback.
   1753     if self._aborting:
-> 1754         self._raise_error_fast()
   1755         break
   1757     # If the next job is not ready for retrieval yet, we just wait for
   1758     # async callbacks to progress.

File /opt/conda/lib/python3.10/site-packages/joblib/parallel.py:1789, in Parallel._raise_error_fast(self)
   1785 # If this error job exists, immediately raise the error by
   1786 # calling get_result. This job might not exists if abort has been
   1787 # called directly or if the generator is gc'ed.
   1788 if error_job is not None:
-> 1789     error_job.get_result(self.timeout)

File /opt/conda/lib/python3.10/site-packages/joblib/parallel.py:745, in BatchCompletionCallBack.get_result(self, timeout)
    739 backend = self.parallel._backend
    741 if backend.supports_retrieve_callback:
    742     # We assume that the result has already been retrieved by the
    743     # callback thread, and is stored internally. It's just waiting to
    744     # be returned.
--> 745     return self._return_or_raise()
    747 # For other backends, the main thread needs to run the retrieval step.
    748 try:

File /opt/conda/lib/python3.10/site-packages/joblib/parallel.py:763, in BatchCompletionCallBack._return_or_raise(self)
    761 try:
    762     if self.status == TASK_ERROR:
--> 763         raise self._result
    764     return self._result
    765 finally:

StringCacheMismatchError: cannot compare categoricals coming from different sources, consider setting a global StringCache.

Help: if you're using Python, this may look something like:

    with pl.StringCache():
        # Initialize Categoricals.
        df1 = pl.DataFrame({'a': ['1', '2']}, schema={'a': pl.Categorical})
        df2 = pl.DataFrame({'a': ['1', '3']}, schema={'a': pl.Categorical})
    # Your operations go here.
    pl.concat([df1, df2])

Alternatively, if the performance cost is acceptable, you could just set:

    import polars as pl
    pl.enable_string_cache()

on startup.

Issue description

It seems like Polars has a problem with maintaining the internal representation of Categorical values, either due to multi-processing or due to how joblib serializes and deserializes data. Or maybe I'm just doing something wrong :)

p.s. Switching to Enum from Categorical makes this problem go away.

Expected behavior

I obviously expect no exception happening.

Installed versions

``` --------Version info--------- Polars: 1.5.0 Index type: UInt32 Platform: Linux-5.15.154+-x86_64-with-glibc2.31 Python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] ----Optional dependencies---- adbc_driver_manager: cloudpickle: 3.0.0 connectorx: deltalake: fastexcel: fsspec: 2024.6.1 gevent: great_tables: hvplot: matplotlib: 3.7.5 nest_asyncio: 1.6.0 numpy: 1.26.4 openpyxl: 3.1.5 pandas: 2.2.2 pyarrow: 17.0.0 pydantic: 2.8.2 pyiceberg: sqlalchemy: 2.0.30 torch: 2.4.0+cpu xlsx2csv: xlsxwriter: ``` Joblib version: `1.4.2`.
c-peters commented 2 months ago

This has nothing to do with joblib or multiprocessing

import polars as pl
df = pl.DataFrame(pl.Series("a",["1","2","3"],dtype=pl.Categorical))
df.filter(pl.col("a") == pl.lit("1",dtype = pl.Categorical))

The error indicates that the literal is coming from a different Stringcache than the original column (See https://docs.pola.rs/user-guide/concepts/data-types/categoricals/).

I suppose for literals we could cast them to the correct stringcache as we do for strings, which makes sense to me

For now, removing the dtype on the literal should work pl.lit(game_id)

barak1412 commented 1 month ago

@c-peters for learning purpose, may you please point me the place that the right StringCache feched, given a string literal? Thanks.

c-peters commented 1 month ago

Ignore my comment, I did not notice you enabled the string cache globally. I am not too familiar with joblib, but what I am assuming is that joblib starts a seperate process entirely which receives a new cache. setting the backend to threading should work

AndreiPashkin commented 1 month ago

@c-peters, I was thinking that it is because Polars does not play very well with forking, are you sure it is not the case? When I set joblib's backend to threading it indeed works.

ritchie46 commented 1 month ago

If you fork you start completely new processes and a global string cache cannot be global between processes. This is a bug in your query/usage.

AndreiPashkin commented 1 month ago

If you fork you start completely new processes and a global string cache cannot be global between processes. This is a bug in your query/usage.

Yes, that's true. I just wanted to confirm that.