refuel-ai / autolabel

Label, clean and enrich text datasets with LLMs.
https://docs.refuel.ai/
MIT License
2k stars 137 forks source link

[Feature Request]: Support for Azure OpenAI as a provider #429

Open rishabh-bhargava opened 1 year ago

rishabh-bhargava commented 1 year ago

Is your feature request related to a problem? Please describe. We would like to be able to use the OpenAI models through Azure's OpenAI offering.

Describe the solution you'd like Support for Azure OpenAI as a provider.

nihit commented 1 year ago

We can use Azure OpenAI LLM implementation available in langchain for this: https://python.langchain.com/docs/modules/model_io/models/llms/integrations/azure_openai_example

The integration into Autolabel can be very similar to OpenAI: https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/models/openai.py (which in turn uses langchain's OpenAI LLM implentation)

A reference PR for recently added support for Cohere: https://github.com/refuel-ai/autolabel/pull/419

shril commented 10 months ago

@nihit @rishabh-bhargava @rajasbansal I want to pick this issue. Let me know if no one is working on this.

hellangleZ commented 10 months ago

We can use Azure OpenAI LLM implementation available in langchain for this: https://python.langchain.com/docs/modules/model_io/models/llms/integrations/azure_openai_example

The integration into Autolabel can be very similar to OpenAI: https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/models/openai.py (which in turn uses langchain's OpenAI LLM implentation)

A reference PR for recently added support for Cohere: #419 Hi @nihit Only change autolabel openai.py looks could not resolve the issue, because there is a embedding task before the LLM label. but after I try to change a lot of python file, still could not comunnicate to Azure emedding LLM,

Like below:

connection broken by 'NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7f7df3fb0400>: Failed to establish a new connection: [Errno 101] Network is unreachable')': /v1/engines/text-embedding-ada-002/embeddings


KeyboardInterrupt Traceback (most recent call last) /aml/autolabel/examples/banking/example_banking.ipynb Cell 15 line 4 2 from autolabel import AutolabelDataset 3 ds = AutolabelDataset("test.csv", config=config) ----> 4 agent.plan(ds)

File /aml/train/lib/python3.8/site-packages/autolabel/labeler.py:389, in LabelingAgent.plan(self, dataset, max_items, start_index) 380 if ( 381 self.config.explanation_column() 382 and len(seed_examples) > 0 383 and self.config.explanation_column() not in list(seed_examples[0].keys()) 384 ): 385 raise ValueError( 386 f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file)." 387 ) --> 389 self.example_selector = ExampleSelectorFactory.initialize_selector( 390 self.config, 391 [safe_serialize_to_string(example) for example in seed_examples], 392 dataset.df.keys().tolist(), 393 cache=self.generation_cache is not None, 394 ) 396 if self.config.label_selection(): 397 if self.config.task_type() != TaskType.CLASSIFICATION:

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/init.py:118, in ExampleSelectorFactory.initialize_selector(config, examples, columns, cache) 112 if algorithm not in [ 113 FewShotAlgorithm.FIXED, 114 FewShotAlgorithm.LABEL_DIVERSITY_RANDOM, 115 ]: 116 params["cache"] = cache --> 118 return example_cls.from_examples(**params)

File /aml/train/lib/python3.8/site-packages/langchain/prompts/example_selector/semantic_similarity.py:96, in SemanticSimilarityExampleSelector.from_examples(cls, examples, embeddings, vectorstore_cls, k, input_keys, vectorstore_cls_kwargs) 94 else: 95 string_examples = [" ".join(sorted_values(eg)) for eg in examples] ---> 96 vectorstore = vectorstore_cls.from_texts( 97 string_examples, embeddings, metadatas=examples, vectorstore_cls_kwargs 98 ) 99 return cls(vectorstore=vectorstore, k=k, input_keys=input_keys)

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:453, in VectorStoreWrapper.from_texts(cls, texts, embedding, metadatas, cache, kwargs) 436 """Create a vectorstore from raw text. 437 The data will be ephemeral in-memory. 438 Args: ref='/aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:0'>0;32m (...) 444 vector_store: Vectorstore with seedset embeddings 445 """ 446 vector_store = cls( 447 embedding_function=embedding, 448 corpus_embeddings=None, ref='/aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:0'>0;32m (...) 451 kwargs, 452 ) --> 453 vector_store.add_texts(texts=texts, metadatas=metadatas) 454 return vector_store

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:244, in VectorStoreWrapper.add_texts(self, texts, metadatas) 236 """Run texts through the embeddings and add to the vectorstore. Currently, the vectorstore is reinitialized each time, because we do not require a persistent vector store for example selection. 237 Args: 238 texts (Iterable[str]): Texts to add to the vectorstore. ref='/aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:0'>0;32m (...) 241 List[str]: List of IDs of the added texts. 242 """ 243 if self._embedding_function is not None: --> 244 embeddings = self._get_embeddings(texts) 246 self._corpus_embeddings = torch.tensor(embeddings) 247 self._texts = texts

File /aml/train/lib/python3.8/site-packages/autolabel/few_shot/vector_store.py:195, in VectorStoreWrapper._get_embeddings(self, texts) 192 uncached_texts.append(text) 193 uncached_texts_indices.append(idx) --> 195 uncached_embeddings = self._embedding_function.embed_documents( 196 uncached_texts 197 ) 198 self._add_embeddings_to_cache(uncached_texts, uncached_embeddings) 199 for idx, embedding in zip(uncached_texts_indices, uncached_embeddings):

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:476, in OpenAIEmbeddings.embed_documents(self, texts, chunk_size) 464 """Call out to OpenAI's embedding endpoint for embedding search docs. 465 466 Args: ref='/aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:0'>0;32m (...) 472 List of embeddings, one for each text. 473 """ 474 # NOTE: to keep things simple, we assume the list may contain texts longer 475 # than the maximum context and use length-safe embedding function. --> 476 return self._get_len_safe_embeddings(texts, engine=self.deployment)

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:326, in OpenAIEmbeddings._get_len_safe_embeddings(self, texts, engine, chunk_size) 323 _iter = range(0, len(tokens), _chunk_size) 325 for i in _iter: --> 326 response = embed_with_retry( 327 self, 328 input=tokens[i : i + _chunk_size], 329 **self._invocation_params, 330 ) 331 batchedembeddings += [r["embedding"] for r in response["data"]] 333 results: List[List[List[float]]] = [[] for in range(len(texts))]

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:107, in embed_with_retry(embeddings, kwargs) 104 response = embeddings.client.create(kwargs) 105 return _check_response(response) --> 107 return _embed_with_retry(**kwargs)

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:289, in BaseRetrying.wraps..wrapped_f(*args, kw) 287 @functools.wraps(f) 288 def wrapped_f(*args: t.Any, *kw: t.Any) -> t.Any: --> 289 return self(f, args, kw)

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:379, in Retrying.call(self, fn, *args, **kwargs) 377 retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) 378 while True: --> 379 do = self.iter(retry_state=retry_state) 380 if isinstance(do, DoAttempt): 381 try:

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:314, in BaseRetrying.iter(self, retry_state) 312 is_explicit_retry = fut.failed and isinstance(fut.exception(), TryAgain) 313 if not (is_explicit_retry or self.retry(retry_state)): --> 314 return fut.result() 316 if self.after is not None: 317 self.after(retry_state)

