Skip to content

Commit a3ace99

Browse files
committed
fix: Ensure close() resets clients to None and use it in tests
1 parent f076977 commit a3ace99

File tree

2 files changed

+74
-39
lines changed

2 files changed

+74
-39
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from google.genai import types
3535
import pyarrow as pa
3636

37-
from .. import version
3837
from ..agents.base_agent import BaseAgent
3938
from ..agents.callback_context import CallbackContext
4039
from ..events.event import Event
@@ -382,7 +381,7 @@ async def _ensure_init(self):
382381
scopes=["https://www.googleapis.com/auth/cloud-platform"],
383382
)
384383
client_info = gapic_client_info.ClientInfo(
385-
user_agent=f"google-adk-bq-logger/{version.__version__}"
384+
user_agent="google-adk-bq-logger"
386385
)
387386
self._bq_client = bigquery.Client(
388387
project=self._project_id, credentials=creds, client_info=client_info
@@ -585,7 +584,7 @@ async def on_user_message_callback(
585584
if user_message and user_message.parts:
586585
text_content = " ".join([p.text for p in user_message.parts if p.text])
587586

588-
payload = {"text": text_content}
587+
payload = {"text": text_content if text_content else None}
589588

590589
await self._log(
591590
{
@@ -647,7 +646,7 @@ async def on_event_callback(
647646
"text": " ".join(text_parts) if text_parts else None,
648647
"tool_calls": tool_calls if tool_calls else None,
649648
"tool_responses": tool_responses if tool_responses else None,
650-
"raw_role": event.author,
649+
"raw_role": event.author if event.author else None,
651650
}
652651

653652
await self._log(
@@ -750,15 +749,25 @@ async def before_model_callback(
750749
params["max_output_tokens"] = cfg.max_output_tokens
751750

752751
# 2. System Instruction
753-
system_instr = "None"
754-
if llm_request.config and llm_request.config.system_instruction:
752+
system_instr = None
753+
if llm_request.config and llm_request.config.system_instruction is not None:
755754
si = llm_request.config.system_instruction
756755
if isinstance(si, str):
757756
system_instr = si
758757
elif isinstance(si, types.Content):
759758
system_instr = "".join(p.text for p in si.parts if p.text)
760759
elif isinstance(si, types.Part):
761760
system_instr = si.text
761+
elif hasattr(si, "__iter__"):
762+
texts = []
763+
for item in si:
764+
if isinstance(item, str):
765+
texts.append(item)
766+
elif isinstance(item, types.Part) and item.text:
767+
texts.append(item.text)
768+
system_instr = "".join(texts)
769+
else:
770+
system_instr = str(si)
762771

763772
# 3. Prompt History (Simplified structure for JSON)
764773
prompt_history = []
@@ -843,7 +852,10 @@ async def after_model_callback(
843852
),
844853
}
845854

846-
payload = {"response_content": content_parts, "usage": usage}
855+
payload = {
856+
"response_content": content_parts if content_parts else None,
857+
"usage": usage if usage else None,
858+
}
847859

848860
await self._log(
849861
{
@@ -876,10 +888,11 @@ async def before_tool_callback(
876888
If individual string fields exceed `max_content_length`, they are truncated
877889
to preserve the valid JSON structure.
878890
"""
891+
879892
payload = {
880-
"tool_name": tool.name,
881-
"description": tool.description,
882-
"arguments": tool_args,
893+
"tool_name": tool.name if tool.name else None,
894+
"description": tool.description if tool.description else None,
895+
"arguments": tool_args if tool_args else None,
883896
}
884897
await self._log(
885898
{
@@ -910,7 +923,7 @@ async def after_tool_callback(
910923
If individual string fields exceed `max_content_length`, they are truncated
911924
to preserve the valid JSON structure.
912925
"""
913-
payload = {"tool_name": tool.name, "result": result}
926+
payload = {"tool_name": tool.name if tool.name else None, "result": result if result else None}
914927
await self._log(
915928
{
916929
"event_type": "TOOL_COMPLETED",
@@ -964,7 +977,7 @@ async def on_tool_error_callback(
964977
If individual string fields exceed `max_content_length`, they are truncated
965978
to preserve the valid JSON structure.
966979
"""
967-
payload = {"tool_name": tool.name, "arguments": tool_args}
980+
payload = {"tool_name": tool.name if tool.name else None, "arguments": tool_args if tool_args else None}
968981
await self._log(
969982
{
970983
"event_type": "TOOL_ERROR",

tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ async def fake_append_rows(requests, **kwargs):
143143

144144
@pytest.fixture
145145
def dummy_arrow_schema():
146-
# UPDATED: content is pa.string() because JSON is serialized to string before Arrow
146+
# content is pa.string() because JSON is serialized to string before Arrow
147147
return pa.schema([
148148
pa.field("timestamp", pa.timestamp("us", tz="UTC"), nullable=False),
149149
pa.field("event_type", pa.string(), nullable=True),
@@ -259,6 +259,9 @@ async def test_plugin_disabled(
259259
invocation_context=invocation_context,
260260
user_message=types.Content(parts=[types.Part(text="Test")]),
261261
)
262+
# Wait for background tasks
263+
await plugin.close()
264+
262265
mock_auth_default.assert_not_called()
263266
mock_bq_client.assert_not_called()
264267
mock_write_client.append_rows.assert_not_called()
@@ -289,15 +292,25 @@ async def test_event_allowlist(
289292
await plugin.before_model_callback(
290293
callback_context=callback_context, llm_request=llm_request
291294
)
292-
await asyncio.sleep(0.01) # Allow background task to run
295+
await plugin.close() # Wait for write
293296
mock_write_client.append_rows.assert_called_once()
294297
mock_write_client.append_rows.reset_mock()
295298

299+
# Re-init plugin logic since close() shuts it down, but for this test we want to test denial
300+
# However, close() cleans up clients. We should probably create a new plugin or just check that the task was not created.
301+
# But on_user_message_callback will try to log.
302+
# To keep it simple, let's just use a fresh plugin for the second part or assume close() resets state enough to re-run _ensure_init if needed,
303+
# but _ensure_init is called inside _perform_write.
304+
# Actually, close() sets _is_shutting_down to True, so further logs are ignored.
305+
# So we need a new plugin instance or reset _is_shutting_down.
306+
plugin._is_shutting_down = False
307+
296308
user_message = types.Content(parts=[types.Part(text="What is up?")])
297309
await plugin.on_user_message_callback(
298310
invocation_context=invocation_context, user_message=user_message
299311
)
300-
await asyncio.sleep(0.01) # Allow background task to run
312+
# Since it's denied, no task is created. close() would wait if there was one.
313+
await plugin.close()
301314
mock_write_client.append_rows.assert_not_called()
302315

303316
@pytest.mark.asyncio
@@ -322,11 +335,14 @@ async def test_event_denylist(
322335
await plugin.on_user_message_callback(
323336
invocation_context=invocation_context, user_message=user_message
324337
)
325-
await asyncio.sleep(0.01)
338+
await plugin.close()
326339
mock_write_client.append_rows.assert_not_called()
327340

341+
# Reset for next call
342+
plugin._is_shutting_down = False
343+
328344
await plugin.before_run_callback(invocation_context=invocation_context)
329-
await asyncio.sleep(0.01)
345+
await plugin.close()
330346
mock_write_client.append_rows.assert_called_once()
331347

332348
@pytest.mark.asyncio
@@ -370,7 +386,7 @@ def mutate_payload(data):
370386
await plugin.before_model_callback(
371387
callback_context=callback_context, llm_request=llm_request
372388
)
373-
await asyncio.sleep(0.01)
389+
await plugin.close()
374390
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
375391

376392
# Parse JSON
@@ -408,7 +424,7 @@ async def test_max_content_length_smart_truncation(
408424
await plugin.on_user_message_callback(
409425
invocation_context=invocation_context, user_message=user_message
410426
)
411-
await asyncio.sleep(0.01)
427+
await plugin.close()
412428

413429
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
414430
content = json.loads(log_entry["content"])
@@ -450,7 +466,7 @@ async def test_max_content_length_tool_args(
450466
tool_args={"param": long_val},
451467
tool_context=tool_context,
452468
)
453-
await asyncio.sleep(0.01)
469+
await plugin.close()
454470
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
455471
content = json.loads(log_entry["content"])
456472

@@ -487,7 +503,7 @@ async def test_max_content_length_tool_result(
487503
tool_context=tool_context,
488504
result={"res": long_res},
489505
)
490-
await asyncio.sleep(0.01)
506+
await plugin.close()
491507
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
492508
content = json.loads(log_entry["content"])
493509

@@ -523,7 +539,7 @@ async def test_max_content_length_tool_error(
523539
tool_context=tool_context,
524540
error=ValueError("Oops"),
525541
)
526-
await asyncio.sleep(0.01)
542+
await plugin.close()
527543
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
528544
content = json.loads(log_entry["content"])
529545

@@ -541,10 +557,11 @@ async def test_on_user_message_callback_logs_correctly(
541557
await bq_plugin_inst.on_user_message_callback(
542558
invocation_context=invocation_context, user_message=user_message
543559
)
544-
await asyncio.sleep(0.01)
560+
await bq_plugin_inst.close()
545561
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
546562
_assert_common_fields(log_entry, "USER_MESSAGE_RECEIVED")
547563

564+
# UPDATED ASSERTION: Check JSON structure
548565
content = json.loads(log_entry["content"])
549566
assert content["text"] == "What is up?"
550567

@@ -567,7 +584,7 @@ async def test_on_event_callback_tool_call(
567584
await bq_plugin_inst.on_event_callback(
568585
invocation_context=invocation_context, event=event
569586
)
570-
await asyncio.sleep(0.01)
587+
await bq_plugin_inst.close()
571588
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
572589
_assert_common_fields(log_entry, "TOOL_CALL", agent="MyTestAgent")
573590

@@ -594,7 +611,7 @@ async def test_on_event_callback_model_response(
594611
await bq_plugin_inst.on_event_callback(
595612
invocation_context=invocation_context, event=event
596613
)
597-
await asyncio.sleep(0.01)
614+
await bq_plugin_inst.close()
598615
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
599616
_assert_common_fields(log_entry, "MODEL_RESPONSE", agent="MyTestAgent")
600617

@@ -625,7 +642,7 @@ async def test_before_model_callback_logs_structure(
625642
await bq_plugin_inst.before_model_callback(
626643
callback_context=callback_context, llm_request=llm_request
627644
)
628-
await asyncio.sleep(0.01)
645+
await bq_plugin_inst.close()
629646
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
630647
_assert_common_fields(log_entry, "LLM_REQUEST")
631648

@@ -657,10 +674,11 @@ async def test_after_model_callback_text_response(
657674
await bq_plugin_inst.after_model_callback(
658675
callback_context=callback_context, llm_response=llm_response
659676
)
660-
await asyncio.sleep(0.01)
677+
await bq_plugin_inst.close()
661678
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
662679
_assert_common_fields(log_entry, "LLM_RESPONSE")
663680

681+
# UPDATED ASSERTION: Check structured JSON
664682
content = json.loads(log_entry["content"])
665683
assert content["response_content"][0]["type"] == "text"
666684
assert content["response_content"][0]["text"] == "Model response"
@@ -685,7 +703,7 @@ async def test_after_model_callback_tool_call(
685703
await bq_plugin_inst.after_model_callback(
686704
callback_context=callback_context, llm_response=llm_response
687705
)
688-
await asyncio.sleep(0.01)
706+
await bq_plugin_inst.close()
689707
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
690708
_assert_common_fields(log_entry, "LLM_RESPONSE")
691709

@@ -707,10 +725,11 @@ async def test_before_tool_callback_logs_correctly(
707725
await bq_plugin_inst.before_tool_callback(
708726
tool=mock_tool, tool_args={"param": "value"}, tool_context=tool_context
709727
)
710-
await asyncio.sleep(0.01)
728+
await bq_plugin_inst.close()
711729
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
712730
_assert_common_fields(log_entry, "TOOL_STARTING")
713731

732+
# UPDATED ASSERTION: Check structured JSON
714733
content = json.loads(log_entry["content"])
715734
assert content["tool_name"] == "MyTool"
716735
assert content["description"] == "Description"
@@ -731,10 +750,11 @@ async def test_after_tool_callback_logs_correctly(
731750
tool_context=tool_context,
732751
result={"status": "success"},
733752
)
734-
await asyncio.sleep(0.01)
753+
await bq_plugin_inst.close()
735754
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
736755
_assert_common_fields(log_entry, "TOOL_COMPLETED")
737756

757+
# UPDATED ASSERTION: Check structured JSON
738758
content = json.loads(log_entry["content"])
739759
assert content["tool_name"] == "MyTool"
740760
assert content["result"]["status"] == "success"
@@ -755,7 +775,7 @@ async def test_on_model_error_callback_logs_correctly(
755775
await bq_plugin_inst.on_model_error_callback(
756776
callback_context=callback_context, llm_request=llm_request, error=error
757777
)
758-
await asyncio.sleep(0.01)
778+
await bq_plugin_inst.close()
759779
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
760780
_assert_common_fields(log_entry, "LLM_ERROR")
761781
assert log_entry["content"] is None
@@ -777,7 +797,7 @@ async def test_on_tool_error_callback_logs_correctly(
777797
tool_context=tool_context,
778798
error=error,
779799
)
780-
await asyncio.sleep(0.01)
800+
await bq_plugin_inst.close()
781801
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
782802
_assert_common_fields(log_entry, "TOOL_ERROR")
783803

@@ -809,7 +829,9 @@ async def test_bigquery_client_initialization_failure(
809829
invocation_context=invocation_context,
810830
user_message=types.Content(parts=[types.Part(text="Test")]),
811831
)
812-
await asyncio.sleep(0.01)
832+
# Wait for the background task (which logs the error) to complete
833+
await plugin_with_fail.close()
834+
813835
mock_log_error.assert_any_call("BQ Plugin: Init Failed:", exc_info=True)
814836
mock_write_client.append_rows.assert_not_called()
815837

@@ -832,7 +854,7 @@ async def fake_append_rows_with_error(requests, **kwargs):
832854
invocation_context=invocation_context,
833855
user_message=types.Content(parts=[types.Part(text="Test")]),
834856
)
835-
await asyncio.sleep(0.01)
857+
await bq_plugin_inst.close()
836858
mock_log_error.assert_called_with(
837859
"BQ Plugin: Write Error: %s", "Test BQ Error"
838860
)
@@ -861,7 +883,7 @@ async def fake_append_rows_with_schema_error(requests, **kwargs):
861883
invocation_context=invocation_context,
862884
user_message=types.Content(parts=[types.Part(text="Test")]),
863885
)
864-
await asyncio.sleep(0.01)
886+
await bq_plugin_inst.close()
865887
mock_log_error.assert_called_with(
866888
"BQ Plugin: Schema Mismatch. You may need to delete the existing"
867889
" table if you migrated from STRING content to JSON content."
@@ -889,7 +911,7 @@ async def test_before_run_callback_logs_correctly(
889911
await bq_plugin_inst.before_run_callback(
890912
invocation_context=invocation_context
891913
)
892-
await asyncio.sleep(0.01)
914+
await bq_plugin_inst.close()
893915
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
894916
_assert_common_fields(log_entry, "INVOCATION_STARTING")
895917
assert log_entry["content"] is None
@@ -905,7 +927,7 @@ async def test_after_run_callback_logs_correctly(
905927
await bq_plugin_inst.after_run_callback(
906928
invocation_context=invocation_context
907929
)
908-
await asyncio.sleep(0.01)
930+
await bq_plugin_inst.close()
909931
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
910932
_assert_common_fields(log_entry, "INVOCATION_COMPLETED")
911933
assert log_entry["content"] is None
@@ -922,7 +944,7 @@ async def test_before_agent_callback_logs_correctly(
922944
await bq_plugin_inst.before_agent_callback(
923945
agent=mock_agent, callback_context=callback_context
924946
)
925-
await asyncio.sleep(0.01)
947+
await bq_plugin_inst.close()
926948
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
927949
_assert_common_fields(log_entry, "AGENT_STARTING")
928950

@@ -941,7 +963,7 @@ async def test_after_agent_callback_logs_correctly(
941963
await bq_plugin_inst.after_agent_callback(
942964
agent=mock_agent, callback_context=callback_context
943965
)
944-
await asyncio.sleep(0.01)
966+
await bq_plugin_inst.close()
945967
log_entry = _get_captured_event_dict(mock_write_client, dummy_arrow_schema)
946968
_assert_common_fields(log_entry, "AGENT_COMPLETED")
947969

0 commit comments

Comments
 (0)