taichi-dev / taichi

Productive, portable, and performant GPU programming in Python.
https://taichi-lang.org
Apache License 2.0
25.51k stars 2.28k forks source link

Using taichi kernel within pytorch dataloader multiprocessing #6725

Open binarydaddy opened 1 year ago

binarydaddy commented 1 year ago

Hi, I am having trouble using taichi kernel within pytorch's dataloader. Currently I have a setup where I have a separate class dedicated for image augmentation using taichi kernel (taichi.init() is called here on initialization of this class), and pytorch's Dataset class holds this class and calls for augmentation on every getitem call.

Current issue I have been experiencing is that when num_worker is 0, everything works fine. However, when I use num_workers > 0, then the program hangs forever.

I believe this has to do something with calling of 'taichi.init()' within python's multiprocesses, but I am not entirely sure how to solve this issue.

Any help would be much appreciated. Thank you.

ailzhang commented 1 year ago

@binarydaddy would you mind sharing a minimal repro for this? Thanks!

li-yanhao commented 1 year ago

Same issue. I wrote a function with taichi for data preprocessing in a script, and call this function in my custom dataset class derived from torch.utils.data.Dataset. When the main training process was sent to a node of a slurm cluster with num_workers > 0 in my dataloader, the training process is hang forever. But with num_worker = 0 everything is fine.

I also tried submitting a single taichi program as unit test to the node, it worked normally. So I think it is not the problem of taichi in the cluster, but taichi in the pytorch dataloader.

Basically I first have my preprocess.py:

# preprocess.py

import taichi as ti

ti.init(arch=ti.cpu)

def cross_difference(img):
    H, W, C = img.shape

    img_out = np.copy(img)
    _cross_diff(img_out)
    return img_out

@ti.kernel
def _cross_diff(img: ti.types.ndarray()):
    H, W, C = img.shape
    # some processing code ...

Then I have my dataset defined by:

# dataset.py

import numpy as np
from torch.utils.data import Dataset
from preprocess import cross_difference

class MyDataset(Dataset):
    def __getitem__(self, idx):
        img = ... # load an image
        img = cross_difference(img)
        # some other processing ...

In my main script I used it like:

# main.py

from torch import utils

my_dataset = MyDataset()
train_loader = utils.data.DataLoader(my_dataset, batch_size=4, shuffle=False, num_workers=4)
# then use train_loader for data loading

From the printed log I got: [Taichi] version 1.3.0, llvm 16.0.0git, commit 0f25b95e, linux, python 3.8.13 [Taichi] Starting on arch=x64 then the training process is hang. But with num_workers=0 everything is normal.

Does anyone know the possible causes and solutions? Thanks!

pableeto commented 1 year ago

Same issue here - is there any solution or workaround?

eleboss commented 10 months ago

Same issue here

bobcao3 commented 10 months ago

Using the "spawn" mode might be required... And ti.init needs to be called in the workers as well