vanna-ai / vanna

šŸ¤– Chat with your SQL database šŸ“Š. Accurate Text-to-SQL Generation via LLMs using RAG šŸ”„.
https://vanna.ai/docs/
MIT License
12.1k stars 969 forks source link

Error in train() method in linux offline environment #599

Closed 1nterstellar-JD closed 1 month ago

1nterstellar-JD commented 3 months ago
from vanna.base import VannaBase
from vanna.chromadb import ChromaDB_VectorStore
from transformers import AutoModelForCausalLM, AutoTokenizer
from abc import ABC
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
import os

os.environ['CUDA_VISIBLE_DEVICES'] = "1,3"

device = "cuda"

model_dir = '/home/hello/py/public/Qwen1.5-14B-Chat'
tokenizer = AutoTokenizer.from_pretrained(
    model_dir, 
    trust_remote_code=True
    )
model = AutoModelForCausalLM.from_pretrained(
    model_dir, 
    device_map="auto", 
    trust_remote_code=True,
    )

class Qwen(LLM, ABC):
     max_token: int = 30000
     temperature: float = 0.1
     top_p = 0.9
     history_len: int = 5

     def __init__(self):
         super().__init__()

     @property
     def _llm_type(self) -> str:
         return "Qwen"

     @property
     def _history_len(self) -> int:
         return self.history_len

     def set_history_len(self, history_len: int = 10) -> None:
         self.history_len = history_len

     def _call(
         self,
         prompt: str,
         stop: Optional[List[str]] = None,
         run_manager: Optional[CallbackManagerForLLMRun] = None,
     ) -> str:
         messages = [
             {"role": "system", "content": "You are a helpful assistant."},
             {"role": "user", "content": prompt}
         ]
         text = tokenizer.apply_chat_template(
             messages,
             tokenize=False,
             add_generation_prompt=True
         )
         model_inputs = tokenizer([text], return_tensors="pt").to(device)
         generated_ids = model.generate(
             model_inputs.input_ids,
             max_new_tokens=512
         )
         generated_ids = [
             output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
         ]

         response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
         return response

     @property
     def _identifying_params(self) -> Mapping[str, Any]:
         """Get the identifying parameters."""
         return {"max_token": self.max_token,
                 "temperature": self.temperature,
                 "top_p": self.top_p,
                 "history_len": self.history_len}

DEBUG_INFO = None

class MyCustomLLM(VannaBase):
  def __init__(self, config=None):
    self.client = Qwen()

  def system_message(self, message: str) -> any:
    return {"role": "system", "content": message}

  def user_message(self, message: str) -> any:
    return {"role": "user", "content": message}

  def assistant_message(self, message: str) -> any:
    return {"role": "assistant", "content": message}

  def submit_prompt(self, prompt, **kwargs) -> str:
        chat_response = self.client.call(
            messages=prompt,
            result_format='message',
        )
        answer = chat_response.output.choices[0].content
        global DEBUG_INFO
        DEBUG_INFO = (prompt, answer)
        return answer

class MyVanna(ChromaDB_VectorStore, MyCustomLLM):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        MyCustomLLM.__init__(self, config=config)

vn = MyVanna()

vn.connect_to_mysql(
    host='localhost', 
    dbname='Chinook', 
    user='root', 
    password='',
    port=3306
    )

df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")
plan = vn.get_training_plan_generic(df_information_schema)

vn.train(plan=plan)

vn.train(plan=plan)

Describe the bug

