Open minglong-huang opened 2 months ago
I found a similar open issue that addresses turning off asynchronous execution when running a local LLM and Embeddings. The solution involves setting the is_async
parameter to False
in the generate
method call of your LLM class and ensuring that the embed_texts
method in your Embeddings class also handles the is_async
parameter correctly.
Here is the modified code snippet for your MyLLM
class:
class MyLLM(BaseRagasLLM):
def __init__(self, llm_path):
self.tokenizer = AutoTokenizer.from_pretrained(llm_path, trust_remote_code=True)
self.base_llm = AutoModel.from_pretrained(llm_path, trust_remote_code=True)
self.base_llm = self.base_llm.eval()
@property
def llm(self):
return self.base_llm
def get_llm_result(self, prompt):
generations = []
llm_output = {}
token_total = 0
content = prompt.to_string()
print(content)
text, history = self.base_llm.chat(self.tokenizer, content, history=[])
print(f'*'*15)
print(("Generated text: %s", text))
generations.append([Generation(text=text)])
token_total += len(text)
llm_output['token_total'] = token_total
return LLMResult(generations=generations, llm_output=llm_output)
def generate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = [],
):
print(f'Running generate_text function...')
result = self.get_llm_result(prompt)
return result
async def agenerate_text(
self,
prompt: PromptValue,
n: int = 1,
temperature: float = 1e-8,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = [],
) -> LLMResult:
generations = []
llm_output = {}
token_total = 0
content = prompt.to_string()
try:
text, history = await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, self.base_llm.chat, self.tokenizer, content),
timeout=42
)
except asyncio.TimeoutError:
print("Operation timed out, please check the code or increase the timeout duration")
text = ""
except asyncio.CancelledError:
print("Task was cancelled, please check the code")
info = traceback.format_exc()
print(f"info = {info}")
text = ""
except Exception as e:
print(f"An unknown error occurred: {e}")
text = ""
generations.append([Generation(text=text)])
token_total += len(text)
llm_output['token_total'] = token_total
result = LLMResult(generations=generations, llm_output=llm_output)
return result
async def generate(
self,
prompt: PromptValue,
n: int = 1,
temperature: t.Optional[float] = None,
stop: t.Optional[t.List[str]] = None,
callbacks: Callbacks = None,
is_async: bool = True,
) -> LLMResult:
if temperature is None:
temperature = 1e-8
if is_async:
return await self.agenerate_text(prompt, n, temperature, stop, callbacks)
else:
return self.generate_text(prompt, n, temperature, stop, callbacks)
For the TestEmbedding
class, ensure you set is_async=False
when calling the embed_texts
method:
class TestEmbedding(Embeddings, ABC):
run_config: RunConfig
def __init__(self, model_path):
self.embed_texts = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
def embed_text(self, text: str) -> List[float]:
embs = self.embed_texts([text])
return embs[0]
def embed_texts(self, texts: List[str]) -> t.List[t.List[float]]:
embed_documents_with_retry = add_retry(
self.embed_documents, self.run_config
)
return embed_documents_with_retry(texts)
async def aembed_text(self, text: str, is_async=True) -> List[float]:
embs = await self.embed_texts([text], is_async=False)
return embs[0]
async def aembed_texts(
self, texts: List[str], is_async: bool = True
) -> t.List[t.List[float]]:
if is_async:
aembed_documents_with_retry = add_async_retry(
self.aembed_documents, self.run_config
)
return await aembed_documents_with_retry(texts)
else:
loop = asyncio.get_event_loop()
embed_documents_with_retry = add_retry(
self.embed_documents, self.run_config
)
return await loop.run_in_executor(None, embed_documents_with_retry, texts)
def set_run_config(self, run_config: RunConfig):
self.run_config = run_config
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_texts.encode_corpus(texts, self.batch_size, self.max_length).tolist()
def embed_query(self, text: str) -> List[float]:
return self.embed_texts.encode_queries(text, self.batch_size, self.max_length).tolist()
By setting is_async=False
, you ensure that the execution will be synchronous [1].
To continue talking to Dosu, mention @dosu.
Here is my code:
and i have set is_async=False ,but it still run async def agenerate_text( self, prompt: PromptValue, n: int = 1, temperature: float = 1e-8, stop: t.Optional[t.List[str]] = None, callbacks: Callbacks = [], ) .
and then report a error