ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.03k stars 5.59k forks source link

Issues with Batch Overflow during exceptions while utilizing map_batches #47162

Open sayanbiswas59 opened 4 weeks ago

sayanbiswas59 commented 4 weeks ago

What happened + What you expected to happen

We are attempting to utilize Ray v2.23 for batch inferencing, specifically on multi-modal data, by leveraging LMMs.

We have observed an issue where, if an exception arises while executing an item in the batch, the pending items from the current batch accumulate in the next batch of the succeeding task. This causes subsequent tasks to fail due to overflow. Can anyone identify what we might be overlooking?

We are trying to find a solution that allows us to skip the problematic item in the batch and proceed with processing the remaining items. While we have considered skipping the entire batch if an exception occurs, this does not resolve the overflow issue.

Below is an example of the Actor log for which we observe the batch overflow when there is an exception. Batch size = 1000 Exception occurs while processing the task ---> 125/960 and then it starts to overflow.

Processed prompts:   0%|          | 0/960 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:   0%|          | 1/960 [00:47<12:33:23, 47.14s/it, est. speed input: 23.23 toks/s, output: 0.85 toks/s]
Processed prompts:   1%|          | 10/960 [00:48<56:54,  3.59s/it, est. speed input: 476.70 toks/s, output: 8.18 toks/s] 
Processed prompts:  13%|█▎        | 125/960 [01:58<13:11,  1.05it/s, est. speed input: 7946.17 toks/s, output: 104.40 toks/s]

Processed prompts:   0%|          | 0/1603 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:   0%|          | 0/1603 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/2371 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:   0%|          | 0/2371 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/3139 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:   0%|          | 0/3139 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/4123 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:   0%|          | 0/4123 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
.
.
.

Processed prompts:   0%|          | 0/1377874 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
Processed prompts:   0%|          | 0/1377874 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Versions / Dependencies

Ray v2.23
Python 3.10
cuda 12.1

Reproduction script

dataset = ray.data.read_parquet("file_path")
class MyPredictor:

    def __init__(self):
        self.my_model = MyModel(model_path="<model_path>",
                                tensor_parallel_size=1)

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:

        try:
            start_time = time.time()

            inputs = [{"input": input, "data": {
                "image": Image.open(io.BytesIO(base64.b64decode(batch["<image_column_name>"][i])))}} for i in
                       range(len(batch["<image_column_name>"]))]

            predictions = self.my_model.generate(
                inputs, sampling_params="<sampling_params>")
            batch["<output_label>"] = [pred.outputs[0].text for pred in predictions]
            end_time = time.time()
            print(f'Total Inference Time for {len(inputs)} - {end_time - start_time}')

        except OSError as os_error:
            print(f"OS error: {os_error}")
            batch["<output_label>"] = ["" for _ in range(len(batch["<image_column_name>"]))]

        except Exception as error:
            print(f"Misc error: {error}")
            batch["<output_label>"] = ["" for _ in range(len(batch["<image_column_name>"]))]

        finally:
            del batch['<image_bytes_column>']
            return batch

dataset = dataset.map_batches(
    MyPredictor,
    concurrency=int("<num_workers>") * int("<num_gpus>"),
    batch_size=int("<batch_size>"),
    num_gpus=1

Issue Severity

High: It blocks me from completing my task.

scottjlee commented 3 weeks ago

The underlying format of the returned batch from __call__ should still be Dict[str, np.ndarray]. Likely, the output batch is being interpreted as a different format, which is causing the batches to look weird.