Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 104 additions & 9 deletions elm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class ApiBase(ABC):
EMBEDDING_MODEL = 'text-embedding-ada-002'
"""Default model to do text embeddings."""

USE_CLIENT_EMBEDDINGS = False
"""Option to use AzureOpenAI client for embedding calls."""

EMBEDDING_URL = 'https://api.openai.com/v1/embeddings'
"""OpenAI embedding API URL"""

Expand Down Expand Up @@ -148,6 +151,41 @@ async def call_api(url, headers, request_json):

return out

@staticmethod
async def call_client_embedding(client, request_json):
"""Call OpenAI embedding API using client.

Parameters
----------
client : openai.azure.AzureOpenAI
Optional OpenAI client to use for embedding calls.
request_json : mapping
Mapping of request json for embedding call (to be passed
to ``client.embeddings.create()``).

Returns
-------
dict
Embeddings response in json format. Will contain an
'error' key if there was an error while processing the API
call.
"""
out = None
kwargs = dict(request_json)

try:
response = client.embeddings.create(**kwargs)
out = response.model_dump_json(indent=2)
except Exception as e:
logger.debug(f'Error in OpenAI API call from '
f'`aiohttp.ClientSession().post(**kwargs)` with '
f'kwargs: {kwargs}')
logger.exception('Error in OpenAI API call! Turn on debug logging '
'to see full query that caused error.')
out = {'error': str(e)}

return out

async def call_api_async(self, url, headers, all_request_jsons,
ignore_error=None, rate_limit=40e3):
"""Use GPT to clean raw pdf text in parallel calls to the OpenAI API.
Expand Down Expand Up @@ -320,8 +358,7 @@ async def generic_async_query(self, queries, model_role=None,

return out

@classmethod
def get_embedding(cls, text):
def get_embedding(self, text):
"""Get the 1D array (list) embedding of a text string.

Parameters
Expand All @@ -334,9 +371,23 @@ def get_embedding(cls, text):
embedding : list
List of float that represents the numerical embedding of the text
"""
kwargs = dict(url=cls.EMBEDDING_URL,
headers=cls.HEADERS,
json={'model': cls.EMBEDDING_MODEL,
if self.USE_CLIENT_EMBEDDINGS:
kwargs = dict(input=text, model=self.EMBEDDING_MODEL)
response = self._client.embeddings.create(**kwargs)

try:
embedding = response.data[0].embedding
except Exception as exc:
msg = ('Embedding request failed: {} {}'
.format(out.reason, embedding))
logger.error(msg)
raise RuntimeError(msg) from exc

return embedding

kwargs = dict(url=self.EMBEDDING_URL,
headers=self.HEADERS,
json={'model': self.EMBEDDING_MODEL,
'input': text})

out = requests.post(**kwargs)
Expand Down Expand Up @@ -485,10 +536,10 @@ def submit_jobs(self):

elif tokens < avail_tokens:
token_count += tokens
task = asyncio.create_task(ApiBase.call_api(self.url,
self.headers,
request),
name=self.job_names[ijob])
task = asyncio.create_task(
self._get_call_api_coro(request),
name=self.job_names[ijob])

self.api_jobs[ijob] = task
self.tries[ijob] += 1
self._tsub = time.time()
Expand All @@ -506,6 +557,10 @@ def submit_jobs(self):
token_count = 0
break

def _get_call_api_coro(self, request):
"""Convenience function to get the appropriate API call coroutine"""
return ApiBase.call_api(self.url, self.headers, request)

async def collect_jobs(self):
"""Collect asyncronous API calls and API outputs. Store outputs in the
`out` attribute."""
Expand Down Expand Up @@ -582,3 +637,43 @@ async def run(self):
time.sleep(5)

return self.out


class ClientEmbeddingsApiQueue(ApiQueue):
"""Class to manage the parallel API embedding submissions using a client"""

def __init__(self, client, request_jsons, ignore_error=None,
rate_limit=40e3, max_retries=10):
"""

Parameters
----------
client : openai.AzureOpenAI | openai.OpenAI
OpenAI client object to use for API calls.
request_jsons : list
List of API data input, one entry typically looks like this for
chat completion:
{"model": "gpt-3.5-turbo",
"messages": [{"role": "system", "content": "You do this..."},
{"role": "user", "content": "Do this: {}"}],
"temperature": 0.0}
ignore_error : None | callable
Optional callable to parse API error string. If the callable
returns True, the error will be ignored, the API call will not be
tried again, and the output will be an empty string.
rate_limit : float
OpenAI API rate limit (tokens / minute). Note that the
gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large
factor of safety (~1/2) because we can only count the tokens on the
input side and assume the output is about the same count.
max_retries : int
Number of times to retry an API call wi
"""
super().__init__(url=None, headers=None, request_jsons=request_jsons,
ignore_error=ignore_error,
rate_limit=rate_limit, max_retries=max_retries)
self.client = client

def _get_call_api_coro(self, request):
"""Convenience function to get the appropriate API call coroutine"""
return ApiBase.call_client_embedding(self.client, request)
69 changes: 61 additions & 8 deletions elm/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
ELM text embedding
"""
import openai
import json
import re
import os
import logging

