Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
10.63k stars 1.06k forks source link

cuda error when serve with `workers_per_device` > 1 and using concurrency request #1733

Closed puppyapple closed 1 month ago

puppyapple commented 1 month ago

Bug description

I change the serve.py a little to add the workers_per_device parameter, and then I served with workers_per_device > 1. When the concurrency of request > 1, I will get:

lIndex: block: [0,0,0], thread: [119,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [120,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [121,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [122,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [123,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
LitAPI ran into an error while processing the request uid=a24a7f9e-90ff-422c-b5ed-724680a485fe.
Please check the error trace for more details.
Traceback (most recent call last):
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litserve/server.py", line 151, in run_single_loop
    y = _inject_context(
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litserve/server.py", line 72, in _inject_context
    return func(*args, **kwargs)
  File "/home/puppyapple/Server/BigAI/Chinese_LLM_From_Scratch/Journey/Day11/service.py", line 92, in predict
    output = self.llm.generate(
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/api.py", line 445, in generate
    outputs = generate_fn(
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/generate/base.py", line 140, in generate
    token = next_token(
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/generate/base.py", line 77, in next_token
    logits = model(x, input_pos)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 141, in forward
    output = self._forward_module(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/model.py", line 83, in forward
    mask = self.mask_cache.index_select(2, input_pos)
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

the modified server script is below:

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import click
from pathlib import Path
from pprint import pprint
from typing import Dict, Any, Optional, Literal
import torch
from litgpt.api import LLM
from litgpt.utils import auto_download_checkpoint
from litserve import LitAPI, LitServer

class BaseLitAPI(LitAPI):
    def __init__(
        self,
        checkpoint_dir: Path,
        quantize: Optional[
            Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]
        ] = None,
        precision: Optional[str] = None,
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 1.0,
        max_new_tokens: int = 50,
        devices: int = 1,
    ) -> None:

        super().__init__()
        self.checkpoint_dir = checkpoint_dir
        self.quantize = quantize
        self.precision = precision
        self.temperature = temperature
        self.top_k = top_k
        self.max_new_tokens = max_new_tokens
        self.top_p = top_p
        self.devices = devices

    def setup(self, device: str) -> None:
        if ":" in device:
            accelerator, device = device.split(":")
            device = f"[{int(device)}]"
        else:
            accelerator = device
            device = 1

        print("Initializing model...")
        self.llm = LLM.load(model=self.checkpoint_dir, distribute=None)

        self.llm.distribute(
            devices=self.devices,
            accelerator=accelerator,
            quantize=self.quantize,
            precision=self.precision,
            generate_strategy=(
                "sequential" if self.devices is not None and self.devices > 1 else None
            ),
        )
        print("Model successfully initialized.")

    def decode_request(self, request: Dict[str, Any]) -> Any:
        # Convert the request payload to your model input.
        prompt = str(request["prompt"])
        return prompt

class SimpleLitAPI(BaseLitAPI):
    def __init__(
        self,
        checkpoint_dir: Path,
        quantize: Optional[str] = None,
        precision: Optional[str] = None,
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 1.0,
        max_new_tokens: int = 50,
        devices: int = 1,
    ):
        super().__init__(
            checkpoint_dir,
            quantize,
            precision,
            temperature,
            top_k,
            top_p,
            max_new_tokens,
            devices,
        )

    def setup(self, device: str):
        super().setup(device)

    def predict(self, inputs: str) -> Any:
        output = self.llm.generate(
            inputs,
            temperature=self.temperature,
            top_k=self.top_k,
            top_p=self.top_p,
            max_new_tokens=self.max_new_tokens,
        )
        return output

    def encode_response(self, output: str) -> Dict[str, Any]:
        # Convert the model output to a response payload.
        return {"output": output}

class StreamLitAPI(BaseLitAPI):
    def __init__(
        self,
        checkpoint_dir: Path,
        quantize: Optional[str] = None,
        precision: Optional[str] = None,
        temperature: float = 0.8,
        top_k: int = 50,
        top_p: float = 1.0,
        max_new_tokens: int = 50,
        devices: int = 1,
    ):
        super().__init__(
            checkpoint_dir,
            quantize,
            precision,
            temperature,
            top_k,
            top_p,
            max_new_tokens,
            devices,
        )

    def setup(self, device: str):
        super().setup(device)

    def predict(self, inputs: torch.Tensor) -> Any:
        # Run the model on the input and return the output.
        yield from self.llm.generate(
            inputs,
            temperature=self.temperature,
            top_k=self.top_k,
            top_p=self.top_p,
            max_new_tokens=self.max_new_tokens,
            stream=True,
        )

    def encode_response(self, output):
        for out in output:
            yield {"output": out}

@click.command()
@click.option("--checkpoint_dir", type=str)
@click.option("--quantize", type=str, default=None)
@click.option("--precision", type=str, default=None)
@click.option("--temperature", type=float, default=0.8)
@click.option("--top_k", type=int, default=50)
@click.option("--top_p", type=float, default=1.0)
@click.option("--max_new_tokens", type=int, default=50)
@click.option("--devices", type=int, default=1)
@click.option("--workers_per_device", type=int, default=20)
@click.option("--port", type=int, default=8000)
@click.option("--stream", type=bool, default=False)
@click.option("--accelerator", type=str, default="auto")
def run_server(
    checkpoint_dir: Path,
    quantize: Optional[
        Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]
    ] = None,
    precision: Optional[str] = None,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1.0,
    max_new_tokens: int = 50,
    devices: int = 1,
    port: int = 8000,
    accelerator: str = "auto",
    workers_per_device: int = 20,
    stream: bool = False,
    access_token: Optional[str] = None,
) -> None:
    """Serve a LitGPT model using LitServe.

    Evaluate a model with the LM Evaluation Harness.

    Arguments:
        checkpoint_dir: The checkpoint directory to load the model from.
        quantize: Whether to quantize the model and using which method:
            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
            - bnb.int8: 8-bit quantization from bitsandbytes
            for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
        precision: Optional precision setting to instantiate the model weights in. By default, this will
            automatically be inferred from the metadata in the given ``checkpoint_dir`` directory.
        temperature: Temperature setting for the text generation. Value above 1 increase randomness.
            Values below 1 decrease randomness.
        top_k: The size of the pool of potential next tokens. Values larger than 1 result in more novel
            generated text but can also lead to more incoherent texts.
        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
            In top-p sampling, the next token is sampled from the highest probability tokens
            whose cumulative probability exceeds the threshold `top_p`. When specified,
            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
            to sampling the most probable token, while `top_p=1` samples from the whole distribution.
            It can be used in conjunction with `top_k` and `temperature` with the following order
            of application:

            1. `top_k` sampling
            2. `temperature` scaling
            3. `top_p` sampling

            For more details, see https://arxiv.org/abs/1904.09751
            or https://huyenchip.com/2024/01/16/sampling.html#top_p
        max_new_tokens: The number of generation steps to take.
        workers_per_device: How many workers to use per device.
        max_batch_size: The maximum batch size to use.
        devices: How many devices/GPUs to use.
        accelerator: The type of accelerator to use. For example, "auto", "cuda", "cpu", or "mps".
            The "auto" setting (default) chooses a GPU if available, and otherwise uses a CPU.
        port: The network port number on which the model is configured to be served.
        stream: Whether to stream the responses.
        access_token: Optional API token to access models with restrictions.
    """
    checkpoint_dir = auto_download_checkpoint(
        model_name=checkpoint_dir, access_token=access_token
    )
    pprint(locals())

    if not stream:
        server = LitServer(
            SimpleLitAPI(
                checkpoint_dir=checkpoint_dir,
                quantize=quantize,
                precision=precision,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_new_tokens=max_new_tokens,
                devices=devices,
            ),
            workers_per_device=workers_per_device,
            accelerator=accelerator,
            devices=devices,  # We need to use the devives inside the `SimpleLitAPI` class
        )

    else:
        server = LitServer(
            StreamLitAPI(
                checkpoint_dir=checkpoint_dir,
                quantize=quantize,
                precision=precision,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                max_new_tokens=max_new_tokens,
                devices=devices,  # We need to use the devives inside the `StreamLitAPI` class
            ),
            workers_per_device=workers_per_device,
            accelerator=accelerator,
            devices=1,
            stream=True,
        )

    server.run(port=port, generate_client_file=False)

if __name__ == "__main__":
    run_server()

the test script is below:

import asyncio
import aiohttp
import json
import argparse
import hashlib
import time
from tqdm import tqdm
from litgpt.prompts import MicroStories

def hash_prompt(prompt):
    return hashlib.md5(prompt.encode()).hexdigest()

async def generate_response(session, prompt, semaphore, cache):
    prompt_hash = hash_prompt(prompt)
    if prompt_hash in cache:
        return cache[prompt_hash]

    async with semaphore:
        async with session.post(
            "http://127.0.0.1:8000/predict", json={"prompt": prompt}
        ) as response:
            result = await response.json()
            cache[prompt_hash] = result
            return result

async def main(concurrency, test_mode):
    ms = MicroStories()

    with open(
        "../../Data/TinyStoriesInstruct/sft_data_v2.json", "r", encoding="utf-8"
    ) as f:
        sft_data = json.load(f)

    if test_mode:
        sft_data = sft_data[:8]

    # 读取缓存
    try:
        with open("dpo_cache.json", "r", encoding="utf-8") as f:
            cache = json.load(f)
    except FileNotFoundError:
        cache = {}

    semaphore = asyncio.Semaphore(concurrency)

    async with aiohttp.ClientSession() as session:
        tasks = []
        for case in tqdm(sft_data, desc="生成DPO数据"):
            prompt = ms.apply(prompt=case["instruction"], input=case["input"])
            task = asyncio.create_task(
                generate_response(session, prompt, semaphore, cache)
            )
            tasks.append(task)

        responses = await asyncio.gather(*tasks)

    dpo_data = []
    for case, response in zip(sft_data, responses):
        prompt = ms.apply(prompt=case["instruction"], input=case["input"])
        dpo_sample = {
            "prompt": prompt,
            "rejected": response["output"],
            "chosen": case["output"],
        }
        dpo_data.append(dpo_sample)

    # 保存更新后的缓存
    with open("dpo_cache.json", "w", encoding="utf-8") as f:
        json.dump(cache, f, ensure_ascii=False, indent=2)

    output_file = "dpo_data_test.json" if test_mode else "dpo_data.json"
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(dpo_data, f, ensure_ascii=False, indent=2)

    print(f"DPO数据已生成并保存到 {output_file}")
    print(f"缓存已更新并保存到 dpo_cache.json")

    end_time = time.time()
    execution_time = end_time - start_time
    print(f"总执行时间: {execution_time:.2f} 秒")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="生成DPO数据")
    parser.add_argument("--concurrency", type=int, default=10, help="并发数量")
    parser.add_argument("--test", action="store_true", help="测试模式")
    args = parser.parse_args()

    start_time = time.time()
    asyncio.run(main(args.concurrency, args.test))

What operating system are you using?

Linux

LitGPT Version


Version: 0.4.10
rasbt commented 1 month ago

Thanks for reporting. I am not exactly sure how the workers_per_device implementation in LitGPT works, i.e., how it works under the hood. Maybe @aniketmaurya can chime in here.

aniketmaurya commented 1 month ago

@puppyapple seems like a wrong device id is being set, could you print the device in the setup method and see what it prints?

puppyapple commented 1 month ago

@puppyapple seems like a wrong device id is being set, could you print the device in the setup method and see what it prints? So I add the print in the setup function of BaseLitAPI above

image

which gives me :

INFO:     Waiting for application startup.
INFO:     Application startup complete.
device passed in : cuda:0
Initializing model...
device='[0]'
self.devices=1
device passed in : cuda:0
Initializing model...
device='[0]'
self.devices=1
Model successfully initialized.
Setup complete for worker 1.
Model successfully initialized.
Setup complete for worker 0.

Are these normal?

puppyapple commented 1 month ago

Thanks for reporting. I am not exactly sure how the workers_per_device implementation in LitGPT works, i.e., how it works under the hood. Maybe @aniketmaurya can chime in here.

@rasbt @aniketmaurya I have tried another test without litgpt serve: using multiprocessing to load n models to process chunks of data in parallel, and I will get the similar error:

e` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [61,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [62,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1231: indexSelectSmallIndex: block: [0,0,0], thread: [63,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
multiprocessing.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/multiprocessing/pool.py", line 48, in mapstar
    return list(map(*args))
  File "/home/puppyapple/Server/BigAI/Chinese_LLM_From_Scratch/Journey/Day11/multi_model_inference.py", line 26, in process_chunk
    response = model.generate(prompt=prompt, max_new_tokens=350)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/api.py", line 445, in generate
    outputs = generate_fn(
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/generate/base.py", line 140, in generate
    token = next_token(
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/generate/base.py", line 77, in next_token
    logits = model(x, input_pos)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 141, in forward
    output = self._forward_module(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/model.py", line 94, in forward
    x = block(x, cos, sin, mask, input_pos)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/model.py", line 197, in forward
    attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/litgpt/model.py", line 237, in forward
    qkv = self.attn(x)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/puppyapple/anaconda3/envs/bigmodel/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, lda, b, CUDA_R_16BF, ldb, &fbeta, c, CUDA_R_16BF, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`

the script that I used is below:

import json
import multiprocessing
from functools import partial
from litgpt import LLM
from litgpt.prompts import MicroStories
import click
import torch

# 设置多进程启动方法为'spawn'
multiprocessing.set_start_method("spawn", force=True)

def init_model():
    model = LLM.load(
        model="../../Experiments/Output/sft/microstories/mask_prompt_5e-4/final"
    )
    return model

def process_chunk(model, chunk):
    ms = MicroStories()
    results = []
    for case in chunk:
        prompt = ms.apply(prompt=case["instruction"], input=case["input"])
        with torch.no_grad():
            response = model.generate(prompt=prompt, max_new_tokens=350)
        results.append(
            {"prompt": prompt, "rejected": response, "chosen": case["output"]}
        )
    return results

@click.command()
@click.option("-n", "--num_processes", default=4, help="并发进程数")
@click.option("--test", is_flag=True, help="测试模式,只处理前100条数据")
def main(num_processes, test):
    # 加载SFT数据
    with open(
        "../../Data/TinyStoriesInstruct/sft_data_v2.json", "r", encoding="utf-8"
    ) as f:
        sft_data = json.load(f)

    if test:
        sft_data = sft_data[:100]

    # 确定进程数量
    n_processes = min(multiprocessing.cpu_count(), num_processes)

    # 初始化模型
    model = init_model()

    # 使用partial创建一个新的函数,将model作为第一个参数
    process_chunk_with_model = partial(process_chunk, model)

    # 将数据分成n_processes份
    chunk_size = len(sft_data) // n_processes
    chunks = [sft_data[i : i + chunk_size] for i in range(0, len(sft_data), chunk_size)]

    # 使用进程池并行处理数据
    with multiprocessing.Pool(n_processes) as pool:
        results = pool.map(process_chunk_with_model, chunks)

    # 合并结果
    dpo_samples = [item for sublist in results for item in sublist]

    # 保存结果
    output_file = "dpo_samples_test.json" if test else "dpo_samples.json"
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(dpo_samples, f, ensure_ascii=False, indent=2)

    print(f"处理完成,结果已保存到 {output_file}")

if __name__ == "__main__":
    main()

my model is load from a checkpoint of SFT from litgpt finetune_full. not sure if this error is the same root cause as the serve case above, if yes, then maybe the problem is not in litserve? Since I did not use anything related with it this time.

puppyapple commented 1 month ago

Update: I updated the litgpt to the latest version(0.4.12), all the errors above disappear for now: 25 workers under 25 concurrency requests for massive data generation, and no cuda errors for 30 min's running till now. Not sure which update(s) between 0.4.10 and 0.4.12 fix this.

rasbt commented 1 month ago

Nice, glad to hear that this works fine now without requiring any additional fix! I'll close this issue as completed, but please let us know in case there are any issues that occur later.