Eventual-Inc / Daft

Distributed DataFrame for Python designed for the cloud, powered by Rust
https://getdaft.io
Apache License 2.0
2.09k stars 141 forks source link

Model loads after each completed partition #2878

Open conceptofmind opened 18 hours ago

conceptofmind commented 18 hours ago

Given the embedding udf below, the model re-loads and is reinitializing after each write and completed partition / parquet file:

import daft
import torch
import numpy as np
from typing import Optional

BATCH_SIZE = 1
NUM_GPUS = 1
return_dtype=daft.DataType.list(daft.DataType.float64())

@daft.udf(return_dtype=return_dtype, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE)
class STEmbeddingUDF:
    def __init__(
        self,
        model_name: str,
        device: str = "cuda",
        convert_to_tensor: bool = False,
        torch_dtype: torch.dtype = torch.float16,
        set_seq_len: bool = True,
        max_seq_length: Optional[int] = 2048,
    ):
        from sentence_transformers import SentenceTransformer

        self.model = SentenceTransformer(
            model_name,
            device=device,
            model_kwargs={"torch_dtype": torch_dtype}
        )

        if set_seq_len:
            self.model.max_seq_length = max_seq_length

        self.convert_to_tensor = convert_to_tensor
        self.device = device

    def __call__(self, text_col: daft.DataFrame) -> daft.DataFrame:
        embeddings = self.model.encode(
            text_col.to_pylist(), 
            batch_size=BATCH_SIZE,
            convert_to_tensor=self.convert_to_tensor,
            device=self.device,
        )
        if self.convert_to_tensor is not True:
            embeddings = embeddings[0].astype(np.float64) 
        return [embeddings]

...

df = daft.read_parquet(f"hf://datasets/name")

processor = STEmbeddingUDF
data = col("text")
df = df.with_column("embeddings, processor(data))

df.to_parquet("embeddings")

Model reloads after completing and writing one partition/file:

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:49<00:00,  7.09s/it]
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:47<00:00,  6.82s/it]
Loading checkpoint shards:  29%|████████████████████████████████▎                                                                                | 2/7 [00:13<00:32,  6.41s/it]

Any input would be greatly appreciated.

jaychia commented 17 hours ago

Indeed. This is something @kevinzwang is working on stabilizing our solution for.

You can actually try it out with the environment variable: DAFT_ENABLE_ACTOR_POOL_PROJECTIONS =1

The fix is currently available for the local executor. We have some local branches available for the Ray Runner as well and are in the process of testing/stabilizing both. Let us know if you try it!

jaychia commented 17 hours ago

I see you're using a GPU as well in your UDF. We will probably want to correctly assign the CUDA_VISBLE_DEVICES appropriately for each instance of your UDF which isn't yet being done on the PyRunner.

conceptofmind commented 17 hours ago

Indeed. This is something @kevinzwang is working on stabilizing our solution for.

You can actually try it out with the environment variable: DAFT_ENABLE_ACTOR_POOL_PROJECTIONS =1

The fix is currently available for the local executor. We have some local branches available for the Ray Runner as well and are in the process of testing/stabilizing both. Let us know if you try it!

Hi @jaychia ,

I will test out setting the environment variable for cuda visible devices and actor pool.

Thank you!

jaychia commented 11 hours ago

Made a PR for an initial attempt at doing CUDA_VISIBLE_DEVICES: https://github.com/Eventual-Inc/Daft/pull/2882

You'll likely need that if running multi-GPU on a single node + PyRunner!