Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions hindsight-api-slim/hindsight_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
ENV_EMBEDDINGS_OPENAI_MODEL = "HINDSIGHT_API_EMBEDDINGS_OPENAI_MODEL"
ENV_EMBEDDINGS_OPENAI_BASE_URL = "HINDSIGHT_API_EMBEDDINGS_OPENAI_BASE_URL"

# Gemini/Vertex AI embeddings configuration
ENV_EMBEDDINGS_GEMINI_API_KEY = "HINDSIGHT_API_EMBEDDINGS_GEMINI_API_KEY"
ENV_EMBEDDINGS_GEMINI_MODEL = "HINDSIGHT_API_EMBEDDINGS_GEMINI_MODEL"
ENV_EMBEDDINGS_GEMINI_OUTPUT_DIMENSIONALITY = "HINDSIGHT_API_EMBEDDINGS_GEMINI_OUTPUT_DIMENSIONALITY"
ENV_EMBEDDINGS_VERTEXAI_PROJECT_ID = "HINDSIGHT_API_EMBEDDINGS_VERTEXAI_PROJECT_ID"
ENV_EMBEDDINGS_VERTEXAI_REGION = "HINDSIGHT_API_EMBEDDINGS_VERTEXAI_REGION"
ENV_EMBEDDINGS_VERTEXAI_SERVICE_ACCOUNT_KEY = "HINDSIGHT_API_EMBEDDINGS_VERTEXAI_SERVICE_ACCOUNT_KEY"

# Cohere configuration (separate for embeddings and reranker)
ENV_EMBEDDINGS_COHERE_API_KEY = "HINDSIGHT_API_EMBEDDINGS_COHERE_API_KEY"
ENV_EMBEDDINGS_COHERE_MODEL = "HINDSIGHT_API_EMBEDDINGS_COHERE_MODEL"
Expand Down Expand Up @@ -231,6 +239,11 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
ENV_RERANKER_ZEROENTROPY_MODEL = "HINDSIGHT_API_RERANKER_ZEROENTROPY_MODEL"
ENV_RERANKER_ZEROENTROPY_BASE_URL = "HINDSIGHT_API_RERANKER_ZEROENTROPY_BASE_URL"

# Google Discovery Engine reranker configuration
ENV_RERANKER_GOOGLE_MODEL = "HINDSIGHT_API_RERANKER_GOOGLE_MODEL"
ENV_RERANKER_GOOGLE_PROJECT_ID = "HINDSIGHT_API_RERANKER_GOOGLE_PROJECT_ID"
ENV_RERANKER_GOOGLE_SERVICE_ACCOUNT_KEY = "HINDSIGHT_API_RERANKER_GOOGLE_SERVICE_ACCOUNT_KEY"

ENV_VECTOR_EXTENSION = "HINDSIGHT_API_VECTOR_EXTENSION"
ENV_TEXT_SEARCH_EXTENSION = "HINDSIGHT_API_TEXT_SEARCH_EXTENSION"

Expand Down Expand Up @@ -403,6 +416,8 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
DEFAULT_EMBEDDINGS_LOCAL_FORCE_CPU = False # Force CPU mode for local embeddings (avoids MPS/XPC issues on macOS)
DEFAULT_EMBEDDINGS_LOCAL_TRUST_REMOTE_CODE = False # Security: disabled by default, required for some models
DEFAULT_EMBEDDINGS_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_EMBEDDINGS_GEMINI_MODEL = "gemini-embedding-001"
DEFAULT_EMBEDDINGS_GEMINI_OUTPUT_DIMENSIONALITY = 768
DEFAULT_EMBEDDING_DIMENSION = 384

DEFAULT_RERANKER_PROVIDER = "local"
Expand All @@ -426,6 +441,8 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:

DEFAULT_RERANKER_ZEROENTROPY_MODEL = "zerank-2"

DEFAULT_RERANKER_GOOGLE_MODEL = "semantic-ranker-default-004"

# Vector extension (pgvector, vchord, or pgvectorscale)
DEFAULT_VECTOR_EXTENSION = "pgvector" # Options: "pgvector", "vchord", "pgvectorscale"

Expand Down Expand Up @@ -706,6 +723,13 @@ class HindsightConfig:
embeddings_litellm_sdk_model: str
embeddings_litellm_sdk_api_base: str | None
embeddings_litellm_sdk_output_dimensions: int | None
# Gemini/Vertex AI embeddings
embeddings_gemini_api_key: str | None
embeddings_gemini_model: str
embeddings_gemini_output_dimensionality: int | None
embeddings_vertexai_project_id: str | None
embeddings_vertexai_region: str | None
embeddings_vertexai_service_account_key: str | None

# Reranker
reranker_provider: str
Expand Down Expand Up @@ -733,6 +757,9 @@ class HindsightConfig:
reranker_zeroentropy_api_key: str | None
reranker_zeroentropy_model: str
reranker_zeroentropy_base_url: str | None
reranker_google_model: str
reranker_google_project_id: str | None
reranker_google_service_account_key: str | None

# Server
host: str
Expand Down Expand Up @@ -882,6 +909,10 @@ class HindsightConfig:
"reranker_zeroentropy_base_url",
# Service Account Keys
"llm_vertexai_service_account_key",
"embeddings_vertexai_service_account_key",
"reranker_google_service_account_key",
# Embeddings API keys
"embeddings_gemini_api_key",
# File storage credentials
"file_storage_s3_access_key_id",
"file_storage_s3_secret_access_key",
Expand Down Expand Up @@ -1160,6 +1191,20 @@ def from_env(cls) -> "HindsightConfig":
embeddings_litellm_sdk_output_dimensions=int(v)
if (v := os.getenv(ENV_EMBEDDINGS_LITELLM_SDK_OUTPUT_DIMENSIONS))
else None,
# Gemini/Vertex AI embeddings (with fallback to LLM keys)
embeddings_gemini_api_key=os.getenv(ENV_EMBEDDINGS_GEMINI_API_KEY) or os.getenv(ENV_LLM_API_KEY),
embeddings_gemini_model=os.getenv(ENV_EMBEDDINGS_GEMINI_MODEL, DEFAULT_EMBEDDINGS_GEMINI_MODEL),
embeddings_gemini_output_dimensionality=int(
os.getenv(
ENV_EMBEDDINGS_GEMINI_OUTPUT_DIMENSIONALITY,
str(DEFAULT_EMBEDDINGS_GEMINI_OUTPUT_DIMENSIONALITY),
)
),
embeddings_vertexai_project_id=os.getenv(ENV_EMBEDDINGS_VERTEXAI_PROJECT_ID)
or os.getenv(ENV_LLM_VERTEXAI_PROJECT_ID),
embeddings_vertexai_region=os.getenv(ENV_EMBEDDINGS_VERTEXAI_REGION) or os.getenv(ENV_LLM_VERTEXAI_REGION),
embeddings_vertexai_service_account_key=os.getenv(ENV_EMBEDDINGS_VERTEXAI_SERVICE_ACCOUNT_KEY)
or os.getenv(ENV_LLM_VERTEXAI_SERVICE_ACCOUNT_KEY),
# Reranker
reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
Expand Down Expand Up @@ -1209,6 +1254,12 @@ def from_env(cls) -> "HindsightConfig":
reranker_zeroentropy_api_key=os.getenv(ENV_RERANKER_ZEROENTROPY_API_KEY),
reranker_zeroentropy_model=os.getenv(ENV_RERANKER_ZEROENTROPY_MODEL, DEFAULT_RERANKER_ZEROENTROPY_MODEL),
reranker_zeroentropy_base_url=os.getenv(ENV_RERANKER_ZEROENTROPY_BASE_URL) or None,
# Google Discovery Engine reranker (with fallback to LLM Vertex AI keys)
reranker_google_model=os.getenv(ENV_RERANKER_GOOGLE_MODEL, DEFAULT_RERANKER_GOOGLE_MODEL),
reranker_google_project_id=os.getenv(ENV_RERANKER_GOOGLE_PROJECT_ID)
or os.getenv(ENV_LLM_VERTEXAI_PROJECT_ID),
reranker_google_service_account_key=os.getenv(ENV_RERANKER_GOOGLE_SERVICE_ACCOUNT_KEY)
or os.getenv(ENV_LLM_VERTEXAI_SERVICE_ACCOUNT_KEY),
# Server
host=os.getenv(ENV_HOST, DEFAULT_HOST),
port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),
Expand Down
174 changes: 173 additions & 1 deletion hindsight-api-slim/hindsight_api/engine/cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DEFAULT_RERANKER_COHERE_MODEL,
DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
DEFAULT_RERANKER_FLASHRANK_MODEL,
DEFAULT_RERANKER_GOOGLE_MODEL,
DEFAULT_RERANKER_LITELLM_MAX_TOKENS_PER_DOC,
DEFAULT_RERANKER_LITELLM_MODEL,
DEFAULT_RERANKER_LITELLM_SDK_MODEL,
Expand All @@ -36,6 +37,7 @@
ENV_RERANKER_COHERE_MODEL,
ENV_RERANKER_FLASHRANK_CACHE_DIR,
ENV_RERANKER_FLASHRANK_MODEL,
ENV_RERANKER_GOOGLE_PROJECT_ID,
ENV_RERANKER_LITELLM_SDK_API_KEY,
ENV_RERANKER_LOCAL_FORCE_CPU,
ENV_RERANKER_LOCAL_MAX_CONCURRENT,
Expand Down Expand Up @@ -1266,6 +1268,164 @@ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
return await loop.run_in_executor(None, self._predict_sync, pairs)


class GoogleCrossEncoder(CrossEncoderModel):
"""
Google Discovery Engine cross-encoder using the Ranking REST API.

Uses httpx + google-auth for lightweight REST calls (no gRPC/protobuf).
Supports ADC (Application Default Credentials) or service account key file.

Available models:
- semantic-ranker-default-004: Best quality, 1024 tokens/record (recommended)
- semantic-ranker-fast-004: Lower latency, 1024 tokens/record

Max 200 records per API request. Location is always "global".
"""

MAX_RECORDS_PER_REQUEST = 200
API_BASE = "https://discoveryengine.googleapis.com/v1"
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]

def __init__(
self,
project_id: str,
model: str = DEFAULT_RERANKER_GOOGLE_MODEL,
service_account_key: str | None = None,
location: str = "global",
timeout: float = 60.0,
):
"""
Initialize Google Discovery Engine cross-encoder.

Args:
project_id: Google Cloud project ID
model: Ranking model name (default: semantic-ranker-default-004)
service_account_key: Path to service account JSON key file.
If None, uses Application Default Credentials (ADC).
location: API location (default: "global")
timeout: Request timeout in seconds (default: 60.0)
"""
self.project_id = project_id
self.model = model
self.service_account_key = service_account_key
self.location = location
self.timeout = timeout
self._credentials = None
self._client: httpx.Client | None = None
self._rank_url: str | None = None

@property
def provider_name(self) -> str:
return "google"

def _get_auth_headers(self) -> dict[str, str]:
"""Get Authorization header with a fresh access token."""
import google.auth.transport.requests

if not self._credentials.valid:
self._credentials.refresh(google.auth.transport.requests.Request())
return {"Authorization": f"Bearer {self._credentials.token}"}

async def initialize(self) -> None:
"""Initialize credentials and HTTP client."""
if self._client is not None:
return

auth_method = "ADC" if not self.service_account_key else "service_account"
logger.info(
f"Reranker: initializing Google Discovery Engine provider "
f"(project={self.project_id}, model={self.model}, auth={auth_method})"
)
if self.service_account_key:
try:
from google.oauth2 import service_account
except ImportError:
raise ImportError(
"google-auth is required for GoogleCrossEncoder. Install it with: pip install google-auth"
)
self._credentials = service_account.Credentials.from_service_account_file(
self.service_account_key,
scopes=self.SCOPES,
)
else:
try:
import google.auth
except ImportError:
raise ImportError(
"google-auth is required for GoogleCrossEncoder. Install it with: pip install google-auth"
)
self._credentials, _ = google.auth.default(scopes=self.SCOPES)

ranking_config = f"projects/{self.project_id}/locations/{self.location}/rankingConfigs/default_ranking_config"
self._rank_url = f"{self.API_BASE}/{ranking_config}:rank"
self._client = httpx.Client(timeout=self.timeout)

logger.info("Reranker: Google Discovery Engine provider initialized")

def _predict_sync(self, pairs: list[tuple[str, str]]) -> list[float]:
"""Synchronous predict via REST API."""
if not pairs:
return []

# Group pairs by query
query_groups: dict[str, list[tuple[int, str]]] = {}
for idx, (query, text) in enumerate(pairs):
if query not in query_groups:
query_groups[query] = []
query_groups[query].append((idx, text))

all_scores = [0.0] * len(pairs)

for query, indexed_texts in query_groups.items():
texts = [text for _, text in indexed_texts]
indices = [idx for idx, _ in indexed_texts]

# Process in batches of MAX_RECORDS_PER_REQUEST
for batch_start in range(0, len(texts), self.MAX_RECORDS_PER_REQUEST):
batch_texts = texts[batch_start : batch_start + self.MAX_RECORDS_PER_REQUEST]
batch_indices = indices[batch_start : batch_start + self.MAX_RECORDS_PER_REQUEST]

records = [{"id": str(i), "content": text} for i, text in enumerate(batch_texts)]

response = self._client.post(
self._rank_url,
headers=self._get_auth_headers(),
json={
"model": self.model,
"query": query,
"records": records,
"topN": len(records),
},
)
response.raise_for_status()
result = response.json()

for record in result.get("records", []):
local_idx = int(record["id"])
all_scores[batch_indices[local_idx]] = record["score"]

return all_scores

async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
"""
Score query-document pairs using Google Discovery Engine Ranking API.

Args:
pairs: List of (query, document) tuples to score

Returns:
List of relevance scores (0-1, higher = more relevant)
"""
if self._client is None:
raise RuntimeError("Reranker not initialized. Call initialize() first.")

if not pairs:
return []

loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._predict_sync, pairs)


def create_cross_encoder_from_env() -> CrossEncoderModel:
"""
Create a CrossEncoderModel instance based on configuration.
Expand Down Expand Up @@ -1341,11 +1501,23 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
api_key=api_key,
model=config.reranker_zeroentropy_model,
)
elif provider == "google":
project_id = config.reranker_google_project_id
if not project_id:
raise ValueError(
f"{ENV_RERANKER_GOOGLE_PROJECT_ID} (or HINDSIGHT_API_LLM_VERTEXAI_PROJECT_ID) "
f"is required when {ENV_RERANKER_PROVIDER} is 'google'"
)
return GoogleCrossEncoder(
project_id=project_id,
model=config.reranker_google_model,
service_account_key=config.reranker_google_service_account_key,
)
elif provider == "rrf":
return RRFPassthroughCrossEncoder()
elif provider == "jina-mlx":
return JinaMLXCrossEncoder()
else:
raise ValueError(
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'"
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'google', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'"
)
Loading
Loading