Skip to content

Commit aa99faa

Browse files
committed
Address review feedback on BQ plugin JSON structure, timestamps, linting
1 parent a3ace99 commit aa99faa

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from google.api_core.gapic_v1 import client_info as gapic_client_info
2929
import google.auth
3030
from google.cloud import bigquery
31+
from google.cloud import bigquery_storage_v1
3132
from google.cloud.bigquery import schema as bq_schema
3233
from google.cloud.bigquery_storage_v1 import types as bq_storage_types
33-
from google.cloud.bigquery_storage_v1.services.big_query_write.async_client import BigQueryWriteAsyncClient
3434
from google.genai import types
3535
import pyarrow as pa
3636

@@ -221,7 +221,7 @@ class BigQueryLoggerConfig:
221221
event_allowlist: Optional[List[str]] = None
222222
event_denylist: Optional[List[str]] = None
223223
# Custom formatter is discouraged now that we use JSON, but kept for compat
224-
content_formatter: Optional[Callable[[Any], str]] = None
224+
content_formatter: Optional[Callable[[dict], dict]] = None
225225
shutdown_timeout: float = 5.0
226226
client_close_timeout: float = 2.0
227227
# Increased default limit to 50KB since we truncate per-field, not per-row
@@ -307,7 +307,11 @@ def __init__(
307307
)
308308
self._config = config if config else BigQueryLoggerConfig()
309309
self._bq_client: bigquery.Client | None = None
310-
self._write_client: BigQueryWriteAsyncClient | None = None
310+
# Type alias update: Use the class from the top-level package import
311+
self._write_client: (
312+
bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient
313+
| None
314+
) = None
311315
self._init_lock: asyncio.Lock | None = None
312316
self._arrow_schema: pa.Schema | None = None
313317
self._background_tasks: set[asyncio.Task] = set()
@@ -407,7 +411,8 @@ def create_resources():
407411

408412
await asyncio.to_thread(create_resources)
409413

410-
self._write_client = BigQueryWriteAsyncClient(
414+
# Fix: Use the top-level package import to avoid "cli" substring in path
415+
self._write_client = bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient(
411416
credentials=creds,
412417
client_info=client_info,
413418
)
@@ -446,7 +451,11 @@ async def _perform_write(self, row: dict):
446451
):
447452
if resp.error.code != 0:
448453
msg = resp.error.message
449-
if "schema mismatch" in msg.lower():
454+
if (
455+
"schema mismatch" in msg.lower()
456+
or "field" in msg.lower()
457+
or "type" in msg.lower()
458+
):
450459
logging.error(
451460
"BQ Plugin: Schema Mismatch. You may need to delete the"
452461
" existing table if you migrated from STRING content to JSON"
@@ -462,7 +471,7 @@ async def _perform_write(self, row: dict):
462471
except asyncio.CancelledError:
463472
if not self._is_shutting_down:
464473
logging.warning("BQ Plugin: Write task cancelled unexpectedly.")
465-
except Exception:
474+
except Exception as e:
466475
logging.error("BQ Plugin: Write Failed:", exc_info=True)
467476

468477
async def _log(self, data: dict, content_payload: Any = None):
@@ -657,6 +666,7 @@ async def on_event_callback(
657666
"invocation_id": invocation_context.invocation_id,
658667
"user_id": invocation_context.session.user_id,
659668
"error_message": event.error_message,
669+
"timestamp": datetime.fromtimestamp(event.timestamp, timezone.utc),
660670
},
661671
content_payload=payload,
662672
)
@@ -790,14 +800,14 @@ async def before_model_callback(
790800

791801
payload = {
792802
"model": llm_request.model or "default",
793-
"params": params,
803+
"params": params if params else None,
794804
"tools_available": (
795805
list(llm_request.tools_dict.keys())
796806
if llm_request.tools_dict
797-
else []
807+
else None
798808
),
799809
"system_instruction": system_instr,
800-
"prompt": prompt_history,
810+
"prompt": prompt_history if prompt_history else None,
801811
}
802812

803813
await self._log(
@@ -888,7 +898,6 @@ async def before_tool_callback(
888898
If individual string fields exceed `max_content_length`, they are truncated
889899
to preserve the valid JSON structure.
890900
"""
891-
892901
payload = {
893902
"tool_name": tool.name if tool.name else None,
894903
"description": tool.description if tool.description else None,
@@ -923,7 +932,10 @@ async def after_tool_callback(
923932
If individual string fields exceed `max_content_length`, they are truncated
924933
to preserve the valid JSON structure.
925934
"""
926-
payload = {"tool_name": tool.name if tool.name else None, "result": result if result else None}
935+
payload = {
936+
"tool_name": tool.name if tool.name else None,
937+
"result": result if result else None,
938+
}
927939
await self._log(
928940
{
929941
"event_type": "TOOL_COMPLETED",
@@ -977,7 +989,10 @@ async def on_tool_error_callback(
977989
If individual string fields exceed `max_content_length`, they are truncated
978990
to preserve the valid JSON structure.
979991
"""
980-
payload = {"tool_name": tool.name if tool.name else None, "arguments": tool_args if tool_args else None}
992+
payload = {
993+
"tool_name": tool.name if tool.name else None,
994+
"arguments": tool_args if tool_args else None,
995+
}
981996
await self._log(
982997
{
983998
"event_type": "TOOL_ERROR",

tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ def mock_bq_client():
123123

124124
@pytest.fixture
125125
def mock_write_client():
126-
with mock.patch.object(
127-
bigquery_agent_analytics_plugin, "BigQueryWriteAsyncClient", autospec=True
126+
# Updated patch path to match the new import structure in src
127+
with mock.patch(
128+
"google.cloud.bigquery_storage_v1.services.big_query_write.async_client.BigQueryWriteAsyncClient",
129+
autospec=True,
128130
) as mock_cls:
129131
mock_client = mock_cls.return_value
130132
mock_client.transport = mock.AsyncMock()

0 commit comments

Comments
 (0)