diff --git a/elm/base.py b/elm/base.py index ef987c3e..f75ac3b2 100644 --- a/elm/base.py +++ b/elm/base.py @@ -26,6 +26,9 @@ class ApiBase(ABC): EMBEDDING_MODEL = 'text-embedding-ada-002' """Default model to do text embeddings.""" + USE_CLIENT_EMBEDDINGS = False + """Option to use AzureOpenAI client for embedding calls.""" + EMBEDDING_URL = 'https://api.openai.com/v1/embeddings' """OpenAI embedding API URL""" @@ -148,6 +151,41 @@ async def call_api(url, headers, request_json): return out + @staticmethod + async def call_client_embedding(client, request_json): + """Call OpenAI embedding API using client. + + Parameters + ---------- + client : openai.azure.AzureOpenAI + Optional OpenAI client to use for embedding calls. + request_json : mapping + Mapping of request json for embedding call (to be passed + to ``client.embeddings.create()``). + + Returns + ------- + dict + Embeddings response in json format. Will contain an + 'error' key if there was an error while processing the API + call. + """ + out = None + kwargs = dict(request_json) + + try: + response = client.embeddings.create(**kwargs) + out = response.model_dump_json(indent=2) + except Exception as e: + logger.debug(f'Error in OpenAI API call from ' + f'`aiohttp.ClientSession().post(**kwargs)` with ' + f'kwargs: {kwargs}') + logger.exception('Error in OpenAI API call! Turn on debug logging ' + 'to see full query that caused error.') + out = {'error': str(e)} + + return out + async def call_api_async(self, url, headers, all_request_jsons, ignore_error=None, rate_limit=40e3): """Use GPT to clean raw pdf text in parallel calls to the OpenAI API. @@ -320,8 +358,7 @@ async def generic_async_query(self, queries, model_role=None, return out - @classmethod - def get_embedding(cls, text): + def get_embedding(self, text): """Get the 1D array (list) embedding of a text string. Parameters @@ -334,9 +371,23 @@ def get_embedding(cls, text): embedding : list List of float that represents the numerical embedding of the text """ - kwargs = dict(url=cls.EMBEDDING_URL, - headers=cls.HEADERS, - json={'model': cls.EMBEDDING_MODEL, + if self.USE_CLIENT_EMBEDDINGS: + kwargs = dict(input=text, model=self.EMBEDDING_MODEL) + response = self._client.embeddings.create(**kwargs) + + try: + embedding = response.data[0].embedding + except Exception as exc: + msg = ('Embedding request failed: {} {}' + .format(out.reason, embedding)) + logger.error(msg) + raise RuntimeError(msg) from exc + + return embedding + + kwargs = dict(url=self.EMBEDDING_URL, + headers=self.HEADERS, + json={'model': self.EMBEDDING_MODEL, 'input': text}) out = requests.post(**kwargs) @@ -485,10 +536,10 @@ def submit_jobs(self): elif tokens < avail_tokens: token_count += tokens - task = asyncio.create_task(ApiBase.call_api(self.url, - self.headers, - request), - name=self.job_names[ijob]) + task = asyncio.create_task( + self._get_call_api_coro(request), + name=self.job_names[ijob]) + self.api_jobs[ijob] = task self.tries[ijob] += 1 self._tsub = time.time() @@ -506,6 +557,10 @@ def submit_jobs(self): token_count = 0 break + def _get_call_api_coro(self, request): + """Convenience function to get the appropriate API call coroutine""" + return ApiBase.call_api(self.url, self.headers, request) + async def collect_jobs(self): """Collect asyncronous API calls and API outputs. Store outputs in the `out` attribute.""" @@ -582,3 +637,43 @@ async def run(self): time.sleep(5) return self.out + + +class ClientEmbeddingsApiQueue(ApiQueue): + """Class to manage the parallel API embedding submissions using a client""" + + def __init__(self, client, request_jsons, ignore_error=None, + rate_limit=40e3, max_retries=10): + """ + + Parameters + ---------- + client : openai.AzureOpenAI | openai.OpenAI + OpenAI client object to use for API calls. + request_jsons : list + List of API data input, one entry typically looks like this for + chat completion: + {"model": "gpt-3.5-turbo", + "messages": [{"role": "system", "content": "You do this..."}, + {"role": "user", "content": "Do this: {}"}], + "temperature": 0.0} + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. + rate_limit : float + OpenAI API rate limit (tokens / minute). Note that the + gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large + factor of safety (~1/2) because we can only count the tokens on the + input side and assume the output is about the same count. + max_retries : int + Number of times to retry an API call wi + """ + super().__init__(url=None, headers=None, request_jsons=request_jsons, + ignore_error=ignore_error, + rate_limit=rate_limit, max_retries=max_retries) + self.client = client + + def _get_call_api_coro(self, request): + """Convenience function to get the appropriate API call coroutine""" + return ApiBase.call_client_embedding(self.client, request) diff --git a/elm/embed.py b/elm/embed.py index b1470779..b0bc189e 100644 --- a/elm/embed.py +++ b/elm/embed.py @@ -3,11 +3,12 @@ ELM text embedding """ import openai +import json import re import os import logging -from elm.base import ApiBase +from elm.base import ApiBase, ApiQueue, ClientEmbeddingsApiQueue from elm.chunk import Chunker @@ -20,7 +21,7 @@ class ChunkAndEmbed(ApiBase): DEFAULT_MODEL = 'text-embedding-ada-002' """Default model to do embeddings.""" - def __init__(self, text, model=None, **chunk_kwargs): + def __init__(self, text, model=None, client=None, **chunk_kwargs): """ Parameters ---------- @@ -30,6 +31,10 @@ def __init__(self, text, model=None, **chunk_kwargs): model : None | str Optional specification of OpenAI model to use. Default is cls.DEFAULT_MODEL + client : openai.azure.AzureOpenAI | None + Optional OpenAI client to use for embedding calls. If + ``None``, a client is set up using environment variables. + By default, ``None``. chunk_kwargs : dict | None kwargs for initialization of :class:`elm.chunk.Chunker` """ @@ -37,6 +42,8 @@ def __init__(self, text, model=None, **chunk_kwargs): super().__init__(model) self.text = text + if client is not None: + self._client = client if os.path.isfile(text): logger.info('Loading text file: {}'.format(text)) @@ -73,7 +80,7 @@ def clean_tables(text): return '\n'.join(lines) - def run(self, rate_limit=175e3): + def run(self, rate_limit=175e3): # pylint: disable=unused-argument """Run text embedding in serial Parameters @@ -142,17 +149,17 @@ async def run_async(self, rate_limit=175e3): for chunk in self.text_chunks: req = {"input": chunk, "model": self.model} - if 'azure' in str(openai.api_type).lower(): + if 'embedding' not in str(self.model).lower(): req['engine'] = self.model all_request_jsons.append(req) - embeddings = await self.call_api_async(self.EMBEDDING_URL, - self.HEADERS, - all_request_jsons, - rate_limit=rate_limit) + embeddings = await self.call_embedding_async(all_request_jsons, + rate_limit=rate_limit) for i, chunk in enumerate(embeddings): + if self.USE_CLIENT_EMBEDDINGS: + chunk = json.loads(chunk) try: embeddings[i] = chunk['data'][0]['embedding'] except Exception: @@ -164,3 +171,49 @@ async def run_async(self, rate_limit=175e3): logger.info('Finished all embeddings.') return embeddings + + async def call_embedding_async(self, all_request_jsons, + ignore_error=None, rate_limit=40e3): + """Use an OpenAI API client to generate embeddings for text. + + NOTE: you need to call this using the await command in ipython or + jupyter, e.g.: `out = await PDFtoTXT.clean_txt_async()` + + Parameters + ---------- + all_request_jsons : list + List of API data input, one entry typically looks like this for + chat completion: + {"model": "gpt-3.5-turbo", + "messages": [{"role": "system", "content": "You do this..."}, + {"role": "user", "content": "Do this: {}"}], + "temperature": 0.0} + ignore_error : None | callable + Optional callable to parse API error string. If the callable + returns True, the error will be ignored, the API call will not be + tried again, and the output will be an empty string. + rate_limit : float + OpenAI API rate limit (tokens / minute). Note that the + gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large + factor of safety (~1/2) because we can only count the tokens on the + input side and assume the output is about the same count. + + Returns + ------- + out : list + List of API outputs where each list entry is a GPT answer from the + corresponding message in the all_request_jsons input. + """ + if self.USE_CLIENT_EMBEDDINGS: + self.api_queue = ClientEmbeddingsApiQueue(self._client, + all_request_jsons, + ignore_error, + rate_limit=rate_limit) + else: + self.api_queue = ApiQueue(self.EMBEDDING_URL, self.EMBEDDING_URL, + all_request_jsons, + ignore_error=ignore_error, + rate_limit=rate_limit) + + out = await self.api_queue.run() + return out diff --git a/elm/version.py b/elm/version.py index f9963abf..2627c20a 100644 --- a/elm/version.py +++ b/elm/version.py @@ -2,4 +2,4 @@ ELM version number """ -__version__ = "0.0.35" +__version__ = "0.0.36" diff --git a/tests/test_wizard.py b/tests/test_wizard.py index cf64353b..c69c3e31 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -26,8 +26,8 @@ class MockObject: class MockClass: """Dummy class to mock various api calls""" - @staticmethod - def get_embedding(*args, **kwargs): # pylint: disable=unused-argument + # pylint: disable=unused-argument + def get_embedding(self, *args, **kwargs): """Mock for ChunkAndEmbed.call_api()""" embedding = np.random.uniform(0, 1, 10) return embedding