File /aml/train/lib/python3.8/concurrent/futures/_base.py:437, in Future.result(self, timeout) 435 raise CancelledError() 436 elif self._state == FINISHED: --> 437 return self.__get_result() 439 self._condition.wait(timeout) 441 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:

File /aml/train/lib/python3.8/concurrent/futures/_base.py:389, in Future.__get_result(self) 387 if self._exception: 388 try: --> 389 raise self._exception 390 finally: 391 # Break a reference cycle with the exception in self._exception 392 self = None

File /aml/train/lib/python3.8/site-packages/tenacity/init.py:382, in Retrying.call(self, fn, *args, *kwargs) 380 if isinstance(do, DoAttempt): 381 try: --> 382 result = fn(args, **kwargs) 383 except BaseException: # noqa: B902 384 retry_state.set_exception(sys.exc_info()) # type: ignore[arg-type]

File /aml/train/lib/python3.8/site-packages/langchain/embeddings/openai.py:104, in embed_with_retry.._embed_with_retry(kwargs) 102 @retry_decorator 103 def _embed_with_retry(kwargs: Any) -> Any: --> 104 response = embeddings.client.create(**kwargs) 105 return _check_response(response)

File /aml/train/lib/python3.8/site-packages/openai/api_resources/embedding.py:33, in Embedding.create(cls, *args, *kwargs) 31 while True: 32 try: ---> 33 response = super().create(args, **kwargs) 35 # If a user specifies base64, we'll just return the encoded string. 36 # This is only for the default case. 37 if not user_provided_encoding_format:

