Skip to content

Commit

Permalink
Update Cohere to 5.13.4 v2 API (#45267)
Browse files Browse the repository at this point in the history
  • Loading branch information
okirialbert authored Dec 29, 2024
1 parent 7229e2f commit 2a78648
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 35 deletions.
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@
"cohere": {
"deps": [
"apache-airflow>=2.9.0",
"cohere>=4.37,<5"
"cohere>=5.13.4"
],
"devel-deps": [],
"plugins": [],
Expand Down
70 changes: 55 additions & 15 deletions providers/src/airflow/providers/cohere/hooks/cohere.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -15,25 +14,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
import warnings
from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any

import cohere
from cohere.types import UserChatMessageV2

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from cohere.core.request_options import RequestOptions
from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings


logger = logging.getLogger(__name__)


class CohereHook(BaseHook):
"""
Use Cohere Python SDK to interact with Cohere platform.
Use Cohere Python SDK to interact with Cohere platform using API v2.
.. seealso:: https://docs.cohere.com/docs
:param conn_id: :ref:`Cohere connection id <howto/connection:cohere>`
:param timeout: Request timeout in seconds.
:param max_retries: Maximal number of retries for requests.
:param timeout: Request timeout in seconds. Optional.
:param max_retries: Maximal number of retries for requests. Deprecated, use request_options instead. Optional.
:param request_options: Dictionary for function-specific request configuration. Optional.
"""

conn_name_attr = "conn_id"
Expand All @@ -46,23 +58,45 @@ def __init__(
conn_id: str = default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
request_options: RequestOptions | None = None,
) -> None:
super().__init__()
self.conn_id = conn_id
self.timeout = timeout
self.max_retries = max_retries
self.request_options = request_options

if self.max_retries:
warnings.warn(
"Argument `max_retries` is deprecated. Use `request_options` dict for function-specific request configuration.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
if self.request_options is None:
self.request_options = {"max_retries": self.max_retries}
else:
self.request_options.update({"max_retries": self.max_retries})

@cached_property
def get_conn(self) -> cohere.Client: # type: ignore[override]
def get_conn(self) -> cohere.ClientV2: # type: ignore[override]
conn = self.get_connection(self.conn_id)
return cohere.Client(
api_key=conn.password, timeout=self.timeout, max_retries=self.max_retries, api_url=conn.host
return cohere.ClientV2(
api_key=conn.password,
timeout=self.timeout,
base_url=conn.host or None,
)

def create_embeddings(
self, texts: list[str], model: str = "embed-multilingual-v2.0"
) -> list[list[float]]:
response = self.get_conn.embed(texts=texts, model=model)
self, texts: list[str], model: str = "embed-multilingual-v3.0"
) -> EmbedByTypeResponseEmbeddings:
logger.info("Creating embeddings with model: embed-multilingual-v3.0")
response = self.get_conn.embed(
texts=texts,
model=model,
input_type="search_document",
embedding_types=["float"],
request_options=self.request_options,
)
embeddings = response.embeddings
return embeddings

Expand All @@ -75,9 +109,15 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
},
}

def test_connection(self) -> tuple[bool, str]:
def test_connection(
self,
model: str = "command-r-plus-08-2024",
messages: ChatMessages | None = None,
) -> tuple[bool, str]:
try:
self.get_conn.generate("Test", max_tokens=10)
return True, "Connection established"
if messages is None:
messages = [UserChatMessageV2(role="user", content="hello world!")]
self.get_conn.chat(model=model, messages=messages)
return True, "Connection successfully established."
except Exception as e:
return False, str(e)
return False, f"Unexpected error: {str(e)}"
25 changes: 23 additions & 2 deletions providers/src/airflow/providers/cohere/operators/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from airflow.providers.cohere.hooks.cohere import CohereHook

if TYPE_CHECKING:
from cohere.core.request_options import RequestOptions
from cohere.types import EmbedByTypeResponseEmbeddings

from airflow.utils.context import Context


Expand All @@ -41,6 +44,17 @@ class CohereEmbeddingOperator(BaseOperator):
information for Cohere. Defaults to "cohere_default".
:param timeout: Timeout in seconds for Cohere API.
:param max_retries: Number of times to retry before failing.
:param request_options: Request-specific configuration.
Fields:
- timeout_in_seconds: int. The number of seconds to await an API call before timing out.
- max_retries: int. The max number of retries to attempt if the API call fails.
- additional_headers: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's header dict
- additional_query_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's query parameters dict
- additional_body_parameters: typing.Dict[str, typing.Any]. A dictionary containing additional parameters to spread into the request's body parameters dict
"""

template_fields: Sequence[str] = ("input_text",)
Expand All @@ -51,6 +65,7 @@ def __init__(
conn_id: str = CohereHook.default_conn_name,
timeout: int | None = None,
max_retries: int | None = None,
request_options: RequestOptions | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
Expand All @@ -60,12 +75,18 @@ def __init__(
self.input_text = input_text
self.timeout = timeout
self.max_retries = max_retries
self.request_options = request_options

@cached_property
def hook(self) -> CohereHook:
"""Return an instance of the CohereHook."""
return CohereHook(conn_id=self.conn_id, timeout=self.timeout, max_retries=self.max_retries)
return CohereHook(
conn_id=self.conn_id,
timeout=self.timeout,
max_retries=self.max_retries,
request_options=self.request_options,
)

def execute(self, context: Context) -> list[list[float]]:
def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings:
"""Embed texts using Cohere embed services."""
return self.hook.create_embeddings(self.input_text)
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/cohere/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ integrations:

dependencies:
- apache-airflow>=2.9.0
- cohere>=4.37,<5
- cohere>=5.13.4

hooks:
- integration-name: Cohere
Expand Down
13 changes: 5 additions & 8 deletions providers/tests/cohere/hooks/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,16 @@ class TestCohereHook:

def test__get_api_key(self):
api_key = "test"
api_url = "http://some_host.com"
base_url = "http://some_host.com"
timeout = 150
max_retries = 5
with (
patch.object(
CohereHook,
"get_connection",
return_value=Connection(conn_type="cohere", password=api_key, host=api_url),
return_value=Connection(conn_type="cohere", password=api_key, host=base_url),
),
patch("cohere.Client") as client,
patch("cohere.ClientV2") as client,
):
hook = CohereHook(timeout=timeout, max_retries=max_retries)
hook = CohereHook(timeout=timeout)
_ = hook.get_conn
client.assert_called_once_with(
api_key=api_key, timeout=timeout, max_retries=max_retries, api_url=api_url
)
client.assert_called_once_with(api_key=api_key, timeout=timeout, base_url=base_url)
18 changes: 10 additions & 8 deletions providers/tests/cohere/operators/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@patch("airflow.providers.cohere.hooks.cohere.CohereHook.get_connection")
@patch("cohere.Client")
@patch("cohere.ClientV2")
def test_cohere_embedding_operator(cohere_client, get_connection):
"""
Test Cohere client is getting called with the correct key and that
Expand All @@ -35,22 +35,24 @@ class resp:
embeddings = embedded_obj

api_key = "test"
api_url = "http://some_host.com"
base_url = "http://some_host.com"
timeout = 150
max_retries = 5
texts = ["On Kernel-Target Alignment. We describe a family of global optimization procedures"]
request_options = None

get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=api_url)
get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=base_url)
client_obj = MagicMock()
cohere_client.return_value = client_obj
client_obj.embed.return_value = resp

op = CohereEmbeddingOperator(
task_id="embed", conn_id="some_conn", input_text=texts, timeout=timeout, max_retries=max_retries
task_id="embed",
conn_id="some_conn",
input_text=texts,
timeout=timeout,
request_options=request_options,
)

val = op.execute(context={})
cohere_client.assert_called_once_with(
api_key=api_key, api_url=api_url, timeout=timeout, max_retries=max_retries
)
cohere_client.assert_called_once_with(api_key=api_key, base_url=base_url, timeout=timeout)
assert val == embedded_obj

0 comments on commit 2a78648

Please sign in to comment.