diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 63b95e57ea..ed2a6dd697 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -28,9 +28,9 @@ from google.api_core.gapic_v1 import client_info as gapic_client_info import google.auth from google.cloud import bigquery +from google.cloud import bigquery_storage_v1 from google.cloud.bigquery import schema as bq_schema from google.cloud.bigquery_storage_v1 import types as bq_storage_types -from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient from google.genai import types import pyarrow as pa @@ -85,7 +85,7 @@ def _pyarrow_timestamp(): "GEOGRAPHY": pa.string, "INT64": pa.int64, "INTEGER": pa.int64, - "JSON": pa.string, + "JSON": pa.string, # JSON is passed as string to Arrow "NUMERIC": _pyarrow_numeric, "BIGNUMERIC": _pyarrow_bignumeric, "STRING": pa.string, @@ -221,13 +221,39 @@ class BigQueryLoggerConfig: enabled: bool = True event_allowlist: Optional[List[str]] = None event_denylist: Optional[List[str]] = None - content_formatter: Optional[Callable[[Any], str]] = None + # Custom formatter is discouraged now that we use JSON, but kept for compat + content_formatter: Optional[Callable[[dict], dict]] = None shutdown_timeout: float = 5.0 client_close_timeout: float = 2.0 - max_content_length: int = 500 + # Increased default limit to 50KB since we truncate per-field, not per-row + max_content_length: int = 50000 + + +def _recursive_smart_truncate(obj: Any, max_len: int) -> Any: + """Recursively truncates string values within a dict or list.""" + if isinstance(obj, str): + if len(obj) > max_len: + return obj[:max_len] + "...[TRUNCATED]" + return obj + elif isinstance(obj, dict): + return {k: _recursive_smart_truncate(v, max_len) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return type(obj)(_recursive_smart_truncate(i, max_len) for i in obj) + else: + return obj + + +def _serialize_to_json_safe(content_obj: Any, max_len: int) -> str: + """Safely serializes an object to a JSON string with smart truncation.""" + try: + truncated_obj = _recursive_smart_truncate(content_obj, max_len) + # default=str handles datetime or other non-serializable types by converting to string + return json.dumps(truncated_obj, default=str) + except Exception as e: + logging.warning(f"JSON serialization failed: {e}") + return json.dumps({"error": "Serialization failed", "details": str(e)}) -# --- Helper Formatters --- def _get_event_type(event: Event) -> str: """Determines the event type from an Event object.""" if event.author == "user": @@ -243,66 +269,8 @@ def _get_event_type(event: Event) -> str: return "SYSTEM" -def _format_content( - content: Optional[types.Content], max_len: int = 500 -) -> tuple[str, bool]: - """Formats an Event content for logging. - - Args: - content: The Event content to format. - max_len: The maximum length of the text parts before truncation. - - Returns: - A tuple containing the formatted content string and a boolean indicating if - the content was truncated. - """ - if not content or not content.parts: - return "None", False - parts = [] - for p in content.parts: - if p.text: - parts.append( - f"text: '{p.text[:max_len]}...' " - if len(p.text) > max_len - else f"text: '{p.text}'" - ) - elif p.function_call: - parts.append(f"call: {p.function_call.name}") - elif p.function_response: - parts.append(f"resp: {p.function_response.name}") - else: - parts.append("other") - return " | ".join(parts), any( - len(p.text) > max_len for p in content.parts if p.text - ) - - -def _format_args( - args: dict[str, Any], *, max_len: int = 1000 -) -> tuple[str, bool]: - """Formats tool arguments or results for logging. - - Args: - args: The tool arguments or results dictionary to format. - max_len: The maximum length of the output string before truncation. - - Returns: - A tuple containing the JSON formatted string and a boolean indicating if - the content was truncated. - """ - if not args: - return "{}", False - try: - s = json.dumps(args) - except TypeError: - s = str(args) - if len(s) > max_len: - return s[:max_len] + "...", True - return s, False - - class BigQueryAgentAnalyticsPlugin(BasePlugin): - """A plugin that logs agent analytic events to Google BigQuery. + """A plugin that logs agent analytic events to Google BigQuery (Structured JSON). This plugin captures key events during an agent's lifecycle—such as user interactions, tool executions, LLM requests/responses, and errors—and @@ -340,11 +308,17 @@ def __init__( ) self._config = config if config else BigQueryLoggerConfig() self._bq_client: bigquery.Client | None = None - self._write_client: BigQueryWriteAsyncClient | None = None + # Type alias update: Use the class from the top-level package import + self._write_client: ( + bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient + | None + ) = None self._init_lock: asyncio.Lock | None = None self._arrow_schema: pa.Schema | None = None self._background_tasks: set[asyncio.Task] = set() self._is_shutting_down = False + + # --- Updated Schema: Content is now JSON --- self._schema = [ bigquery.SchemaField( "timestamp", @@ -356,90 +330,47 @@ def __init__( "event_type", "STRING", mode="NULLABLE", - description=( - "Indicates the type of event being logged (e.g., 'LLM_REQUEST'," - " 'TOOL_COMPLETED')." - ), + description="Indicates the type of event (e.g., 'LLM_REQUEST').", ), bigquery.SchemaField( "agent", "STRING", mode="NULLABLE", - description=( - "The name of the ADK agent or author associated with the event." - ), + description="The name of the ADK agent.", ), bigquery.SchemaField( "session_id", "STRING", mode="NULLABLE", - description=( - "A unique identifier to group events within a single" - " conversation or user session." - ), + description="Unique identifier for the session.", ), bigquery.SchemaField( "invocation_id", "STRING", mode="NULLABLE", - description=( - "A unique identifier for each individual agent execution or" - " turn within a session." - ), + description="Unique identifier for the invocation/turn.", ), bigquery.SchemaField( "user_id", "STRING", mode="NULLABLE", - description=( - "The identifier of the user associated with the current" - " session." - ), + description="The user identifier.", ), + # CHANGED: STRING -> JSON bigquery.SchemaField( "content", - "STRING", + "JSON", mode="NULLABLE", - description=( - "The event-specific data (payload). Format varies by" - " event_type." - ), + description="Structured event payload.", ), bigquery.SchemaField( "error_message", "STRING", mode="NULLABLE", - description=( - "Populated if an error occurs during the processing of the" - " event." - ), - ), - bigquery.SchemaField( - "is_truncated", - "BOOLEAN", - mode="NULLABLE", - description=( - "Indicates if the content field was truncated due to size" - " limits." - ), + description="Error details if applicable.", ), ] - def _format_content_safely( - self, content: Optional[types.Content] - ) -> tuple[str | None, bool]: - """Formats content using self._config.content_formatter or _format_content, catching errors.""" - if content is None: - return None, False - try: - if self._config.content_formatter: - # Custom formatter: we assume no truncation or we can't know. - return self._config.content_formatter(content), False - return _format_content(content, max_len=self._config.max_content_length) - except Exception as e: - logging.warning("Content formatter failed: %s", e) - return "[FORMATTING FAILED]", False - async def _ensure_init(self): """Ensures BigQuery clients are initialized.""" if self._write_client: @@ -461,7 +392,6 @@ async def _ensure_init(self): project=self._project_id, credentials=creds, client_info=client_info ) - # Ensure table exists (sync call in thread) def create_resources(): if self._bq_client: self._bq_client.create_dataset(self._dataset_id, exists_ok=True) @@ -482,21 +412,21 @@ def create_resources(): await asyncio.to_thread(create_resources) - self._write_client = BigQueryWriteAsyncClient( + # Fix: Use the top-level package import to avoid "cli" substring in path + self._write_client = bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient( credentials=creds, client_info=client_info, ) self._arrow_schema = to_arrow_schema(self._schema) if not self._arrow_schema: raise RuntimeError("Failed to convert BigQuery schema to Arrow.") - logging.info("BQ Plugin: Initialized successfully.") return True except Exception as e: logging.error("BQ Plugin: Init Failed:", exc_info=True) return False async def _perform_write(self, row: dict): - """Actual async write operation, intended to run as a background task.""" + """Actual async write operation.""" try: if ( not await self._ensure_init() @@ -505,7 +435,6 @@ async def _perform_write(self, row: dict): ): return - # Serialize pydict = {f.name: [row.get(f.name)] for f in self._arrow_schema} batch = pa.RecordBatch.from_pydict(pydict, schema=self._arrow_schema) req = bq_storage_types.AppendRowsRequest( @@ -518,22 +447,20 @@ async def _perform_write(self, row: dict): batch.serialize().to_pybytes() ) - # Write with protection against immediate cancellation async for resp in await asyncio.shield( self._write_client.append_rows(iter([req])) ): if resp.error.code != 0: msg = resp.error.message - # Check for common schema mismatch indicators if ( "schema mismatch" in msg.lower() or "field" in msg.lower() or "type" in msg.lower() ): logging.error( - "BQ Plugin: Schema Mismatch Error. The BigQuery table schema" - " may be incorrect or out of sync with the plugin. Please" - " verify the table definition. Details: %s", + "BQ Plugin: Schema Mismatch. You may need to delete the" + " existing table if you migrated from STRING content to JSON" + " content. Details: %s", msg, ) else: @@ -548,10 +475,16 @@ async def _perform_write(self, row: dict): except Exception as e: logging.error("BQ Plugin: Write Failed:", exc_info=True) - async def _log(self, data: dict): - """Schedules a log entry to be written in the background.""" + async def _log(self, data: dict, content_payload: Any = None): + """ + Schedules a log entry. + Args: + data: Metadata dict (event_type, agent, etc.) + content_payload: The structured data to be JSON serialized. + """ if not self._config.enabled: return + event_type = data.get("event_type") if ( self._config.event_denylist @@ -564,7 +497,24 @@ async def _log(self, data: dict): ): return - # Prepare row immediately (capture current state) + # If a custom formatter/redactor is provided, let it modify the payload + # BEFORE we truncate and serialize it. + if self._config.content_formatter and content_payload is not None: + try: + # The formatter now receives a Dict and should return a Dict + content_payload = self._config.content_formatter(content_payload) + except Exception as e: + logging.warning(f"Content formatter failed: {e}") + # Fallback: keep original payload but log the error + + # Prepare payload + content_json_str = None + if content_payload is not None: + # Use smart truncation to keep JSON valid but safe size + content_json_str = _serialize_to_json_safe( + content_payload, self._config.max_content_length + ) + row = { "timestamp": datetime.now(timezone.utc), "event_type": None, @@ -572,13 +522,11 @@ async def _log(self, data: dict): "session_id": None, "invocation_id": None, "user_id": None, - "content": None, + "content": content_json_str, # Injected here "error_message": None, - "is_truncated": False, } row.update(data) - # Fire and forget: Create task and track it task = asyncio.create_task(self._perform_write(row)) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) @@ -603,7 +551,6 @@ async def close(self): except Exception as e: logging.warning("BQ Plugin: Error flushing logs:", exc_info=True) - # Use getattr for safe access in case transport is not present. if self._write_client and getattr(self._write_client, "transport", None): try: logging.info("BQ Plugin: Closing write client.") @@ -613,6 +560,7 @@ async def close(self): ) except Exception as e: logging.warning("BQ Plugin: Error closing write client: %s", e) + pass if self._bq_client: try: self._bq_client.close() @@ -624,7 +572,8 @@ async def close(self): self._is_shutting_down = False logging.info("BQ Plugin: Shutdown complete.") - # --- Streamlined Callbacks --- + # --- Refactored Callbacks using Structured Data --- + async def on_user_message_callback( self, *, @@ -636,19 +585,27 @@ async def on_user_message_callback( Logs the user message details including: 1. User content (text) - The content is formatted as 'User Content: {content}'. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing the user text. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - content, truncated = self._format_content_safely(user_message) - await self._log({ - "event_type": "USER_MESSAGE_RECEIVED", - "agent": invocation_context.agent.name, - "session_id": invocation_context.session.id, - "invocation_id": invocation_context.invocation_id, - "user_id": invocation_context.session.user_id, - "content": f"User Content: {content}", - "is_truncated": truncated, - }) + # Extract text parts + text_content = "" + if user_message and user_message.parts: + text_content = " ".join([p.text for p in user_message.parts if p.text]) + + payload = {"text": text_content if text_content else None} + + await self._log( + { + "event_type": "USER_MESSAGE_RECEIVED", + "agent": invocation_context.agent.name, + "session_id": invocation_context.session.id, + "invocation_id": invocation_context.invocation_id, + "user_id": invocation_context.session.user_id, + }, + content_payload=payload, + ) async def before_run_callback( self, *, invocation_context: InvocationContext @@ -665,7 +622,7 @@ async def before_run_callback( "session_id": invocation_context.session.id, "invocation_id": invocation_context.invocation_id, "user_id": invocation_context.session.user_id, - }) + }) # No content payload needed async def on_event_callback( self, *, invocation_context: InvocationContext, event: Event @@ -676,22 +633,65 @@ async def on_event_callback( 1. Event type (determined from event properties) 2. Event content (text, function calls, or responses) 3. Error messages (if any) - - The content is formatted based on the event type. - If the content length exceeds `max_content_length`, it is truncated. """ - content, truncated = self._format_content_safely(event.content) - await self._log({ - "event_type": _get_event_type(event), - "agent": event.author, - "session_id": invocation_context.session.id, - "invocation_id": invocation_context.invocation_id, - "user_id": invocation_context.session.user_id, - "content": content, - "error_message": event.error_message, - "timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc), - "is_truncated": truncated, - }) + # Rename 'text_parts' to 'content_parts' since it holds dicts now + content_parts = [] + + # tool_calls and tool_responses might still be useful as separate summaries, + # or you can rely entirely on content_parts. keeping them for now: + tool_calls = [] + tool_responses = [] + + if event.content and event.content.parts: + for p in event.content.parts: + if p.text: + content_parts.append({"type": "text", "text": p.text}) + elif p.function_call: + content_parts.append({ + "type": "function_call", + "name": p.function_call.name, + "args": dict(p.function_call.args), + }) + # Optional: keep filling this if you want the high-level summary list + tool_calls.append(p.function_call.name) + elif p.function_response: + content_parts.append( + {"type": "function_response", "name": p.function_response.name} + ) + # Optional: keep filling this if you want the high-level summary list + tool_responses.append(p.function_response.name) + elif p.inline_data: + content_parts.append({ + "type": "inline_data", + "mime_type": p.inline_data.mime_type, + }) + elif p.file_data: + content_parts.append({ + "type": "file_data", + "mime_type": p.file_data.mime_type, + "file_uri": p.file_data.file_uri, + }) + + payload = { + # CHANGED: Do not join. Store the list of dicts. + "content_parts": content_parts if content_parts else None, + "tool_calls": tool_calls if tool_calls else None, + "tool_responses": tool_responses if tool_responses else None, + "raw_role": event.author if event.author else None, + } + + await self._log( + { + "event_type": _get_event_type(event), + "agent": event.author, + "session_id": invocation_context.session.id, + "invocation_id": invocation_context.invocation_id, + "user_id": invocation_context.session.user_id, + "error_message": event.error_message, + "timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc), + }, + content_payload=payload, + ) async def after_run_callback( self, *, invocation_context: InvocationContext @@ -719,14 +719,16 @@ async def before_agent_callback( Content includes: 1. Agent Name (from callback context) """ - await self._log({ - "event_type": "AGENT_STARTING", - "agent": agent.name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": f"Agent Name: {callback_context.agent_name}", - }) + await self._log( + { + "event_type": "AGENT_STARTING", + "agent": agent.name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + }, + content_payload={"target_agent": callback_context.agent_name}, + ) async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext @@ -737,14 +739,16 @@ async def after_agent_callback( Content includes: 1. Agent Name (from callback context) """ - await self._log({ - "event_type": "AGENT_COMPLETED", - "agent": agent.name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": f"Agent Name: {callback_context.agent_name}", - }) + await self._log( + { + "event_type": "AGENT_COMPLETED", + "agent": agent.name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + }, + content_payload={"target_agent": callback_context.agent_name}, + ) async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest @@ -758,63 +762,34 @@ async def before_model_callback( 4. Prompt content (user/model messages) 5. System instructions - The content is formatted as a single string with fields separated by ' | '. - If the total length exceeds `max_content_length`, the string is truncated, - prioritizing the metadata (Model, Params, Tools) over the Prompt and System - Prompt. + The content is formatted as a structured JSON object. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - content_parts = [ - f"Model: {llm_request.model or 'default'}", - ] - is_truncated = False - # 1. Params + # 1. Config Params + params = {} if llm_request.config: - config = llm_request.config - params_to_log = {} - if hasattr(config, "temperature") and config.temperature is not None: - params_to_log["temperature"] = config.temperature - if hasattr(config, "top_p") and config.top_p is not None: - params_to_log["top_p"] = config.top_p - if hasattr(config, "top_k") and config.top_k is not None: - params_to_log["top_k"] = config.top_k - if ( - hasattr(config, "max_output_tokens") - and config.max_output_tokens is not None - ): - params_to_log["max_output_tokens"] = config.max_output_tokens - - if params_to_log: - params_str = ", ".join([f"{k}={v}" for k, v in params_to_log.items()]) - content_parts.append(f"Params: {{{params_str}}}") - - # 2. Tools - if llm_request.tools_dict: - content_parts.append( - f"Available Tools: {list(llm_request.tools_dict.keys())}" - ) - - # 3. Prompt - if contents := getattr(llm_request, "contents", None): - prompt_parts = [] - for c in contents: - c_str, c_trunc = self._format_content_safely(c) - prompt_parts.append(f"{c.role}: {c_str}") - if c_trunc: - is_truncated = True - prompt_str = " | ".join(prompt_parts) - content_parts.append(f"Prompt: {prompt_str}") - - # 4. System Prompt - system_instruction_text = "None" - if llm_request.config and llm_request.config.system_instruction: + cfg = llm_request.config + if getattr(cfg, "temperature", None) is not None: + params["temperature"] = cfg.temperature + if getattr(cfg, "top_p", None) is not None: + params["top_p"] = cfg.top_p + if getattr(cfg, "top_k", None) is not None: + params["top_k"] = cfg.top_k + if getattr(cfg, "max_output_tokens", None) is not None: + params["max_output_tokens"] = cfg.max_output_tokens + + # 2. System Instruction + system_instr = None + if llm_request.config and llm_request.config.system_instruction is not None: si = llm_request.config.system_instruction if isinstance(si, str): - system_instruction_text = si + system_instr = si elif isinstance(si, types.Content): - system_instruction_text = "".join(p.text for p in si.parts if p.text) + system_instr = "".join(p.text for p in si.parts if p.text) elif isinstance(si, types.Part): - system_instruction_text = si.text + system_instr = si.text elif hasattr(si, "__iter__"): texts = [] for item in si: @@ -822,28 +797,53 @@ async def before_model_callback( texts.append(item) elif isinstance(item, types.Part) and item.text: texts.append(item.text) - system_instruction_text = "".join(texts) + system_instr = "".join(texts) else: - system_instruction_text = str(si) - elif llm_request.config and not llm_request.config.system_instruction: - system_instruction_text = "Empty" - - content_parts.append(f"System Prompt: {system_instruction_text}") + system_instr = str(si) + + # 3. Prompt History (Simplified structure for JSON) + prompt_history = [] + if getattr(llm_request, "contents", None): + for c in llm_request.contents: + role = c.role + parts_list = [] + for p in c.parts: + if p.text: + parts_list.append({"type": "text", "text": p.text}) + elif p.function_call: + parts_list.append({ + "type": "function_call", + "name": p.function_call.name, + "args": dict(p.function_call.args), + }) + elif p.function_response: + parts_list.append( + {"type": "function_response", "name": p.function_response.name} + ) + prompt_history.append({"role": role, "parts": parts_list}) + + payload = { + "model": llm_request.model or "default", + "params": params if params else None, + "tools_available": ( + list(llm_request.tools_dict.keys()) + if llm_request.tools_dict + else None + ), + "system_instruction": system_instr, + "prompt": prompt_history if prompt_history else None, + } - final_content = " | ".join(content_parts) - max_len = self._config.max_content_length - if len(final_content) > max_len: - final_content = final_content[:max_len] + "..." - is_truncated = True - await self._log({ - "event_type": "LLM_REQUEST", - "agent": callback_context.agent_name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": final_content, - "is_truncated": is_truncated, - }) + await self._log( + { + "event_type": "LLM_REQUEST", + "agent": callback_context.agent_name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + }, + content_payload=payload, + ) async def after_model_callback( self, *, callback_context: CallbackContext, llm_response: LlmResponse @@ -855,60 +855,53 @@ async def after_model_callback( 2. Text response (if no tool calls) 3. Token usage statistics (prompt, candidates, total) - The content is formatted as a single string with fields separated by ' | '. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing response parts + and usage statistics. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ content_parts = [] - content = llm_response.content - is_tool_call = False - is_truncated = False - if content and content.parts: - is_tool_call = any(part.function_call for part in content.parts) - - if is_tool_call: - fc_names = [] - if content and content.parts: - fc_names = [ - part.function_call.name - for part in content.parts - if part.function_call - ] - content_parts.append(f"Tool Name: {', '.join(fc_names)}") - else: - text_content, truncated = self._format_content_safely( - llm_response.content - ) - content_parts.append(f"Tool Name: text_response, {text_content}") - if truncated: - is_truncated = True - + if llm_response.content and llm_response.content.parts: + for p in llm_response.content.parts: + if p.text: + content_parts.append({"type": "text", "text": p.text}) + elif p.function_call: + content_parts.append({ + "type": "function_call", + "name": p.function_call.name, + "args": dict(p.function_call.args), + }) + + usage = {} if llm_response.usage_metadata: - prompt_tokens = getattr( - llm_response.usage_metadata, "prompt_token_count", "N/A" - ) - candidates_tokens = getattr( - llm_response.usage_metadata, "candidates_token_count", "N/A" - ) - total_tokens = getattr( - llm_response.usage_metadata, "total_token_count", "N/A" - ) - token_usage_str = ( - f"Token Usage: {{prompt: {prompt_tokens}, candidates:" - f" {candidates_tokens}, total: {total_tokens}}}" - ) - content_parts.append(token_usage_str) + usage = { + "prompt_tokens": getattr( + llm_response.usage_metadata, "prompt_token_count", 0 + ), + "candidates_tokens": getattr( + llm_response.usage_metadata, "candidates_token_count", 0 + ), + "total_tokens": getattr( + llm_response.usage_metadata, "total_token_count", 0 + ), + } + + payload = { + "response_content": content_parts if content_parts else None, + "usage": usage if usage else None, + } - final_content = " | ".join(content_parts) - await self._log({ - "event_type": "LLM_RESPONSE", - "agent": callback_context.agent_name, - "session_id": callback_context.session.id, - "invocation_id": callback_context.invocation_id, - "user_id": callback_context.session.user_id, - "content": final_content, - "error_message": llm_response.error_message, - "is_truncated": is_truncated, - }) + await self._log( + { + "event_type": "LLM_RESPONSE", + "agent": callback_context.agent_name, + "session_id": callback_context.session.id, + "invocation_id": callback_context.invocation_id, + "user_id": callback_context.session.user_id, + "error_message": llm_response.error_message, + }, + content_payload=payload, + ) async def before_tool_callback( self, @@ -924,29 +917,26 @@ async def before_tool_callback( 2. Tool description 3. Tool arguments - The content is formatted as 'Tool Name: ..., Description: ..., Arguments: - ...'. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing tool name, + description, and arguments. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - args_str, truncated = _format_args( - tool_args, max_len=self._config.max_content_length - ) - content = ( - f"Tool Name: {tool.name}, Description: {tool.description}," - f" Arguments: {args_str}" + payload = { + "tool_name": tool.name if tool.name else None, + "description": tool.description if tool.description else None, + "arguments": tool_args if tool_args else None, + } + await self._log( + { + "event_type": "TOOL_STARTING", + "agent": tool_context.agent_name, + "session_id": tool_context.session.id, + "invocation_id": tool_context.invocation_id, + "user_id": tool_context.session.user_id, + }, + content_payload=payload, ) - if len(content) > self._config.max_content_length: - content = content[: self._config.max_content_length] + "..." - truncated = True - await self._log({ - "event_type": "TOOL_STARTING", - "agent": tool_context.agent_name, - "session_id": tool_context.session.id, - "invocation_id": tool_context.invocation_id, - "user_id": tool_context.session.user_id, - "content": content, - "is_truncated": truncated, - }) async def after_tool_callback( self, @@ -962,25 +952,24 @@ async def after_tool_callback( 1. Tool name 2. Tool result - The content is formatted as 'Tool Name: ..., Result: ...'. - If the content length exceeds `max_content_length`, it is truncated. + The content is formatted as a structured JSON object containing tool name and result. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - result_str, truncated = _format_args( - result, max_len=self._config.max_content_length + payload = { + "tool_name": tool.name if tool.name else None, + "result": result if result else None, + } + await self._log( + { + "event_type": "TOOL_COMPLETED", + "agent": tool_context.agent_name, + "session_id": tool_context.session.id, + "invocation_id": tool_context.invocation_id, + "user_id": tool_context.session.user_id, + }, + content_payload=payload, ) - content = f"Tool Name: {tool.name}, Result: {result_str}" - if len(content) > self._config.max_content_length: - content = content[: self._config.max_content_length] + "..." - truncated = True - await self._log({ - "event_type": "TOOL_COMPLETED", - "agent": tool_context.agent_name, - "session_id": tool_context.session.id, - "invocation_id": tool_context.invocation_id, - "user_id": tool_context.session.user_id, - "content": content, - "is_truncated": truncated, - }) async def on_model_error_callback( self, @@ -1019,23 +1008,23 @@ async def on_tool_error_callback( 1. Tool name 2. Tool arguments + The content is formatted as a structured JSON object containing tool name and arguments. The error message is captured in the `error_message` field. - If the content length exceeds `max_content_length`, it is truncated. + If individual string fields exceed `max_content_length`, they are truncated + to preserve the valid JSON structure. """ - args_str, truncated = _format_args( - tool_args, max_len=self._config.max_content_length + payload = { + "tool_name": tool.name if tool.name else None, + "arguments": tool_args if tool_args else None, + } + await self._log( + { + "event_type": "TOOL_ERROR", + "agent": tool_context.agent_name, + "session_id": tool_context.session.id, + "invocation_id": tool_context.invocation_id, + "user_id": tool_context.session.user_id, + "error_message": str(error), + }, + content_payload=payload, ) - content = f"Tool Name: {tool.name}, Arguments: {args_str}" - if len(content) > self._config.max_content_length: - content = content[: self._config.max_content_length] + "..." - truncated = True - await self._log({ - "event_type": "TOOL_ERROR", - "agent": tool_context.agent_name, - "session_id": tool_context.session.id, - "invocation_id": tool_context.invocation_id, - "user_id": tool_context.session.user_id, - "content": content, - "error_message": str(error), - "is_truncated": truncated, - }) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 6f0412dbbd..6790faf680 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -34,7 +34,6 @@ from google.auth import exceptions as auth_exceptions import google.auth.credentials from google.cloud import bigquery -from google.cloud.bigquery_storage_v1 import types as bq_storage_types from google.genai import types import pyarrow as pa import pytest @@ -124,8 +123,10 @@ def mock_bq_client(): @pytest.fixture def mock_write_client(): - with mock.patch.object( - bigquery_agent_analytics_plugin, "BigQueryWriteAsyncClient", autospec=True + # Updated patch path to match the new import structure in src + with mock.patch( + "google.cloud.bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient", + autospec=True, ) as mock_cls: mock_client = mock_cls.return_value mock_client.transport = mock.AsyncMock() @@ -136,7 +137,6 @@ async def fake_append_rows(requests, **kwargs): mock_append_rows_response.row_errors = [] mock_append_rows_response.error = mock.MagicMock() mock_append_rows_response.error.code = 0 # OK status - # This a gen is what's returned *after* the await. return _async_gen(mock_append_rows_response) mock_client.append_rows.side_effect = fake_append_rows @@ -145,6 +145,7 @@ async def fake_append_rows(requests, **kwargs): @pytest.fixture def dummy_arrow_schema(): + # content is pa.string() because JSON is serialized to string before Arrow return pa.schema([ pa.field("timestamp", pa.timestamp("us", tz="UTC"), nullable=False), pa.field("event_type", pa.string(), nullable=True), @@ -154,7 +155,6 @@ def dummy_arrow_schema(): pa.field("user_id", pa.string(), nullable=True), pa.field("content", pa.string(), nullable=True), pa.field("error_message", pa.string(), nullable=True), - pa.field("is_truncated", pa.bool_(), nullable=True), ]) @@ -234,7 +234,6 @@ def _assert_common_fields(log_entry, event_type, agent="MyTestAgent"): assert log_entry["user_id"] == "user-456" assert "timestamp" in log_entry assert isinstance(log_entry["timestamp"], datetime.datetime) - assert "is_truncated" in log_entry # --- Test Class --- @@ -257,12 +256,14 @@ async def test_plugin_disabled( table_id=TABLE_ID, config=config, ) - # user_message = types.Content(parts=[types.Part(text="Test")]) await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) + # Wait for background tasks + await plugin.close() + mock_auth_default.assert_not_called() mock_bq_client.assert_not_called() mock_write_client.append_rows.assert_not_called() @@ -293,15 +294,27 @@ async def test_event_allowlist( await plugin.before_model_callback( callback_context=callback_context, llm_request=llm_request ) - await asyncio.sleep(0.01) # Allow background task to run + await plugin.close() # Wait for write mock_write_client.append_rows.assert_called_once() mock_write_client.append_rows.reset_mock() + # REFACTOR: Use a fresh plugin instance for the denied case + plugin_denied = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, TABLE_ID, config + ) + ) + await plugin_denied._ensure_init() + # Inject the same mock_write_client + plugin_denied._write_client = mock_write_client + plugin_denied._arrow_schema = plugin._arrow_schema + user_message = types.Content(parts=[types.Part(text="What is up?")]) - await plugin.on_user_message_callback( + await plugin_denied.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) # Allow background task to run + # Since it's denied, no task is created. close() would wait if there was one. + await plugin_denied.close() mock_write_client.append_rows.assert_not_called() @pytest.mark.asyncio @@ -326,45 +339,80 @@ async def test_event_denylist( await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) + await plugin.close() mock_write_client.append_rows.assert_not_called() - await plugin.before_run_callback(invocation_context=invocation_context) - await asyncio.sleep(0.01) + # REFACTOR: Use a fresh plugin instance for the allowed case + plugin_allowed = ( + bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( + PROJECT_ID, DATASET_ID, TABLE_ID, config + ) + ) + await plugin_allowed._ensure_init() + # Inject the same mock_write_client + plugin_allowed._write_client = mock_write_client + plugin_allowed._arrow_schema = plugin._arrow_schema + + await plugin_allowed.before_run_callback( + invocation_context=invocation_context + ) + await plugin_allowed.close() mock_write_client.append_rows.assert_called_once() @pytest.mark.asyncio - async def test_content_formatter( + async def test_content_formatter_payload_mutation( self, mock_write_client, - invocation_context, + callback_context, mock_auth_default, mock_bq_client, mock_to_arrow_schema, dummy_arrow_schema, mock_asyncio_to_thread, ): - def redact_content(content): - return "[REDACTED]" - - config = BigQueryLoggerConfig(content_formatter=redact_content) + """Tests a formatter that modifies the JSON structure (Pruning & Normalization).""" + + def mutate_payload(data): + if isinstance(data, dict): + # 1. Pruning: Remove system_instruction + if "system_instruction" in data: + del data["system_instruction"] + # 2. Normalization: Uppercase model name + if "model" in data and isinstance(data["model"], str): + data["model"] = data["model"].upper() + return data + + config = BigQueryLoggerConfig(content_formatter=mutate_payload) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) await plugin._ensure_init() mock_write_client.append_rows.reset_mock() - user_message = types.Content(parts=[types.Part(text="Secret message")]) - await plugin.on_user_message_callback( - invocation_context=invocation_context, user_message=user_message + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + system_instruction=types.Content(parts=[types.Part(text="Sys")]) + ), + contents=[types.Content(role="user", parts=[types.Part(text="User")])], ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() + + await plugin.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - assert log_entry["content"] == "User Content: [REDACTED]" + + # Parse JSON + content = json.loads(log_entry["content"]) + + # Verify mutation + assert "system_instruction" not in content + assert content["model"] == "GEMINI-PRO" + assert content["prompt"][0]["role"] == "user" @pytest.mark.asyncio - async def test_content_formatter_error( + async def test_content_formatter_error_fallback( self, mock_write_client, invocation_context, @@ -374,7 +422,9 @@ async def test_content_formatter_error( dummy_arrow_schema, mock_asyncio_to_thread, ): - def error_formatter(content): + """Tests that if content_formatter fails, the original payload is used.""" + + def error_formatter(data): raise ValueError("Formatter failed") config = BigQueryLoggerConfig(content_formatter=error_formatter) @@ -384,17 +434,23 @@ def error_formatter(content): await plugin._ensure_init() mock_write_client.append_rows.reset_mock() - user_message = types.Content(parts=[types.Part(text="Secret message")]) + user_message = types.Content(parts=[types.Part(text="Original message")]) + + # This triggers the log. Internal logic catches exception and proceeds. await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) + await plugin.close() + mock_write_client.append_rows.assert_called_once() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - assert log_entry["content"] == "User Content: [FORMATTING FAILED]" + + # Verify that despite the error, we still got the original data + content = json.loads(log_entry["content"]) + assert content["text"] == "Original message" @pytest.mark.asyncio - async def test_max_content_length( + async def test_max_content_length_smart_truncation( self, mock_write_client, invocation_context, @@ -405,7 +461,8 @@ async def test_max_content_length( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=40) + # Config limit to 10 chars + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -413,45 +470,21 @@ async def test_max_content_length( mock_write_client.append_rows.reset_mock() # Test User Message Truncation - user_message = types.Content( - parts=[types.Part(text="12345678901234567890123456789012345678901")] - ) # 41 chars + long_text = "123456789012345" # 15 chars + user_message = types.Content(parts=[types.Part(text=long_text)]) + await plugin.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - assert ( - log_entry["content"] - == "User Content: text: '1234567890123456789012345678901234567890...' " - ) - assert log_entry["is_truncated"] - mock_write_client.append_rows.reset_mock() + await plugin.close() - # Test before_model_callback full content truncation - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - config=types.GenerateContentConfig( - system_instruction=types.Content( - parts=[types.Part(text="System Instruction")] - ) - ), - contents=[ - types.Content(role="user", parts=[types.Part(text="Prompt")]) - ], - ) - await plugin.before_model_callback( - callback_context=callback_context, llm_request=llm_request - ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - # Full content: "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt: System Instruction" - # Truncated to 40 chars + ...: - expected_content = "Model: gemini-pro | Prompt: user: text: ..." - assert log_entry["content"] == expected_content - assert log_entry["is_truncated"] + content = json.loads(log_entry["content"]) + + # Verify "1234567890...[TRUNCATED]" + assert content["text"] == "1234567890...[TRUNCATED]" + # Verify it is still valid JSON + assert isinstance(content, dict) @pytest.mark.asyncio async def test_max_content_length_tool_args( @@ -464,7 +497,8 @@ async def test_max_content_length_tool_args( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=80) + # Limit 10 chars + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -475,24 +509,21 @@ async def test_max_content_length_tool_args( base_tool_lib.BaseTool, instance=True, spec_set=True ) type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") + type(mock_tool).description = mock.PropertyMock(return_value="Desc") - # Args length > 80 - # {"param": "A" * 50} is ~60 chars. - # Prefix is ~57 chars. Total ~117 chars. + # Args contain a long string + long_val = "A" * 20 await plugin.before_tool_callback( tool=mock_tool, - tool_args={"param": "A" * 50}, + tool_args={"param": long_val}, tool_context=tool_context, ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + content = json.loads(log_entry["content"]) - assert 'Arguments: {"param": "AAAAA' in log_entry["content"] - assert log_entry["content"].endswith("...") - assert len(log_entry["content"]) == 83 # 80 + 3 dots - assert log_entry["is_truncated"] + # Verify truncation happened inside the JSON structure + assert content["arguments"]["param"] == "AAAAAAAAAA...[TRUNCATED]" @pytest.mark.asyncio async def test_max_content_length_tool_result( @@ -505,7 +536,7 @@ async def test_max_content_length_tool_result( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=80) + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -517,23 +548,18 @@ async def test_max_content_length_tool_result( ) type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - # Result length > 80 - # {"res": "A" * 60} is ~70 chars. - # Prefix is ~27 chars. Total ~97 chars. + long_res = "A" * 20 await plugin.after_tool_callback( tool=mock_tool, tool_args={}, tool_context=tool_context, - result={"res": "A" * 60}, + result={"res": long_res}, ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + content = json.loads(log_entry["content"]) - assert 'Result: {"res": "AAAAA' in log_entry["content"] - assert log_entry["content"].endswith("...") - assert len(log_entry["content"]) == 83 # 80 + 3 dots - assert log_entry["is_truncated"] + assert content["result"]["res"] == "AAAAAAAAAA...[TRUNCATED]" @pytest.mark.asyncio async def test_max_content_length_tool_error( @@ -546,7 +572,7 @@ async def test_max_content_length_tool_error( dummy_arrow_schema, mock_asyncio_to_thread, ): - config = BigQueryLoggerConfig(max_content_length=80) + config = BigQueryLoggerConfig(max_content_length=10) plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin( PROJECT_ID, DATASET_ID, TABLE_ID, config ) @@ -558,23 +584,18 @@ async def test_max_content_length_tool_error( ) type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - # Args length > 80 - # {"arg": "A" * 60} is ~70 chars. - # Prefix is ~28 chars. Total ~98 chars. + long_arg = "A" * 20 await plugin.on_tool_error_callback( tool=mock_tool, - tool_args={"arg": "A" * 60}, + tool_args={"arg": long_arg}, tool_context=tool_context, error=ValueError("Oops"), ) - await asyncio.sleep(0.01) - mock_write_client.append_rows.assert_called_once() + await plugin.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + content = json.loads(log_entry["content"]) - assert 'Arguments: {"arg": "AAAAA' in log_entry["content"] - assert log_entry["content"].endswith("...") - assert len(log_entry["content"]) == 83 # 80 + 3 dots - assert log_entry["is_truncated"] + assert content["arguments"]["arg"] == "AAAAAAAAAA...[TRUNCATED]" @pytest.mark.asyncio async def test_on_user_message_callback_logs_correctly( @@ -588,11 +609,13 @@ async def test_on_user_message_callback_logs_correctly( await bq_plugin_inst.on_user_message_callback( invocation_context=invocation_context, user_message=user_message ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "USER_MESSAGE_RECEIVED") - assert log_entry["content"] == "User Content: text: 'What is up?'" - assert not log_entry["is_truncated"] + + # UPDATED ASSERTION: Check JSON structure + content = json.loads(log_entry["content"]) + assert content["text"] == "What is up?" @pytest.mark.asyncio async def test_on_event_callback_tool_call( @@ -613,10 +636,14 @@ async def test_on_event_callback_tool_call( await bq_plugin_inst.on_event_callback( invocation_context=invocation_context, event=event ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "TOOL_CALL", agent="MyTestAgent") - assert "call: get_weather" in log_entry["content"] + + # Verify Generic Event JSON structure + content = json.loads(log_entry["content"]) + assert content["raw_role"] == "MyTestAgent" + assert content["tool_calls"] == ["get_weather"] assert log_entry["timestamp"] == datetime.datetime( 2025, 10, 22, 10, 0, 0, tzinfo=datetime.timezone.utc ) @@ -639,14 +666,208 @@ async def test_on_event_callback_model_response( await bq_plugin_inst.on_event_callback( invocation_context=invocation_context, event=event ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "MODEL_RESPONSE", agent="MyTestAgent") - assert "text: 'Hello there!'" in log_entry["content"] + + content = json.loads(log_entry["content"]) + assert content["content_parts"][0]["type"] == "text" + assert content["content_parts"][0]["text"] == "Hello there!" + assert log_entry["timestamp"] == datetime.datetime( 2025, 10, 22, 11, 0, 0, tzinfo=datetime.timezone.utc ) + @pytest.mark.asyncio + async def test_before_model_callback_logs_structure( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + """Covers combined logic of params and tools in one structured test.""" + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + config=types.GenerateContentConfig( + temperature=0.5, + top_p=0.9, + system_instruction=types.Content(parts=[types.Part(text="Sys")]), + ), + contents=[types.Content(role="user", parts=[types.Part(text="User")])], + ) + # Manually set tools_dict + llm_request.tools_dict = {"tool1": "func1"} + + await bq_plugin_inst.before_model_callback( + callback_context=callback_context, llm_request=llm_request + ) + await bq_plugin_inst.close() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_REQUEST") + + # Verify structured JSON + content = json.loads(log_entry["content"]) + assert content["model"] == "gemini-pro" + assert content["params"]["temperature"] == 0.5 + assert content["params"]["top_p"] == 0.9 + assert "tool1" in content["tools_available"] + assert content["system_instruction"] == "Sys" + assert content["prompt"][0]["role"] == "user" + assert content["prompt"][0]["parts"][0]["type"] == "text" + assert content["prompt"][0]["parts"][0]["text"] == "User" + + @pytest.mark.asyncio + async def test_after_model_callback_text_response( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + llm_response = llm_response_lib.LlmResponse( + content=types.Content(parts=[types.Part(text="Model response")]), + usage_metadata=types.UsageMetadata( + prompt_token_count=10, total_token_count=15 + ), + ) + await bq_plugin_inst.after_model_callback( + callback_context=callback_context, llm_response=llm_response + ) + await bq_plugin_inst.close() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_RESPONSE") + + # UPDATED ASSERTION: Check structured JSON + content = json.loads(log_entry["content"]) + assert content["response_content"][0]["type"] == "text" + assert content["response_content"][0]["text"] == "Model response" + assert content["usage"]["prompt_tokens"] == 10 + assert content["usage"]["total_tokens"] == 15 + + @pytest.mark.asyncio + async def test_after_model_callback_tool_call( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + tool_fc = types.FunctionCall(name="get_weather", args={"location": "Paris"}) + llm_response = llm_response_lib.LlmResponse( + content=types.Content(parts=[types.Part(function_call=tool_fc)]), + usage_metadata=types.UsageMetadata( + prompt_token_count=10, total_token_count=15 + ), + ) + await bq_plugin_inst.after_model_callback( + callback_context=callback_context, llm_response=llm_response + ) + await bq_plugin_inst.close() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_RESPONSE") + + content = json.loads(log_entry["content"]) + # Verify Tool Call structure + assert content["response_content"][0]["type"] == "function_call" + assert content["response_content"][0]["name"] == "get_weather" + assert content["response_content"][0]["args"]["location"] == "Paris" + + @pytest.mark.asyncio + async def test_before_tool_callback_logs_correctly( + self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema + ): + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + type(mock_tool).description = mock.PropertyMock(return_value="Description") + await bq_plugin_inst.before_tool_callback( + tool=mock_tool, + tool_args={"param": "value"}, + tool_context=tool_context, + ) + await bq_plugin_inst.close() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "TOOL_STARTING") + + # UPDATED ASSERTION: Check structured JSON + content = json.loads(log_entry["content"]) + assert content["tool_name"] == "MyTool" + assert content["description"] == "Description" + assert content["arguments"]["param"] == "value" + + @pytest.mark.asyncio + async def test_after_tool_callback_logs_correctly( + self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema + ): + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + type(mock_tool).description = mock.PropertyMock(return_value="Description") + await bq_plugin_inst.after_tool_callback( + tool=mock_tool, + tool_args={}, + tool_context=tool_context, + result={"status": "success"}, + ) + await bq_plugin_inst.close() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "TOOL_COMPLETED") + + # UPDATED ASSERTION: Check structured JSON + content = json.loads(log_entry["content"]) + assert content["tool_name"] == "MyTool" + assert content["result"]["status"] == "success" + + @pytest.mark.asyncio + async def test_on_model_error_callback_logs_correctly( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + llm_request = llm_request_lib.LlmRequest( + model="gemini-pro", + contents=[types.Content(parts=[types.Part(text="Prompt")])], + ) + error = ValueError("LLM failed") + await bq_plugin_inst.on_model_error_callback( + callback_context=callback_context, llm_request=llm_request, error=error + ) + await bq_plugin_inst.close() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "LLM_ERROR") + assert log_entry["content"] is None + assert log_entry["error_message"] == "LLM failed" + + @pytest.mark.asyncio + async def test_on_tool_error_callback_logs_correctly( + self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema + ): + mock_tool = mock.create_autospec( + base_tool_lib.BaseTool, instance=True, spec_set=True + ) + type(mock_tool).name = mock.PropertyMock(return_value="MyTool") + type(mock_tool).description = mock.PropertyMock(return_value="Description") + error = TimeoutError("Tool timed out") + await bq_plugin_inst.on_tool_error_callback( + tool=mock_tool, + tool_args={"param": "value"}, + tool_context=tool_context, + error=error, + ) + await bq_plugin_inst.close() + log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) + _assert_common_fields(log_entry, "TOOL_ERROR") + + content = json.loads(log_entry["content"]) + assert content["tool_name"] == "MyTool" + assert content["arguments"]["param"] == "value" + assert log_entry["error_message"] == "Tool timed out" + @pytest.mark.asyncio async def test_bigquery_client_initialization_failure( self, @@ -670,7 +891,9 @@ async def test_bigquery_client_initialization_failure( invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) - await asyncio.sleep(0.01) + # Wait for the background task (which logs the error) to complete + await plugin_with_fail.close() + mock_log_error.assert_any_call("BQ Plugin: Init Failed:", exc_info=True) mock_write_client.append_rows.assert_not_called() @@ -678,7 +901,6 @@ async def test_bigquery_client_initialization_failure( async def test_bigquery_insert_error_does_not_raise( self, bq_plugin_inst, mock_write_client, invocation_context ): - async def fake_append_rows_with_error(requests, **kwargs): mock_append_rows_response = mock.MagicMock() mock_append_rows_response.row_errors = [] # No row errors @@ -694,7 +916,7 @@ async def fake_append_rows_with_error(requests, **kwargs): invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() mock_log_error.assert_called_with( "BQ Plugin: Write Error: %s", "Test BQ Error" ) @@ -723,11 +945,11 @@ async def fake_append_rows_with_schema_error(requests, **kwargs): invocation_context=invocation_context, user_message=types.Content(parts=[types.Part(text="Test")]), ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() mock_log_error.assert_called_with( - "BQ Plugin: Schema Mismatch Error. The BigQuery table schema may be" - " incorrect or out of sync with the plugin. Please verify the table" - " definition. Details: %s", + "BQ Plugin: Schema Mismatch. You may need to delete the existing" + " table if you migrated from STRING content to JSON content." + " Details: %s", "Schema mismatch: Field 'new_field' not found in table.", ) @@ -735,8 +957,6 @@ async def fake_append_rows_with_schema_error(requests, **kwargs): async def test_close(self, bq_plugin_inst, mock_bq_client, mock_write_client): await bq_plugin_inst.close() mock_write_client.transport.close.assert_called_once() - # bq_client might not be closed if it wasn't created or if close() failed, - # but here it should be. # in the new implementation we verify attributes are reset assert bq_plugin_inst._write_client is None assert bq_plugin_inst._bq_client is None @@ -753,7 +973,7 @@ async def test_before_run_callback_logs_correctly( await bq_plugin_inst.before_run_callback( invocation_context=invocation_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "INVOCATION_STARTING") assert log_entry["content"] is None @@ -769,7 +989,7 @@ async def test_after_run_callback_logs_correctly( await bq_plugin_inst.after_run_callback( invocation_context=invocation_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "INVOCATION_COMPLETED") assert log_entry["content"] is None @@ -786,10 +1006,12 @@ async def test_before_agent_callback_logs_correctly( await bq_plugin_inst.before_agent_callback( agent=mock_agent, callback_context=callback_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "AGENT_STARTING") - assert log_entry["content"] == "Agent Name: MyTestAgent" + + content = json.loads(log_entry["content"]) + assert content["target_agent"] == "MyTestAgent" @pytest.mark.asyncio async def test_after_agent_callback_logs_correctly( @@ -803,219 +1025,12 @@ async def test_after_agent_callback_logs_correctly( await bq_plugin_inst.after_agent_callback( agent=mock_agent, callback_context=callback_context ) - await asyncio.sleep(0.01) + await bq_plugin_inst.close() log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) _assert_common_fields(log_entry, "AGENT_COMPLETED") - assert log_entry["content"] == "Agent Name: MyTestAgent" - - @pytest.mark.asyncio - async def test_before_model_callback_logs_correctly( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - contents=[ - types.Content(role="user", parts=[types.Part(text="Prompt")]) - ], - ) - await bq_plugin_inst.before_model_callback( - callback_context=callback_context, llm_request=llm_request - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_REQUEST") - assert ( - log_entry["content"] - == "Model: gemini-pro | Prompt: user: text: 'Prompt' | System Prompt:" - " Empty" - ) - - @pytest.mark.asyncio - async def test_before_model_callback_with_params_and_tools( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - config=types.GenerateContentConfig( - temperature=0.5, - top_p=0.9, - system_instruction=types.Content(parts=[types.Part(text="Sys")]), - ), - contents=[types.Content(role="user", parts=[types.Part(text="User")])], - ) - # Manually set tools_dict as it is excluded from init - llm_request.tools_dict = {"tool1": "func1", "tool2": "func2"} - - await bq_plugin_inst.before_model_callback( - callback_context=callback_context, llm_request=llm_request - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_REQUEST") - # Order: Model | Params | Tools | Prompt | System Prompt - # Note: Params order depends on dict iteration but here we construct it deterministically in code? - # The code does: params_to_log["temperature"] = ... then "top_p" = ... - # So order should be temperature, top_p. - assert "Model: gemini-pro" in log_entry["content"] - assert "Params: {temperature=0.5, top_p=0.9}" in log_entry["content"] - assert "Available Tools: ['tool1', 'tool2']" in log_entry["content"] - assert "Prompt: user: text: 'User'" in log_entry["content"] - assert "System Prompt: Sys" in log_entry["content"] - - @pytest.mark.asyncio - async def test_after_model_callback_text_response( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_response = llm_response_lib.LlmResponse( - content=types.Content(parts=[types.Part(text="Model response")]), - usage_metadata=types.UsageMetadata( - prompt_token_count=10, total_token_count=15 - ), - ) - await bq_plugin_inst.after_model_callback( - callback_context=callback_context, llm_response=llm_response - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_RESPONSE") - assert ( - "Tool Name: text_response, text: 'Model response'" - in log_entry["content"] - ) - assert "Token Usage:" in log_entry["content"] - assert "prompt: 10" in log_entry["content"] - assert "total: 15" in log_entry["content"] - assert log_entry["error_message"] is None - - @pytest.mark.asyncio - async def test_after_model_callback_tool_call( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - tool_fc = types.FunctionCall(name="get_weather", args={"location": "Paris"}) - llm_response = llm_response_lib.LlmResponse( - content=types.Content(parts=[types.Part(function_call=tool_fc)]), - usage_metadata=types.UsageMetadata( - prompt_token_count=10, total_token_count=15 - ), - ) - await bq_plugin_inst.after_model_callback( - callback_context=callback_context, llm_response=llm_response - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_RESPONSE") - assert "Tool Name: get_weather" in log_entry["content"] - assert "Token Usage:" in log_entry["content"] - assert "prompt: 10" in log_entry["content"] - assert "total: 15" in log_entry["content"] - assert log_entry["error_message"] is None - @pytest.mark.asyncio - async def test_before_tool_callback_logs_correctly( - self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema - ): - mock_tool = mock.create_autospec( - base_tool_lib.BaseTool, instance=True, spec_set=True - ) - type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") - await bq_plugin_inst.before_tool_callback( - tool=mock_tool, tool_args={"param": "value"}, tool_context=tool_context - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "TOOL_STARTING") - assert ( - log_entry["content"] - == 'Tool Name: MyTool, Description: Description, Arguments: {"param":' - ' "value"}' - ) - - @pytest.mark.asyncio - async def test_after_tool_callback_logs_correctly( - self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema - ): - mock_tool = mock.create_autospec( - base_tool_lib.BaseTool, instance=True, spec_set=True - ) - type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") - await bq_plugin_inst.after_tool_callback( - tool=mock_tool, - tool_args={}, - tool_context=tool_context, - result={"status": "success"}, - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "TOOL_COMPLETED") - assert ( - log_entry["content"] - == 'Tool Name: MyTool, Result: {"status": "success"}' - ) - - @pytest.mark.asyncio - async def test_on_model_error_callback_logs_correctly( - self, - bq_plugin_inst, - mock_write_client, - callback_context, - dummy_arrow_schema, - ): - llm_request = llm_request_lib.LlmRequest( - model="gemini-pro", - contents=[types.Content(parts=[types.Part(text="Prompt")])], - ) - error = ValueError("LLM failed") - await bq_plugin_inst.on_model_error_callback( - callback_context=callback_context, llm_request=llm_request, error=error - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "LLM_ERROR") - assert log_entry["content"] is None - assert log_entry["error_message"] == "LLM failed" - - @pytest.mark.asyncio - async def test_on_tool_error_callback_logs_correctly( - self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema - ): - mock_tool = mock.create_autospec( - base_tool_lib.BaseTool, instance=True, spec_set=True - ) - type(mock_tool).name = mock.PropertyMock(return_value="MyTool") - type(mock_tool).description = mock.PropertyMock(return_value="Description") - error = TimeoutError("Tool timed out") - await bq_plugin_inst.on_tool_error_callback( - tool=mock_tool, - tool_args={"param": "value"}, - tool_context=tool_context, - error=error, - ) - await asyncio.sleep(0.01) - log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema) - _assert_common_fields(log_entry, "TOOL_ERROR") - assert ( - log_entry["content"] - == 'Tool Name: MyTool, Arguments: {"param": "value"}' - ) - assert log_entry["error_message"] == "Tool timed out" + content = json.loads(log_entry["content"]) + assert content["target_agent"] == "MyTestAgent" @pytest.mark.asyncio async def test_table_creation_options( @@ -1039,9 +1054,7 @@ async def test_table_creation_options( assert table_arg.time_partitioning.type_ == "DAY" assert table_arg.time_partitioning.field == "timestamp" assert table_arg.clustering_fields == ["event_type", "agent", "user_id"] - # Verify schema descriptions are present (spot check) - timestamp_field = next(f for f in table_arg.schema if f.name == "timestamp") - assert ( - timestamp_field.description - == "The UTC time at which the event was logged." - ) + + # Verify schema type for content is JSON + content_field = next(f for f in table_arg.schema if f.name == "content") + assert content_field.field_type == "JSON"