File /aml/train/lib/python3.8/site-packages/openai/api_resources/abstract/engine_api_resource.py:155, in EngineAPIResource.create(cls, api_key, api_base, api_type, request_id, api_version, organization, params) 129 @classmethod 130 def create( 131 cls, ref='/aml/train/lib/python3.8/site-packages/openai/api_resources/abstract/engine_api_resource.py:0'>0;32m (...) 138 params, 139 ): 140 ( 141 deployment_id, 142 engine, ref='/aml/train/lib/python3.8/site-packages/openai/api_resources/abstract/engine_api_resource.py:0'>0;32m (...) 152 api_key, api_base, api_type, api_version, organization, **params 153 ) --> 155 response, _, api_key = requestor.request( 156 "post", 157 url, 158 params=params, 159 headers=headers, 160 stream=stream, 161 request_id=request_id, 162 request_timeout=request_timeout, 163 ) 165 if stream: 166 # must be an iterator 167 assert not isinstance(response, OpenAIResponse)

File /aml/train/lib/python3.8/site-packages/openai/api_requestor.py:289, in APIRequestor.request(self, method, url, params, headers, files, stream, request_id, request_timeout) 278 def request( 279 self, 280 method, ref='/aml/train/lib/python3.8/site-packages/openai/api_requestor.py:0'>0;32m (...) 287 request_timeout: Optional[Union[float, Tuple[float, float]]] = None, 288 ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: --> 289 result = self.request_raw( 290 method.lower(), 291 url, 292 params=params, 293 supplied_headers=headers, 294 files=files, 295 stream=stream, 296 request_id=request_id, 297 request_timeout=request_timeout, 298 ) 299 resp, got_stream = self._interpret_response(result, stream) 300 return resp, got_stream, self.api_key

File /aml/train/lib/python3.8/site-packages/openai/api_requestor.py:606, in APIRequestor.request_raw(self, method, url, params, supplied_headers, files, stream, request_id, request_timeout) 604 _thread_context.session_create_time = time.time() 605 try: --> 606 result = _thread_context.session.request( 607 method, 608 abs_url, 609 headers=headers, 610 data=data, 611 files=files, 612 stream=stream, 613 timeout=request_timeout if request_timeout else TIMEOUT_SECS, 614 proxies=_thread_context.session.proxies, 615 ) 616 except requests.exceptions.Timeout as e: 617 raise error.Timeout("Request timed out: {}".format(e)) from e

File /aml/train/lib/python3.8/site-packages/requests/sessions.py:589, in Session.request(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json) 584 send_kwargs = { 585 "timeout": timeout, 586 "allow_redirects": allow_redirects, 587 } 588 send_kwargs.update(settings) --> 589 resp = self.send(prep, **send_kwargs) 591 return resp

File /aml/train/lib/python3.8/site-packages/requests/sessions.py:703, in Session.send(self, request, kwargs) 700 start = preferred_clock() 702 # Send the request --> 703 r = adapter.send(request, kwargs) 705 # Total elapsed time of the request (approximately) 706 elapsed = preferred_clock() - start

File /aml/train/lib/python3.8/site-packages/requests/adapters.py:486, in HTTPAdapter.send(self, request, stream, timeout, verify, cert, proxies) 483 timeout = TimeoutSauce(connect=timeout, read=timeout) 485 try: --> 486 resp = conn.urlopen( 487 method=request.method, 488 url=url, 489 body=request.body, 490 headers=request.headers, 491 redirect=False, 492 assert_same_host=False, 493 preload_content=False, 494 decode_content=False, 495 retries=self.max_retries, 496 timeout=timeout, 497 chunked=chunked, 498 ) 500 except (ProtocolError, OSError) as err: 501 raise ConnectionError(err, request=request)

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:826, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, response_kw) 821 if not conn: 822 # Try again 823 log.warning( 824 "Retrying (%r) after connection broken by '%r': %s", retries, err, url 825 ) --> 826 return self.urlopen( 827 method, 828 url, 829 body, 830 headers, 831 retries, 832 redirect, 833 assert_same_host, 834 timeout=timeout, 835 pool_timeout=pool_timeout, 836 release_conn=release_conn, 837 chunked=chunked, 838 body_pos=body_pos, 839 response_kw 840 ) 842 # Handle redirect? 843 redirect_location = redirect and response.get_redirect_location()

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:714, in HTTPConnectionPool.urlopen(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw) 711 self._prepare_proxy(conn) 713 # Make the request on the httplib connection object. --> 714 httplib_response = self._make_request( 715 conn, 716 method, 717 url, 718 timeout=timeout_obj, 719 body=body, 720 headers=headers, 721 chunked=chunked, 722 ) 724 # If we're going to release the connection in finally:, then 725 # the response doesn't need to know about the connection. Otherwise 726 # it will also try to release it and we'll have a double-release 727 # mess. 728 response_conn = conn if not release_conn else None

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:403, in HTTPConnectionPool._make_request(self, conn, method, url, timeout, chunked, **httplib_request_kw) 401 # Trigger any extra validation we need to do. 402 try: --> 403 self._validate_conn(conn) 404 except (SocketTimeout, BaseSSLError) as e: 405 # Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout. 406 self._raise_timeout(err=e, url=url, timeout_value=conn.timeout)

