Eladlev / AutoPrompt

A framework for prompt tuning using Intent-based Prompt Calibration
Apache License 2.0
1.86k stars 149 forks source link

Is there support for local LLM? #58

Open 17Reset opened 2 months ago

17Reset commented 2 months ago

Is there any plan to support local offline models?

Eladlev commented 2 months ago

Local models are supported via langchain HuggingFacePipeline: https://github.com/Eladlev/AutoPrompt/issues/40#issuecomment-2016365671

17Reset commented 2 months ago

The configuration file doesn't have any path to the local LLM, how can I use it correctly?

Eladlev commented 2 months ago

You can see here for more info on langchain HuggingFacePipeline: https://python.langchain.com/docs/integrations/llms/huggingface_pipelines/

You need to put the model id in the config file as the 'name', see here: https://github.com/Eladlev/AutoPrompt/issues/40#issuecomment-2016365671

This is the list of supported models (and their id): https://huggingface.co/models

17Reset commented 2 months ago

Sorry, I meant to operate completely offline, the connection you provided above, is to run the model hosted on Huggingface locally, what I'm trying to understand is to first go offline and then run the LLM that exists locally on my own server.

Eladlev commented 2 months ago

The huggingface pipeline downloads the model locally (once), and then uses the stored model. You can also download the model manually, by going to the relevant model card and downloading all the files (this is for example the files for GPT-2): https://huggingface.co/openai-community/gpt2/tree/main

Then in the model name you should refer to the folder with all the model files

17Reset commented 2 months ago

Something error:

/mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Describe the task: Assistant is a large language model that is tasked with writing reporting.
Initial prompt: Create an annual performance report for the Human Resources department, which is divided into six sections, each providing a comprehensive review of the department's key operations during the year.
/mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_transports/default.py:6 │
│ 7 in map_httpcore_exceptions                                                                     │
│                                                                                                  │
│    64 @contextlib.contextmanager                                                                 │
│    65 def map_httpcore_exceptions() -> typing.Iterator[None]:                                    │
│    66 │   try:                                                                                   │
│ ❱  67 │   │   yield                                                                              │
│    68 │   except Exception as exc:                                                               │
│    69 │   │   mapped_exc = None                                                                  │
│    70                                                                                            │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_transports/default.py:2 │
│ 31 in handle_request                                                                             │
│                                                                                                  │
│   228 │   │   │   extensions=request.extensions,                                                 │
│   229 │   │   )                                                                                  │
│   230 │   │   with map_httpcore_exceptions():                                                    │
│ ❱ 231 │   │   │   resp = self._pool.handle_request(req)                                          │
│   232 │   │                                                                                      │
│   233 │   │   assert isinstance(resp.stream, typing.Iterable)                                    │
│   234                                                                                            │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpcore/_sync/connection_pool │
│ .py:216 in handle_request                                                                        │
│                                                                                                  │
│   213 │   │   │   │   closing = self._assign_requests_to_connections()                           │
│   214 │   │   │                                                                                  │
│   215 │   │   │   self._close_connections(closing)                                               │
│ ❱ 216 │   │   │   raise exc from None                                                            │
│   217 │   │                                                                                      │
│   218 │   │   # Return the response. Note that in this case we still have to manage              │
│   219 │   │   # the point at which the response is closed.                                       │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpcore/_sync/connection_pool │
│ .py:196 in handle_request                                                                        │
│                                                                                                  │
│   193 │   │   │   │                                                                              │
│   194 │   │   │   │   try:                                                                       │
│   195 │   │   │   │   │   # Send the request on the assigned connection.                         │
│ ❱ 196 │   │   │   │   │   response = connection.handle_request(                                  │
│   197 │   │   │   │   │   │   pool_request.request                                               │
│   198 │   │   │   │   │   )                                                                      │
│   199 │   │   │   │   except ConnectionNotAvailable:                                             │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpcore/_sync/connection.py:9 │
│ 9 in handle_request                                                                              │
│                                                                                                  │
│    96 │   │   │   │   │   │   )                                                                  │
│    97 │   │   except BaseException as exc:                                                       │
│    98 │   │   │   self._connect_failed = True                                                    │
│ ❱  99 │   │   │   raise exc                                                                      │
│   100 │   │                                                                                      │
│   101 │   │   return self._connection.handle_request(request)                                    │
│   102                                                                                            │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpcore/_sync/connection.py:7 │
│ 6 in handle_request                                                                              │
│                                                                                                  │
│    73 │   │   try:                                                                               │
│    74 │   │   │   with self._request_lock:                                                       │
│    75 │   │   │   │   if self._connection is None:                                               │
│ ❱  76 │   │   │   │   │   stream = self._connect(request)                                        │
│    77 │   │   │   │   │                                                                          │
│    78 │   │   │   │   │   ssl_object = stream.get_extra_info("ssl_object")                       │
│    79 │   │   │   │   │   http2_negotiated = (                                                   │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpcore/_sync/connection.py:1 │
│ 22 in _connect                                                                                   │
│                                                                                                  │
│   119 │   │   │   │   │   │   "socket_options": self._socket_options,                            │
│   120 │   │   │   │   │   }                                                                      │
│   121 │   │   │   │   │   with Trace("connect_tcp", logger, request, kwargs) as trace:           │
│ ❱ 122 │   │   │   │   │   │   stream = self._network_backend.connect_tcp(**kwargs)               │
│   123 │   │   │   │   │   │   trace.return_value = stream                                        │
│   124 │   │   │   │   else:                                                                      │
│   125 │   │   │   │   │   kwargs = {                                                             │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpcore/_backends/sync.py:205 │
│ in connect_tcp                                                                                   │
│                                                                                                  │
│   202 │   │   │   OSError: ConnectError,                                                         │
│   203 │   │   }                                                                                  │
│   204 │   │                                                                                      │
│ ❱ 205 │   │   with map_exceptions(exc_map):                                                      │
│   206 │   │   │   sock = socket.create_connection(                                               │
│   207 │   │   │   │   address,                                                                   │
│   208 │   │   │   │   timeout,                                                                   │
│                                                                                                  │
│ /usr/lib/python3.10/contextlib.py:153 in __exit__                                                │
│                                                                                                  │
│   150 │   │   │   │   # tell if we get the same exception back                                   │
│   151 │   │   │   │   value = typ()                                                              │
│   152 │   │   │   try:                                                                           │
│ ❱ 153 │   │   │   │   self.gen.throw(typ, value, traceback)                                      │
│   154 │   │   │   except StopIteration as exc:                                                   │
│   155 │   │   │   │   # Suppress StopIteration *unless* it's the same exception that             │
│   156 │   │   │   │   # was passed to throw().  This prevents a StopIteration                    │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpcore/_exceptions.py:14 in  │
│ map_exceptions                                                                                   │
│                                                                                                  │
│   11 │   except Exception as exc:  # noqa: PIE786                                                │
│   12 │   │   for from_exc, to_exc in map.items():                                                │
│   13 │   │   │   if isinstance(exc, from_exc):                                                   │
│ ❱ 14 │   │   │   │   raise to_exc(exc) from exc                                                  │
│   15 │   │   raise  # pragma: nocover                                                            │
│   16                                                                                             │
│   17                                                                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ConnectError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/argilla/client/sdk/client.py:1 │
│ 24 in inner                                                                                      │
│                                                                                                  │
│   121 │   │   @functools.wraps(func)                                                             │
│   122 │   │   def inner(self, *args, **kwargs):                                                  │
│   123 │   │   │   try:                                                                           │
│ ❱ 124 │   │   │   │   result = func(self, *args, **kwargs)                                       │
│   125 │   │   │   │   return result                                                              │
│   126 │   │   │   except httpx.ConnectError as err:                                              │
│   127 │   │   │   │   err_str = f"Your Api endpoint at {self.base_url} is not available or not   │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/argilla/client/sdk/client.py:1 │
│ 35 in get                                                                                        │
│                                                                                                  │
│   132 │   @with_httpx_error_handler                                                              │
│   133 │   def get(self, path: str, *args, **kwargs):                                             │
│   134 │   │   path = self._normalize_path(path)                                                  │
│ ❱ 135 │   │   response = self.__httpx__.get(                                                     │
│   136 │   │   │   url=path,                                                                      │
│   137 │   │   │   headers=self.get_headers(),                                                    │
│   138 │   │   │   *args,                                                                         │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_client.py:1055 in get   │
│                                                                                                  │
│   1052 │   │                                                                                     │
│   1053 │   │   **Parameters**: See `httpx.request`.                                              │
│   1054 │   │   """                                                                               │
│ ❱ 1055 │   │   return self.request(                                                              │
│   1056 │   │   │   "GET",                                                                        │
│   1057 │   │   │   url,                                                                          │
│   1058 │   │   │   params=params,                                                                │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_client.py:828 in        │
│ request                                                                                          │
│                                                                                                  │
│    825 │   │   │   timeout=timeout,                                                              │
│    826 │   │   │   extensions=extensions,                                                        │
│    827 │   │   )                                                                                 │
│ ❱  828 │   │   return self.send(request, auth=auth, follow_redirects=follow_redirects)           │
│    829 │                                                                                         │
│    830 │   @contextmanager                                                                       │
│    831 │   def stream(                                                                           │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_client.py:915 in send   │
│                                                                                                  │
│    912 │   │                                                                                     │
│    913 │   │   auth = self._build_request_auth(request, auth)                                    │
│    914 │   │                                                                                     │
│ ❱  915 │   │   response = self._send_handling_auth(                                              │
│    916 │   │   │   request,                                                                      │
│    917 │   │   │   auth=auth,                                                                    │
│    918 │   │   │   follow_redirects=follow_redirects,                                            │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_client.py:943 in        │
│ _send_handling_auth                                                                              │
│                                                                                                  │
│    940 │   │   │   request = next(auth_flow)                                                     │
│    941 │   │   │                                                                                 │
│    942 │   │   │   while True:                                                                   │
│ ❱  943 │   │   │   │   response = self._send_handling_redirects(                                 │
│    944 │   │   │   │   │   request,                                                              │
│    945 │   │   │   │   │   follow_redirects=follow_redirects,                                    │
│    946 │   │   │   │   │   history=history,                                                      │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_client.py:980 in        │
│ _send_handling_redirects                                                                         │
│                                                                                                  │
│    977 │   │   │   for hook in self._event_hooks["request"]:                                     │
│    978 │   │   │   │   hook(request)                                                             │
│    979 │   │   │                                                                                 │
│ ❱  980 │   │   │   response = self._send_single_request(request)                                 │
│    981 │   │   │   try:                                                                          │
│    982 │   │   │   │   for hook in self._event_hooks["response"]:                                │
│    983 │   │   │   │   │   hook(response)                                                        │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_client.py:1016 in       │
│ _send_single_request                                                                             │
│                                                                                                  │
│   1013 │   │   │   )                                                                             │
│   1014 │   │                                                                                     │
│   1015 │   │   with request_context(request=request):                                            │
│ ❱ 1016 │   │   │   response = transport.handle_request(request)                                  │
│   1017 │   │                                                                                     │
│   1018 │   │   assert isinstance(response.stream, SyncByteStream)                                │
│   1019                                                                                           │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_transports/default.py:2 │
│ 30 in handle_request                                                                             │
│                                                                                                  │
│   227 │   │   │   content=request.stream,                                                        │
│   228 │   │   │   extensions=request.extensions,                                                 │
│   229 │   │   )                                                                                  │
│ ❱ 230 │   │   with map_httpcore_exceptions():                                                    │
│   231 │   │   │   resp = self._pool.handle_request(req)                                          │
│   232 │   │                                                                                      │
│   233 │   │   assert isinstance(resp.stream, typing.Iterable)                                    │
│                                                                                                  │
│ /usr/lib/python3.10/contextlib.py:153 in __exit__                                                │
│                                                                                                  │
│   150 │   │   │   │   # tell if we get the same exception back                                   │
│   151 │   │   │   │   value = typ()                                                              │
│   152 │   │   │   try:                                                                           │
│ ❱ 153 │   │   │   │   self.gen.throw(typ, value, traceback)                                      │
│   154 │   │   │   except StopIteration as exc:                                                   │
│   155 │   │   │   │   # Suppress StopIteration *unless* it's the same exception that             │
│   156 │   │   │   │   # was passed to throw().  This prevents a StopIteration                    │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/httpx/_transports/default.py:8 │
│ 4 in map_httpcore_exceptions                                                                     │
│                                                                                                  │
│    81 │   │   │   raise                                                                          │
│    82 │   │                                                                                      │
│    83 │   │   message = str(exc)                                                                 │
│ ❱  84 │   │   raise mapped_exc(message) from exc                                                 │
│    85                                                                                            │
│    86                                                                                            │
│    87 HTTPCORE_EXC_MAP = {                                                                       │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ConnectError: [Errno 111] Connection refused

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/Agent/AutoPrompt/estimator/estimator_argilla.py:22 in __init__                     │
│                                                                                                  │
│    19 │   │   """                                                                                │
│    20 │   │   try:                                                                               │
│    21 │   │   │   self.opt = opt                                                                 │
│ ❱  22 │   │   │   rg.init(                                                                       │
│    23 │   │   │   │   api_url=opt.api_url,                                                       │
│    24 │   │   │   │   api_key=opt.api_key,                                                       │
│    25 │   │   │   │   workspace=opt.workspace                                                    │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/argilla/client/singleton.py:95 │
│ in init                                                                                          │
│                                                                                                  │
│    92 │   │   >>> headers = {"X-Client-id":"id","X-Secret":"secret"}                             │
│    93 │   │   >>> rg.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y", extra_headers   │
│    94 │   """                                                                                    │
│ ❱  95 │   ArgillaSingleton.init(                                                                 │
│    96 │   │   api_url=api_url,                                                                   │
│    97 │   │   api_key=api_key,                                                                   │
│    98 │   │   workspace=workspace,                                                               │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/argilla/client/singleton.py:47 │
│ in init                                                                                          │
│                                                                                                  │
│    44 │   ) -> Argilla:                                                                          │
│    45 │   │   cls._INSTANCE = None                                                               │
│    46 │   │                                                                                      │
│ ❱  47 │   │   cls._INSTANCE = Argilla(                                                           │
│    48 │   │   │   api_url=api_url,                                                               │
│    49 │   │   │   api_key=api_key,                                                               │
│    50 │   │   │   timeout=timeout,                                                               │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/argilla/client/client.py:164   │
│ in __init__                                                                                      │
│                                                                                                  │
│   161 │   │   │   httpx_extra_kwargs=httpx_extra_kwargs,                                         │
│   162 │   │   )                                                                                  │
│   163 │   │                                                                                      │
│ ❱ 164 │   │   self._user = users_api.whoami(client=self.http_client)  # .parsed                  │
│   165 │   │                                                                                      │
│   166 │   │   if not workspace and self._user.username == DEFAULT_USERNAME and DEFAULT_USERNAM   │
│   167 │   │   │   warnings.warn(                                                                 │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/argilla/client/sdk/users/api.p │
│ y:39 in whoami                                                                                   │
│                                                                                                  │
│    36 │   """                                                                                    │
│    37 │   url = "/api/me"                                                                        │
│    38 │                                                                                          │
│ ❱  39 │   response = client.get(url)                                                             │
│    40 │   return UserModel(**response)                                                           │
│    41                                                                                            │
│    42                                                                                            │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/argilla/client/sdk/client.py:1 │
│ 28 in inner                                                                                      │
│                                                                                                  │
│   125 │   │   │   │   return result                                                              │
│   126 │   │   │   except httpx.ConnectError as err:                                              │
│   127 │   │   │   │   err_str = f"Your Api endpoint at {self.base_url} is not available or not   │
│ ❱ 128 │   │   │   │   raise BaseClientError(err_str) from err                                    │
│   129 │   │                                                                                      │
│   130 │   │   return inner                                                                       │
│   131                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BaseClientError: Your Api endpoint at http://localhost:6900 is not available or not responding: [Errno 111] Connection refused

During handling of the above exception, another exception occurred:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/Agent/AutoPrompt/run_pipeline.py:41 in <module>                                    │
│                                                                                                  │
│   38 │   initial_prompt = opt.prompt                                                             │
│   39                                                                                             │
│   40 # Initializing the pipeline                                                                 │
│ ❱ 41 pipeline = OptimizationPipeline(config_params, task_description, initial_prompt, output_    │
│   42 if (opt.load_path != ''):                                                                   │
│   43 │   pipeline.load_state(opt.load_path)                                                      │
│   44 best_prompt = pipeline.run_pipeline(opt.num_steps)                                          │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/optimization_pipeline.py:57 in __init__                           │
│                                                                                                  │
│    54 │   │   self.cur_prompt = initial_prompt                                                   │
│    55 │   │                                                                                      │
│    56 │   │   self.predictor = give_estimator(config.predictor)                                  │
│ ❱  57 │   │   self.annotator = give_estimator(config.annotator)                                  │
│    58 │   │   self.eval = Eval(config.eval, self.meta_chain.error_analysis, self.dataset.label   │
│    59 │   │   self.batch_id = 0                                                                  │
│    60 │   │   self.patient = 0                                                                   │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/estimator/__init__.py:31 in give_estimator                        │
│                                                                                                  │
│   28                                                                                             │
│   29 def give_estimator(opt):                                                                    │
│   30 │   if opt.method == 'argilla':                                                             │
│ ❱ 31 │   │   return ArgillaEstimator(opt.config)                                                 │
│   32 │   elif opt.method == 'llm':                                                               │
│   33 │   │   return LLMEstimator(opt.config)                                                     │
│   34 │   elif opt.method == 'llm_batch':                                                         │
│                                                                                                  │
│ /mnt/Agent/AutoPrompt/estimator/estimator_argilla.py:29 in __init__                     │
│                                                                                                  │
│    26 │   │   │   )                                                                              │
│    27 │   │   │   self.time_interval = opt.time_interval                                         │
│    28 │   │   except:                                                                            │
│ ❱  29 │   │   │   raise Exception("Failed to connect to argilla, check connection details")      │
│    30 │                                                                                          │
│    31 │   @staticmethod                                                                          │
│    32 │   def initialize_dataset(dataset_name: str, label_schema: set[str]):                     │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
Exception: Failed to connect to argilla, check connection details
Eladlev commented 2 months ago

This is an Argilla connection error. It seems that the model try to connect to the Argilla server, if you want to work in offline mode you should upload the Argilla server locally (or use the LLM annotator option)

17Reset commented 2 months ago

Could you tell me how to use the GPU when using it offline and does the config file provide a GPU entry?

2024-04-26 10:38:41,707 - WARNING - Device has 2 GPUs available. Provide device={deviceId} to `from_model_id` to use availableGPUs for execution. deviceId is -1 (default) for CPU and can be a positive integer associated with CUDA device id.
2024-04-26 10:41:37,745 - WARNING - Device has 2 GPUs available. Provide device={deviceId} to `from_model_id` to use availableGPUs for execution. deviceId is -1 (default) for CPU and can be a positive integer associated with CUDA device id.
17Reset commented 2 months ago

Here are the changes I made to config: config_default.yml

use_wandb: False
dataset:
    name: 'dataset'
    records_path: null
    initial_dataset: ''
    # label_schema: ["Yes", "No"]
    label_schema: ["Action", "Comedy", "Drama", "Romance", "Horror"]
    max_samples: 50
    semantic_sampling: False # Change to True in case you don't have M1. Currently there is an issue with faiss and M1

# annotator:
    # method : 'argilla'
    # config:
        # api_url: ''
        # api_key: 'admin.apikey'
        # workspace: 'admin'
        # time_interval: 5

predictor:
    method : 'llm'
    config:
        llm:
            type: 'OpenAI'
            name: 'gpt-3.5-turbo-0613'
#            async_params:
#                retry_interval: 10
#                max_retries: 2
            model_kwargs: {"seed": 220}
        num_workers: 5
        prompt: 'prompts/predictor_completion/prediction.prompt'
        mini_batch_size: 1  #change to >1 if you want to include multiple samples in the one prompt
        mode: 'prediction'

meta_prompts:
    folder: 'prompts/meta_prompts_classification'
    num_err_prompt: 1  # Number of error examples per sample in the prompt generation
    num_err_samples: 2 # Number of error examples per sample in the sample generation
    history_length: 4 # Number of sample in the meta-prompt history
    num_generated_samples: 10 # Number of generated samples at each iteration
    num_initialize_samples: 10 # Number of generated samples at iteration 0, in zero-shot case
    samples_generation_batch: 10 # Number of samples generated in one call to the LLM
    num_workers: 5 #Number of parallel workers
    warmup: 4 # Number of warmup steps

eval:
    function_name: 'accuracy'
    num_large_errors: 4
    num_boundary_predictions : 0
    error_threshold: 0.5

llm:
    # type: 'OpenAI'
    # name: 'gpt-4-1106-preview'
    # temperature: 0.8
    type: 'HuggingFacePipeline'
    name: '/mnt/Model/Abacus/llm/Smaug-34B-v0.1'
    max_new_tokens: 2048
    temperature: 0.7

stop_criteria:
    max_usage: 2 #In $ in case of OpenAI models, otherwise number of tokens
    patience: 10 # Number of patience steps
    min_delta: 0.01 # Delta for the improvement definition

The run was very long and I did not wait for the results:

(myenv) xaccel@xaccel:/mnt/Agent/AutoPrompt/repo$ python run_pipeline.py --num_steps 2
/mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Describe the task: Assistant is an expert cinema critic for all genres, and is tasked with classifying other movie reviews.
Initial prompt: Based on the following movie review, what genre is this movie? Select between Action, Comedy, Drama, Romance or Horror.
/mnt/Agent/AutoPrompt/myenv/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  _torch_pytree._register_pytree_node(
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.29it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 15/15 [00:11<00:00,  1.30it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████| 15/15 [00:13<00:00,  1.15it/s]

Once again the process spiked the RAM from 10G to 500G and counting:

image

Eladlev commented 2 months ago

Could you tell me how to use the GPU when using it offline and does the config file provide a GPU entry?

2024-04-26 10:38:41,707 - WARNING - Device has 2 GPUs available. Provide device={deviceId} to `from_model_id` to use availableGPUs for execution. deviceId is -1 (default) for CPU and can be a positive integer associated with CUDA device id.
2024-04-26 10:41:37,745 - WARNING - Device has 2 GPUs available. Provide device={deviceId} to `from_model_id` to use availableGPUs for execution. deviceId is -1 (default) for CPU and can be a positive integer associated with CUDA device id.

I added support for inference with local GPU: https://github.com/Eladlev/AutoPrompt/pull/59

In order to use it you need to change the config to either:

llm:
    type: 'HuggingFacePipeline'
    name: '/mnt/Model/Abacus/llm/Smaug-34B-v0.1'
    max_new_tokens: 2048
    temperature: 0.7
    gpu_device: 0

Or better using accelerate:

llm:
    type: 'HuggingFacePipeline'
    name: '/mnt/Model/Abacus/llm/Smaug-34B-v0.1'
    max_new_tokens: 2048
    temperature: 0.7
    device_map: 'auto'