ConnectError                              Traceback (most recent call last)
File [~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:69](https://vscode-remote+ssh-002dremote-002b10-002e0-002e6-002e88.vscode-resource.vscode-cdn.net/home/hello/py/jd/vanna/~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:69), in map_httpcore_exceptions()
     [68](https://vscode-remote+ssh-002dremote-002b10-002e0-002e6-002e88.vscode-resource.vscode-cdn.net/home/hello/py/jd/vanna/~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:68) try:
---

File [~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:86](https://vscode-remote+ssh-002dremote-002b10-002e0-002e6-002e88.vscode-resource.vscode-cdn.net/home/hello/py/jd/vanna/~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:86), in map_httpcore_exceptions()
     [83](https://vscode-remote+ssh-002dremote-002b10-002e0-002e6-002e88.vscode-resource.vscode-cdn.net/home/hello/py/jd/vanna/~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:83)     raise
     [85](https://vscode-remote+ssh-002dremote-002b10-002e0-002e6-002e88.vscode-resource.vscode-cdn.net/home/hello/py/jd/vanna/~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:85) message = str(exc)
---> [86](https://vscode-remote+ssh-002dremote-002b10-002e0-002e6-002e88.vscode-resource.vscode-cdn.net/home/hello/py/jd/vanna/~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:86) raise mapped_exc(message) from exc

ConnectError: [Errno 104] Connection reset by peer

Desktop (please complete the following information where):

zainhoda commented 3 months ago

Thanks. Is there more in the traceback that can map to something in the vanna package? Otherwise there's no way for us to know if this is to do with the package or one of the third party components.

1nterstellar-JD commented 3 months ago

Thanks. Is there more in the traceback that can map to something in the vanna package? Otherwise there's no way for us to know if this is to do with the package or one of the third party components.

Thank you for your reply.

---------------------------------------------------------------------------
ConnectError                              Traceback (most recent call last)
File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:69, in map_httpcore_exceptions()
     68 try:
---> 69     yield
     70 except Exception as exc:

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:233, in HTTPTransport.handle_request(self, request)
    232 with map_httpcore_exceptions():
--> 233     resp = self._pool.handle_request(req)
    235 assert isinstance(resp.stream, typing.Iterable)

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpcore/_sync/connection_pool.py:216, in ConnectionPool.handle_request(self, request)
    215     self._close_connections(closing)
--> 216     raise exc from None
    218 # Return the response. Note that in this case we still have to manage
    219 # the point at which the response is closed.

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpcore/_sync/connection_pool.py:196, in ConnectionPool.handle_request(self, request)
    194 try:
    195     # Send the request on the assigned connection.
--> 196     response = connection.handle_request(
    197         pool_request.request
    198     )
    199 except ConnectionNotAvailable:
    200     # In some cases a connection may initially be available to
    201     # handle a request, but then become unavailable.
    202     #
    203     # In this case we clear the connection and try again.

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpcore/_sync/connection.py:99, in HTTPConnection.handle_request(self, request)
     98     self._connect_failed = True
---> 99     raise exc
    101 return self._connection.handle_request(request)

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpcore/_sync/connection.py:76, in HTTPConnection.handle_request(self, request)
     75 if self._connection is None:
---> 76     stream = self._connect(request)
     78     ssl_object = stream.get_extra_info("ssl_object")

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpcore/_sync/connection.py:154, in HTTPConnection._connect(self, request)
    153 with Trace("start_tls", logger, request, kwargs) as trace:
--> 154     stream = stream.start_tls(**kwargs)
    155     trace.return_value = stream

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpcore/_backends/sync.py:168, in SyncStream.start_tls(self, ssl_context, server_hostname, timeout)
    167         self.close()
--> 168         raise exc
    169 return SyncStream(sock)

File ~/anaconda3/envs/jd_lc/lib/python3.9/contextlib.py:137, in _GeneratorContextManager.__exit__(self, typ, value, traceback)
    136 try:
--> 137     self.gen.throw(typ, value, traceback)
    138 except StopIteration as exc:
    139     # Suppress StopIteration *unless* it's the same exception that
    140     # was passed to throw().  This prevents a StopIteration
    141     # raised inside the "with" statement from being suppressed.

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpcore/_exceptions.py:14, in map_exceptions(map)
     13     if isinstance(exc, from_exc):
---> 14         raise to_exc(exc) from exc
     15 raise

ConnectError: [Errno 104] Connection reset by peer

The above exception was the direct cause of the following exception:

ConnectError                              Traceback (most recent call last)
Cell In[7], line 11
      1 ddl = '''
      2 CREATE TABLE `Album` (
      3     `AlbumId` INTEGER NOT NULL, 
   (...)
      8 )ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci
      9 '''
---> 11 vn.train(ddl=ddl)

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/vanna/base/base.py:1775, in VannaBase.train(self, question, sql, ddl, documentation, plan)
   1773 if ddl:
   1774     print("Adding ddl:", ddl)
-> 1775     return self.add_ddl(ddl)
   1777 if plan:
   1778     for item in plan._plan:

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/vanna/chromadb/chromadb_vector.py:86, in ChromaDB_VectorStore.add_ddl(self, ddl, **kwargs)
     82 def add_ddl(self, ddl: str, **kwargs) -> str:
     83     id = deterministic_uuid(ddl) + "-ddl"
     84     self.ddl_collection.add(
     85         documents=ddl,
---> 86         embeddings=self.generate_embedding(ddl),
     87         ids=id,
     88     )
     89     return id

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/vanna/chromadb/chromadb_vector.py:60, in ChromaDB_VectorStore.generate_embedding(self, data, **kwargs)
     59 def generate_embedding(self, data: str, **kwargs) -> List[float]:
---> 60     embedding = self.embedding_function([data])
     61     if len(embedding) == 1:
     62         return embedding[0]

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/chromadb/api/types.py:211, in EmbeddingFunction.__init_subclass__.<locals>.__call__(self, input)
    210 def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
--> 211     result = call(self, input)
    212     return validate_embeddings(maybe_cast_one_to_many_embedding(result))

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py:199, in ONNXMiniLM_L6_V2.__call__(self, input)
    197 def __call__(self, input: Documents) -> Embeddings:
    198     # Only download the model when it is actually used
--> 199     self._download_model_if_not_exists()
    200     return cast(Embeddings, self._forward(input).tolist())

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py:226, in ONNXMiniLM_L6_V2._download_model_if_not_exists(self)
    219 os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
    220 if not os.path.exists(
    221     os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME)
    222 ) or not _verify_sha256(
    223     os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
    224     self._MODEL_SHA256,
    225 ):
--> 226     self._download(
    227         url=self.MODEL_DOWNLOAD_URL,
    228         fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
    229     )
    230 with tarfile.open(
    231     name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
    232     mode="r:gz",
    233 ) as tar:
    234     tar.extractall(path=self.DOWNLOAD_PATH)

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/tenacity/__init__.py:336, in BaseRetrying.wraps.<locals>.wrapped_f(*args, **kw)
    334 copy = self.copy()
    335 wrapped_f.statistics = copy.statistics  # type: ignore[attr-defined]
--> 336 return copy(f, *args, **kw)

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/tenacity/__init__.py:475, in Retrying.__call__(self, fn, *args, **kwargs)
    473 retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
    474 while True:
--> 475     do = self.iter(retry_state=retry_state)
    476     if isinstance(do, DoAttempt):
    477         try:

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/tenacity/__init__.py:376, in BaseRetrying.iter(self, retry_state)
    374 result = None
    375 for action in self.iter_state.actions:
--> 376     result = action(retry_state)
    377 return result

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/tenacity/__init__.py:398, in BaseRetrying._post_retry_check_actions.<locals>.<lambda>(rs)
    396 def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None:
    397     if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):
--> 398         self._add_action_func(lambda rs: rs.outcome.result())
    399         return
    401     if self.after is not None:

File ~/anaconda3/envs/jd_lc/lib/python3.9/concurrent/futures/_base.py:439, in Future.result(self, timeout)
    437     raise CancelledError()
    438 elif self._state == FINISHED:
--> 439     return self.__get_result()
    441 self._condition.wait(timeout)
    443 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File ~/anaconda3/envs/jd_lc/lib/python3.9/concurrent/futures/_base.py:391, in Future.__get_result(self)
    389 if self._exception:
    390     try:
--> 391         raise self._exception
    392     finally:
    393         # Break a reference cycle with the exception in self._exception
    394         self = None

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/tenacity/__init__.py:478, in Retrying.__call__(self, fn, *args, **kwargs)
    476 if isinstance(do, DoAttempt):
    477     try:
--> 478         result = fn(*args, **kwargs)
    479     except BaseException:  # noqa: B902
    480         retry_state.set_exception(sys.exc_info())  # type: ignore[arg-type]

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py:100, in ONNXMiniLM_L6_V2._download(self, url, fname, chunk_size)
     85 @retry(  # type: ignore
     86     reraise=True,
     87     stop=stop_after_attempt(3),
   (...)
     90 )
     91 def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:
     92     """
     93     Download the onnx model from the URL and save it to the file path.
     94 
   (...)
     98     elegant way, please do so.
     99     """
--> 100     with httpx.stream("GET", url) as resp:
    101         total = int(resp.headers.get("content-length", 0))
    102         with open(fname, "wb") as file, self.tqdm(
    103             desc=str(fname),
    104             total=total,
   (...)
    107             unit_divisor=1024,
    108         ) as bar:

File ~/anaconda3/envs/jd_lc/lib/python3.9/contextlib.py:119, in _GeneratorContextManager.__enter__(self)
    117 del self.args, self.kwds, self.func
    118 try:
--> 119     return next(self.gen)
    120 except StopIteration:
    121     raise RuntimeError("generator didn't yield") from None

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_api.py:160, in stream(method, url, params, content, data, files, json, headers, cookies, auth, proxy, proxies, timeout, follow_redirects, verify, cert, trust_env)
    141 """
    142 Alternative to `httpx.request()` that streams the response body
    143 instead of loading it into memory at once.
   (...)
    149 [0]: /quickstart#streaming-responses
    150 """
    151 with Client(
    152     cookies=cookies,
    153     proxy=proxy,
   (...)
    158     trust_env=trust_env,
    159 ) as client:
--> 160     with client.stream(
    161         method=method,
    162         url=url,
    163         content=content,
    164         data=data,
    165         files=files,
    166         json=json,
    167         params=params,
    168         headers=headers,
    169         auth=auth,
    170         follow_redirects=follow_redirects,
    171     ) as response:
    172         yield response

File ~/anaconda3/envs/jd_lc/lib/python3.9/contextlib.py:119, in _GeneratorContextManager.__enter__(self)
    117 del self.args, self.kwds, self.func
    118 try:
--> 119     return next(self.gen)
    120 except StopIteration:
    121     raise RuntimeError("generator didn't yield") from None

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_client.py:870, in Client.stream(self, method, url, content, data, files, json, params, headers, cookies, auth, follow_redirects, timeout, extensions)
    847 """
    848 Alternative to `httpx.request()` that streams the response body
    849 instead of loading it into memory at once.
   (...)
    855 [0]: /quickstart#streaming-responses
    856 """
    857 request = self.build_request(
    858     method=method,
    859     url=url,
   (...)
    868     extensions=extensions,
    869 )
--> 870 response = self.send(
    871     request=request,
    872     auth=auth,
    873     follow_redirects=follow_redirects,
    874     stream=True,
    875 )
    876 try:
    877     yield response

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_client.py:914, in Client.send(self, request, stream, auth, follow_redirects)
    906 follow_redirects = (
    907     self.follow_redirects
    908     if isinstance(follow_redirects, UseClientDefault)
    909     else follow_redirects
    910 )
    912 auth = self._build_request_auth(request, auth)
--> 914 response = self._send_handling_auth(
    915     request,
    916     auth=auth,
    917     follow_redirects=follow_redirects,
    918     history=[],
    919 )
    920 try:
    921     if not stream:

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_client.py:942, in Client._send_handling_auth(self, request, auth, follow_redirects, history)
    939 request = next(auth_flow)
    941 while True:
--> 942     response = self._send_handling_redirects(
    943         request,
    944         follow_redirects=follow_redirects,
    945         history=history,
    946     )
    947     try:
    948         try:

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_client.py:979, in Client._send_handling_redirects(self, request, follow_redirects, history)
    976 for hook in self._event_hooks["request"]:
    977     hook(request)
--> 979 response = self._send_single_request(request)
    980 try:
    981     for hook in self._event_hooks["response"]:

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_client.py:1015, in Client._send_single_request(self, request)
   1010     raise RuntimeError(
   1011         "Attempted to send an async request with a sync Client instance."
   1012     )
   1014 with request_context(request=request):
-> 1015     response = transport.handle_request(request)
   1017 assert isinstance(response.stream, SyncByteStream)
   1019 response.request = request

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:233, in HTTPTransport.handle_request(self, request)
    220 req = httpcore.Request(
    221     method=request.method,
    222     url=httpcore.URL(
   (...)
    230     extensions=request.extensions,
    231 )
    232 with map_httpcore_exceptions():
--> 233     resp = self._pool.handle_request(req)
    235 assert isinstance(resp.stream, typing.Iterable)
    237 return Response(
    238     status_code=resp.status,
    239     headers=resp.headers,
    240     stream=ResponseStream(resp.stream),
    241     extensions=resp.extensions,
    242 )

File ~/anaconda3/envs/jd_lc/lib/python3.9/contextlib.py:137, in _GeneratorContextManager.__exit__(self, typ, value, traceback)
    135     value = typ()
    136 try:
--> 137     self.gen.throw(typ, value, traceback)
    138 except StopIteration as exc:
    139     # Suppress StopIteration *unless* it's the same exception that
    140     # was passed to throw().  This prevents a StopIteration
    141     # raised inside the "with" statement from being suppressed.
    142     return exc is not value

File ~/anaconda3/envs/jd_lc/lib/python3.9/site-packages/httpx/_transports/default.py:86, in map_httpcore_exceptions()
     83     raise
     85 message = str(exc)
---> 86 raise mapped_exc(message) from exc

ConnectError: [Errno 104] Connection reset by peer
zainhoda commented 1 month ago

Chroma needs an embedding model downloaded in order for it to function. Since you're offline, it can't download the embedding model.