pola-rs / polars

Dataframes powered by a multithreaded, vectorized query engine, written in Rust
https://docs.pola.rs
Other
29.19k stars 1.84k forks source link

DataFrame.filter freezes when used in torch.utils.data.DataLoader with non-zero number of workers #17526

Open kozlov721 opened 2 months ago

kozlov721 commented 2 months ago

Checks

Reproducible example

import polars as pl
from torch.utils.data import DataLoader, Dataset

class DummyDataset(Dataset):
    def __init__(self):
        table = {
            "class": ["dog", "cat", "cat", "cat"],
            "path": ["0.png", "1.png", "2.png", "3.png"]
        }
        self.df = pl.DataFrame(table)

    def __getitem__(self, idx: int) -> str:
        file_name = f"{idx}.png"
        sub_df = self.df.filter(pl.col("path") == file_name)
        return sub_df["class"].to_list()[0]

    def __len__(self) -> int:
        return len(self.df)

dataset = DummyDataset()

# This works fine
for cls in dataset:
    print(cls)

loader_0_workers = DataLoader(dataset, num_workers=0, batch_size=1)
loader_2_workers = DataLoader(dataset, num_workers=2, batch_size=1)

# This works as well
for cls in loader_0_workers:
    print(cls)

# This hangs
for cls in loader_2_workers:
    print(cls)

Log output

dataframe filtered
dataframe filtered
dataframe filtered
dataframe filtered
dataframe filtered
dataframe filtered
dataframe filtered
dataframe filtered
dataframe filtered
Traceback (most recent call last):
  File "/home/martin/polars/example.py", line 37, in <module>
    for cls in loader_2_workers:
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
    idx, data = self._get_data()
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1295, in _get_data
    success, data = self._try_get_data()
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/multiprocessing/queues.py", line 113, in get
    if not self._poll(timeout):
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/multiprocessing/connection.py", line 424, in _poll
    r = wait([self], timeout)
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/martin/miniconda3/envs/pl/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

Issue description

When polars.DataFrame is used inside torch.utils.data.DataLoader with num_workers set to a non-zero value, the DataFrame.filter method freezes.

The version of torch used is v2.3.0.

Expected behavior

DataFrame.filter should not freeze when used in torch.utils.data.DataLoader with non-zero number of workers.

Installed versions

``` --------Version info--------- Polars: 1.1.0 Index type: UInt32 Platform: Linux-6.8.0-36-generic-x86_64-with-glibc2.39 Python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] ----Optional dependencies---- adbc_driver_manager: cloudpickle: 3.0.0 connectorx: deltalake: fastexcel: fsspec: 2024.6.0 gevent: great_tables: hvplot: matplotlib: 3.8.4 nest_asyncio: 1.6.0 numpy: 1.26.4 openpyxl: pandas: 2.2.2 pyarrow: 15.0.2 pydantic: 2.7.1 pyiceberg: sqlalchemy: 2.0.30 torch: 2.3.0+cu121 xlsx2csv: xlsxwriter: ```
ritchie46 commented 2 months ago

That's a bug in torch's dataloader, or in your way of calling it. They fork the process with multiprocessing, which is outright dangerous. As they confirm in their docs:

https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks

There are a lot of things that can go wrong when a new process is spawned, with the most common cause of deadlocks being background threads. If there’s any thread that holds a lock or imports a module, and fork is called, it’s very likely that the subprocess will be in a corrupted state and will deadlock or fail in a different way. Note that even if you don’t, Python built in libraries do - no need to look further than [multiprocessing](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing). [multiprocessing.Queue](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue) is actually a very complex class, that spawns multiple threads used to serialize, send and receive objects, and they can cause aforementioned problems too. If you find yourself in such situation try using a SimpleQueue, that doesn’t use any additional threads.

We’re trying our best to make it easy for you and ensure these deadlocks don’t happen but some things are out of our control. If you have any issues you can’t cope with for a while, try reaching out on forums, and we’ll see if it’s an issue we can fix.

It is also written in our user guide: https://docs.pola.rs/user-guide/misc/multiprocessing/

TL,DR: if a library/user forks a process when there is any other library in the process that has a mutex, it is a logical bug!

kozlov721 commented 2 months ago

Thanks for the response. I was confused because the above snippet works just fine when I replace polars with pandas so I assumed there's a bug in polars somewhere.