from elm.base import ApiBase
from elm.base import ApiBase, ApiQueue, ClientEmbeddingsApiQueue
from elm.chunk import Chunker


Expand All @@ -20,7 +21,7 @@ class ChunkAndEmbed(ApiBase):
DEFAULT_MODEL = 'text-embedding-ada-002'
"""Default model to do embeddings."""

def __init__(self, text, model=None, **chunk_kwargs):
def __init__(self, text, model=None, client=None, **chunk_kwargs):
"""
Parameters
----------
Expand All @@ -30,13 +31,19 @@ def __init__(self, text, model=None, **chunk_kwargs):
model : None | str
Optional specification of OpenAI model to use. Default is
cls.DEFAULT_MODEL
client : openai.azure.AzureOpenAI | None
Optional OpenAI client to use for embedding calls. If
``None``, a client is set up using environment variables.
By default, ``None``.
chunk_kwargs : dict | None
kwargs for initialization of :class:`elm.chunk.Chunker`
"""

super().__init__(model)

self.text = text
if client is not None:
self._client = client

if os.path.isfile(text):
logger.info('Loading text file: {}'.format(text))
Expand Down Expand Up @@ -73,7 +80,7 @@ def clean_tables(text):

return '\n'.join(lines)

def run(self, rate_limit=175e3):
def run(self, rate_limit=175e3): # pylint: disable=unused-argument
"""Run text embedding in serial

Parameters
Expand Down Expand Up @@ -142,17 +149,17 @@ async def run_async(self, rate_limit=175e3):
for chunk in self.text_chunks:
req = {"input": chunk, "model": self.model}

if 'azure' in str(openai.api_type).lower():
if 'embedding' not in str(self.model).lower():
req['engine'] = self.model

all_request_jsons.append(req)

embeddings = await self.call_api_async(self.EMBEDDING_URL,
self.HEADERS,
all_request_jsons,
rate_limit=rate_limit)
embeddings = await self.call_embedding_async(all_request_jsons,
rate_limit=rate_limit)

for i, chunk in enumerate(embeddings):
if self.USE_CLIENT_EMBEDDINGS:
chunk = json.loads(chunk)
try:
embeddings[i] = chunk['data'][0]['embedding']
except Exception:
Expand All @@ -164,3 +171,49 @@ async def run_async(self, rate_limit=175e3):
logger.info('Finished all embeddings.')

return embeddings

async def call_embedding_async(self, all_request_jsons,
ignore_error=None, rate_limit=40e3):
"""Use an OpenAI API client to generate embeddings for text.

NOTE: you need to call this using the await command in ipython or
jupyter, e.g.: `out = await PDFtoTXT.clean_txt_async()`

Parameters
----------
all_request_jsons : list
List of API data input, one entry typically looks like this for
chat completion:
{"model": "gpt-3.5-turbo",
"messages": [{"role": "system", "content": "You do this..."},
{"role": "user", "content": "Do this: {}"}],
"temperature": 0.0}
ignore_error : None | callable
Optional callable to parse API error string. If the callable
returns True, the error will be ignored, the API call will not be
tried again, and the output will be an empty string.
rate_limit : float
OpenAI API rate limit (tokens / minute). Note that the
gpt-3.5-turbo limit is 90k as of 4/2023, but we're using a large
factor of safety (~1/2) because we can only count the tokens on the
input side and assume the output is about the same count.

Returns
-------
out : list
List of API outputs where each list entry is a GPT answer from the
corresponding message in the all_request_jsons input.
"""
if self.USE_CLIENT_EMBEDDINGS:
self.api_queue = ClientEmbeddingsApiQueue(self._client,
all_request_jsons,
ignore_error,
rate_limit=rate_limit)
else:
self.api_queue = ApiQueue(self.EMBEDDING_URL, self.EMBEDDING_URL,
all_request_jsons,
ignore_error=ignore_error,
rate_limit=rate_limit)

out = await self.api_queue.run()
return out
2 changes: 1 addition & 1 deletion elm/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
ELM version number
"""

__version__ = "0.0.35"
__version__ = "0.0.36"
4 changes: 2 additions & 2 deletions tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class MockObject:
class MockClass:
"""Dummy class to mock various api calls"""

@staticmethod
def get_embedding(*args, **kwargs): # pylint: disable=unused-argument
# pylint: disable=unused-argument
def get_embedding(self, *args, **kwargs):
"""Mock for ChunkAndEmbed.call_api()"""
embedding = np.random.uniform(0, 1, 10)
return embedding
Expand Down