File /aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:1053, in HTTPSConnectionPool._validate_conn(self, conn) 1051 # Force connect early to allow us to validate the connection. 1052 if not getattr(conn, "sock", None): # AppEngine might not have .sock -> 1053 conn.connect() 1055 if not conn.is_verified: 1056 warnings.warn( 1057 ( 1058 "Unverified HTTPS request is being made to host '%s'. " ref='/aml/train/lib/python3.8/site-packages/urllib3/connectionpool.py:0'>0;32m (...) 1063 InsecureRequestWarning, 1064 )

File /aml/train/lib/python3.8/site-packages/urllib3/connection.py:363, in HTTPSConnection.connect(self) 361 def connect(self): 362 # Add certificate verification --> 363 self.sock = conn = self._new_conn() 364 hostname = self.host 365 tls_in_tls = False

File /aml/train/lib/python3.8/site-packages/urllib3/connection.py:174, in HTTPConnection._new_conn(self) 171 extra_kw["socket_options"] = self.socket_options 173 try: --> 174 conn = connection.create_connection( 175 (self._dns_host, self.port), self.timeout, **extra_kw 176 ) 178 except SocketTimeout: 179 raise ConnectTimeoutError( 180 self, 181 "Connection to %s timed out. (connect timeout=%s)" 182 % (self.host, self.timeout), 183 )

File /aml/train/lib/python3.8/site-packages/urllib3/util/connection.py:85, in create_connection(address, timeout, source_address, socket_options) 83 if source_address: 84 sock.bind(source_address) ---> 85 sock.connect(sa) 86 return sock 88 except socket.error as e:

KeyboardInterrupt: