Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 101 additions & 51 deletions newrelic/hooks/mlmodel_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,24 @@ def _record_embedding_success(transaction, embedding_id, linking_metadata, kwarg
embedding_content = str(embedding_content)
request_model = kwargs.get("model")

embedding_token_count = (
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
if settings.ai_monitoring.llm_token_count_callback
else None
)

full_embedding_response_dict = {
"id": embedding_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(request_model, embedding_content)
if settings.ai_monitoring.llm_token_count_callback
else None
),
"request.model": request_model,
"duration": ft.duration * 1000,
"vendor": "gemini",
"ingest_source": "Python",
}
if embedding_token_count:
full_embedding_response_dict["response.usage.total_tokens"] = embedding_token_count

if settings.ai_monitoring.record_content.enabled:
full_embedding_response_dict["input"] = embedding_content

Expand Down Expand Up @@ -300,15 +304,13 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg
"Unable to parse input message to Gemini LLM. Message content and role will be omitted from "
"corresponding LlmChatCompletionMessage event. "
)
# Extract the input message content and role from the input message if it exists
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)

generation_config = kwargs.get("config")
if generation_config:
request_temperature = getattr(generation_config, "temperature", None)
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
else:
request_temperature = None
request_max_tokens = None
# Extract data from generation config object
request_temperature, request_max_tokens = _extract_generation_config(kwargs)

# Prepare error attributes
notice_error_attributes = {
"http.statusCode": getattr(exc, "code", None),
"error.message": getattr(exc, "message", None),
Expand Down Expand Up @@ -348,15 +350,17 @@ def _record_generation_error(transaction, linking_metadata, completion_id, kwarg

create_chat_completion_message_event(
transaction,
input_message,
input_message_content,
input_role,
completion_id,
span_id,
trace_id,
# Passing the request model as the response model here since we do not have access to a response model
request_model,
request_model,
llm_metadata,
output_message_list,
# We do not record token counts in error cases, so set all_token_counts to True so the pipeline tokenizer does not run
all_token_counts=True,
)
except Exception:
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)
Expand All @@ -377,6 +381,7 @@ def _handle_generation_success(transaction, linking_metadata, completion_id, kwa


def _record_generation_success(transaction, linking_metadata, completion_id, kwargs, ft, response):
settings = transaction.settings or global_settings()
span_id = linking_metadata.get("span.id")
trace_id = linking_metadata.get("trace.id")
try:
Expand All @@ -385,12 +390,14 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
# finish_reason is an enum, so grab just the stringified value from it to report
finish_reason = response.get("candidates")[0].get("finish_reason").value
output_message_list = [response.get("candidates")[0].get("content")]
token_usage = response.get("usage_metadata") or {}
else:
# Set all values to NoneTypes since we cannot access them through kwargs or another method that doesn't
# require the response object
response_model = None
output_message_list = []
finish_reason = None
token_usage = {}

request_model = kwargs.get("model")

Expand All @@ -412,13 +419,44 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
"corresponding LlmChatCompletionMessage event. "
)

generation_config = kwargs.get("config")
if generation_config:
request_temperature = getattr(generation_config, "temperature", None)
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
input_message_content, input_role = _parse_input_message(input_message) if input_message else (None, None)

# Parse output message content
# This list should have a length of 1 to represent the output message
# Parse the message text out to pass to any registered token counting callback
output_message_content = output_message_list[0].get("parts")[0].get("text") if output_message_list else None

# Extract token counts from response object
if token_usage:
response_prompt_tokens = token_usage.get("prompt_token_count")
response_completion_tokens = token_usage.get("candidates_token_count")
response_total_tokens = token_usage.get("total_token_count")

else:
request_temperature = None
request_max_tokens = None
response_prompt_tokens = None
response_completion_tokens = None
response_total_tokens = None

# Calculate token counts by checking if a callback is registered and if we have the necessary content to pass
# to it. If not, then we use the token counts provided in the response object
prompt_tokens = (
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
if settings.ai_monitoring.llm_token_count_callback and input_message_content
else response_prompt_tokens
)
completion_tokens = (
settings.ai_monitoring.llm_token_count_callback(response_model, output_message_content)
if settings.ai_monitoring.llm_token_count_callback and output_message_content
else response_completion_tokens
)
total_tokens = (
prompt_tokens + completion_tokens if all([prompt_tokens, completion_tokens]) else response_total_tokens
)

all_token_counts = bool(prompt_tokens and completion_tokens and total_tokens)

# Extract generation config
request_temperature, request_max_tokens = _extract_generation_config(kwargs)

full_chat_completion_summary_dict = {
"id": completion_id,
Expand All @@ -438,66 +476,78 @@ def _record_generation_success(transaction, linking_metadata, completion_id, kwa
"response.number_of_messages": 1 + len(output_message_list),
}

if all_token_counts:
full_chat_completion_summary_dict["response.usage.prompt_tokens"] = prompt_tokens
full_chat_completion_summary_dict["response.usage.completion_tokens"] = completion_tokens
full_chat_completion_summary_dict["response.usage.total_tokens"] = total_tokens

llm_metadata = _get_llm_attributes(transaction)
full_chat_completion_summary_dict.update(llm_metadata)
transaction.record_custom_event("LlmChatCompletionSummary", full_chat_completion_summary_dict)

create_chat_completion_message_event(
transaction,
input_message,
input_message_content,
input_role,
completion_id,
span_id,
trace_id,
response_model,
request_model,
llm_metadata,
output_message_list,
all_token_counts,
)
except Exception:
_logger.warning(RECORD_EVENTS_FAILURE_LOG_MESSAGE, exc_info=True)


def _parse_input_message(input_message):
# The input_message will be a string if generate_content was called directly. In this case, we don't have
# access to the role, so we default to user since this was an input message
if isinstance(input_message, str):
return input_message, "user"
# The input_message will be a Google Content type if send_message was called, so we parse out the message
# text and role (which should be "user")
elif isinstance(input_message, google.genai.types.Content):
return input_message.parts[0].text, input_message.role
else:
return None, None


def _extract_generation_config(kwargs):
generation_config = kwargs.get("config")
if generation_config:
request_temperature = getattr(generation_config, "temperature", None)
request_max_tokens = getattr(generation_config, "max_output_tokens", None)
else:
request_temperature = None
request_max_tokens = None

return request_temperature, request_max_tokens


def create_chat_completion_message_event(
transaction,
input_message,
input_message_content,
input_role,
chat_completion_id,
span_id,
trace_id,
response_model,
request_model,
llm_metadata,
output_message_list,
all_token_counts,
):
try:
settings = transaction.settings or global_settings()

if input_message:
# The input_message will be a string if generate_content was called directly. In this case, we don't have
# access to the role, so we default to user since this was an input message
if isinstance(input_message, str):
input_message_content = input_message
input_role = "user"
# The input_message will be a Google Content type if send_message was called, so we parse out the message
# text and role (which should be "user")
elif isinstance(input_message, google.genai.types.Content):
input_message_content = input_message.parts[0].text
input_role = input_message.role
# Set input data to NoneTypes to ensure token_count callback is not called
else:
input_message_content = None
input_role = None

if input_message_content:
message_id = str(uuid.uuid4())

chat_completion_input_message_dict = {
"id": message_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(request_model, input_message_content)
if settings.ai_monitoring.llm_token_count_callback and input_message_content
else None
),
"role": input_role,
"completion_id": chat_completion_id,
# The input message will always be the first message in our request/ response sequence so this will
Expand All @@ -507,6 +557,8 @@ def create_chat_completion_message_event(
"vendor": "gemini",
"ingest_source": "Python",
}
if all_token_counts:
chat_completion_input_message_dict["token_count"] = 0

if settings.ai_monitoring.record_content.enabled:
chat_completion_input_message_dict["content"] = input_message_content
Expand All @@ -523,7 +575,7 @@ def create_chat_completion_message_event(

# Add one to the index to account for the single input message so our sequence value is accurate for
# the output message
if input_message:
if input_message_content:
index += 1

message_id = str(uuid.uuid4())
Expand All @@ -532,11 +584,6 @@ def create_chat_completion_message_event(
"id": message_id,
"span_id": span_id,
"trace_id": trace_id,
"token_count": (
settings.ai_monitoring.llm_token_count_callback(response_model, message_content)
if settings.ai_monitoring.llm_token_count_callback
else None
),
"role": message.get("role"),
"completion_id": chat_completion_id,
"sequence": index,
Expand All @@ -546,6 +593,9 @@ def create_chat_completion_message_event(
"is_response": True,
}

if all_token_counts:
chat_completion_output_message_dict["token_count"] = 0

if settings.ai_monitoring.record_content.enabled:
chat_completion_output_message_dict["content"] = message_content

Expand Down
6 changes: 3 additions & 3 deletions tests/mlmodel_gemini/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import google.genai
from testing_support.fixtures import override_llm_token_callback_settings, reset_core_stats_engine, validate_attributes
from testing_support.ml_testing_utils import (
add_token_count_to_events,
add_token_count_to_embedding_events,
disabled_ai_monitoring_record_content_settings,
disabled_ai_monitoring_settings,
events_sans_content,
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_gemini_embedding_sync_no_content(gemini_dev_client, set_trace_info):

@reset_core_stats_engine()
@override_llm_token_callback_settings(llm_token_count_callback)
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
@validate_custom_event_count(count=1)
@validate_transaction_metrics(
name="test_embeddings:test_gemini_embedding_sync_with_token_count",
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_gemini_embedding_async_no_content(gemini_dev_client, loop, set_trace_in

@reset_core_stats_engine()
@override_llm_token_callback_settings(llm_token_count_callback)
@validate_custom_events(add_token_count_to_events(embedding_recorded_events))
@validate_custom_events(add_token_count_to_embedding_events(embedding_recorded_events))
@validate_custom_event_count(count=1)
@validate_transaction_metrics(
name="test_embeddings:test_gemini_embedding_async_with_token_count",
Expand Down
Loading
Loading