@@ -143,7 +143,7 @@ async def fake_append_rows(requests, **kwargs):
143143
144144@pytest .fixture
145145def dummy_arrow_schema ():
146- # UPDATED: content is pa.string() because JSON is serialized to string before Arrow
146+ # content is pa.string() because JSON is serialized to string before Arrow
147147 return pa .schema ([
148148 pa .field ("timestamp" , pa .timestamp ("us" , tz = "UTC" ), nullable = False ),
149149 pa .field ("event_type" , pa .string (), nullable = True ),
@@ -259,6 +259,9 @@ async def test_plugin_disabled(
259259 invocation_context = invocation_context ,
260260 user_message = types .Content (parts = [types .Part (text = "Test" )]),
261261 )
262+ # Wait for background tasks
263+ await plugin .close ()
264+
262265 mock_auth_default .assert_not_called ()
263266 mock_bq_client .assert_not_called ()
264267 mock_write_client .append_rows .assert_not_called ()
@@ -289,15 +292,25 @@ async def test_event_allowlist(
289292 await plugin .before_model_callback (
290293 callback_context = callback_context , llm_request = llm_request
291294 )
292- await asyncio . sleep ( 0.01 ) # Allow background task to run
295+ await plugin . close ( ) # Wait for write
293296 mock_write_client .append_rows .assert_called_once ()
294297 mock_write_client .append_rows .reset_mock ()
295298
299+ # Re-init plugin logic since close() shuts it down, but for this test we want to test denial
300+ # However, close() cleans up clients. We should probably create a new plugin or just check that the task was not created.
301+ # But on_user_message_callback will try to log.
302+ # To keep it simple, let's just use a fresh plugin for the second part or assume close() resets state enough to re-run _ensure_init if needed,
303+ # but _ensure_init is called inside _perform_write.
304+ # Actually, close() sets _is_shutting_down to True, so further logs are ignored.
305+ # So we need a new plugin instance or reset _is_shutting_down.
306+ plugin ._is_shutting_down = False
307+
296308 user_message = types .Content (parts = [types .Part (text = "What is up?" )])
297309 await plugin .on_user_message_callback (
298310 invocation_context = invocation_context , user_message = user_message
299311 )
300- await asyncio .sleep (0.01 ) # Allow background task to run
312+ # Since it's denied, no task is created. close() would wait if there was one.
313+ await plugin .close ()
301314 mock_write_client .append_rows .assert_not_called ()
302315
303316 @pytest .mark .asyncio
@@ -322,11 +335,14 @@ async def test_event_denylist(
322335 await plugin .on_user_message_callback (
323336 invocation_context = invocation_context , user_message = user_message
324337 )
325- await asyncio . sleep ( 0.01 )
338+ await plugin . close ( )
326339 mock_write_client .append_rows .assert_not_called ()
327340
341+ # Reset for next call
342+ plugin ._is_shutting_down = False
343+
328344 await plugin .before_run_callback (invocation_context = invocation_context )
329- await asyncio . sleep ( 0.01 )
345+ await plugin . close ( )
330346 mock_write_client .append_rows .assert_called_once ()
331347
332348 @pytest .mark .asyncio
@@ -370,7 +386,7 @@ def mutate_payload(data):
370386 await plugin .before_model_callback (
371387 callback_context = callback_context , llm_request = llm_request
372388 )
373- await asyncio . sleep ( 0.01 )
389+ await plugin . close ( )
374390 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
375391
376392 # Parse JSON
@@ -408,7 +424,7 @@ async def test_max_content_length_smart_truncation(
408424 await plugin .on_user_message_callback (
409425 invocation_context = invocation_context , user_message = user_message
410426 )
411- await asyncio . sleep ( 0.01 )
427+ await plugin . close ( )
412428
413429 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
414430 content = json .loads (log_entry ["content" ])
@@ -450,7 +466,7 @@ async def test_max_content_length_tool_args(
450466 tool_args = {"param" : long_val },
451467 tool_context = tool_context ,
452468 )
453- await asyncio . sleep ( 0.01 )
469+ await plugin . close ( )
454470 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
455471 content = json .loads (log_entry ["content" ])
456472
@@ -487,7 +503,7 @@ async def test_max_content_length_tool_result(
487503 tool_context = tool_context ,
488504 result = {"res" : long_res },
489505 )
490- await asyncio . sleep ( 0.01 )
506+ await plugin . close ( )
491507 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
492508 content = json .loads (log_entry ["content" ])
493509
@@ -523,7 +539,7 @@ async def test_max_content_length_tool_error(
523539 tool_context = tool_context ,
524540 error = ValueError ("Oops" ),
525541 )
526- await asyncio . sleep ( 0.01 )
542+ await plugin . close ( )
527543 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
528544 content = json .loads (log_entry ["content" ])
529545
@@ -541,10 +557,11 @@ async def test_on_user_message_callback_logs_correctly(
541557 await bq_plugin_inst .on_user_message_callback (
542558 invocation_context = invocation_context , user_message = user_message
543559 )
544- await asyncio . sleep ( 0.01 )
560+ await bq_plugin_inst . close ( )
545561 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
546562 _assert_common_fields (log_entry , "USER_MESSAGE_RECEIVED" )
547563
564+ # UPDATED ASSERTION: Check JSON structure
548565 content = json .loads (log_entry ["content" ])
549566 assert content ["text" ] == "What is up?"
550567
@@ -567,7 +584,7 @@ async def test_on_event_callback_tool_call(
567584 await bq_plugin_inst .on_event_callback (
568585 invocation_context = invocation_context , event = event
569586 )
570- await asyncio . sleep ( 0.01 )
587+ await bq_plugin_inst . close ( )
571588 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
572589 _assert_common_fields (log_entry , "TOOL_CALL" , agent = "MyTestAgent" )
573590
@@ -594,7 +611,7 @@ async def test_on_event_callback_model_response(
594611 await bq_plugin_inst .on_event_callback (
595612 invocation_context = invocation_context , event = event
596613 )
597- await asyncio . sleep ( 0.01 )
614+ await bq_plugin_inst . close ( )
598615 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
599616 _assert_common_fields (log_entry , "MODEL_RESPONSE" , agent = "MyTestAgent" )
600617
@@ -625,7 +642,7 @@ async def test_before_model_callback_logs_structure(
625642 await bq_plugin_inst .before_model_callback (
626643 callback_context = callback_context , llm_request = llm_request
627644 )
628- await asyncio . sleep ( 0.01 )
645+ await bq_plugin_inst . close ( )
629646 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
630647 _assert_common_fields (log_entry , "LLM_REQUEST" )
631648
@@ -657,10 +674,11 @@ async def test_after_model_callback_text_response(
657674 await bq_plugin_inst .after_model_callback (
658675 callback_context = callback_context , llm_response = llm_response
659676 )
660- await asyncio . sleep ( 0.01 )
677+ await bq_plugin_inst . close ( )
661678 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
662679 _assert_common_fields (log_entry , "LLM_RESPONSE" )
663680
681+ # UPDATED ASSERTION: Check structured JSON
664682 content = json .loads (log_entry ["content" ])
665683 assert content ["response_content" ][0 ]["type" ] == "text"
666684 assert content ["response_content" ][0 ]["text" ] == "Model response"
@@ -685,7 +703,7 @@ async def test_after_model_callback_tool_call(
685703 await bq_plugin_inst .after_model_callback (
686704 callback_context = callback_context , llm_response = llm_response
687705 )
688- await asyncio . sleep ( 0.01 )
706+ await bq_plugin_inst . close ( )
689707 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
690708 _assert_common_fields (log_entry , "LLM_RESPONSE" )
691709
@@ -707,10 +725,11 @@ async def test_before_tool_callback_logs_correctly(
707725 await bq_plugin_inst .before_tool_callback (
708726 tool = mock_tool , tool_args = {"param" : "value" }, tool_context = tool_context
709727 )
710- await asyncio . sleep ( 0.01 )
728+ await bq_plugin_inst . close ( )
711729 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
712730 _assert_common_fields (log_entry , "TOOL_STARTING" )
713731
732+ # UPDATED ASSERTION: Check structured JSON
714733 content = json .loads (log_entry ["content" ])
715734 assert content ["tool_name" ] == "MyTool"
716735 assert content ["description" ] == "Description"
@@ -731,10 +750,11 @@ async def test_after_tool_callback_logs_correctly(
731750 tool_context = tool_context ,
732751 result = {"status" : "success" },
733752 )
734- await asyncio . sleep ( 0.01 )
753+ await bq_plugin_inst . close ( )
735754 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
736755 _assert_common_fields (log_entry , "TOOL_COMPLETED" )
737756
757+ # UPDATED ASSERTION: Check structured JSON
738758 content = json .loads (log_entry ["content" ])
739759 assert content ["tool_name" ] == "MyTool"
740760 assert content ["result" ]["status" ] == "success"
@@ -755,7 +775,7 @@ async def test_on_model_error_callback_logs_correctly(
755775 await bq_plugin_inst .on_model_error_callback (
756776 callback_context = callback_context , llm_request = llm_request , error = error
757777 )
758- await asyncio . sleep ( 0.01 )
778+ await bq_plugin_inst . close ( )
759779 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
760780 _assert_common_fields (log_entry , "LLM_ERROR" )
761781 assert log_entry ["content" ] is None
@@ -777,7 +797,7 @@ async def test_on_tool_error_callback_logs_correctly(
777797 tool_context = tool_context ,
778798 error = error ,
779799 )
780- await asyncio . sleep ( 0.01 )
800+ await bq_plugin_inst . close ( )
781801 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
782802 _assert_common_fields (log_entry , "TOOL_ERROR" )
783803
@@ -809,7 +829,9 @@ async def test_bigquery_client_initialization_failure(
809829 invocation_context = invocation_context ,
810830 user_message = types .Content (parts = [types .Part (text = "Test" )]),
811831 )
812- await asyncio .sleep (0.01 )
832+ # Wait for the background task (which logs the error) to complete
833+ await plugin_with_fail .close ()
834+
813835 mock_log_error .assert_any_call ("BQ Plugin: Init Failed:" , exc_info = True )
814836 mock_write_client .append_rows .assert_not_called ()
815837
@@ -832,7 +854,7 @@ async def fake_append_rows_with_error(requests, **kwargs):
832854 invocation_context = invocation_context ,
833855 user_message = types .Content (parts = [types .Part (text = "Test" )]),
834856 )
835- await asyncio . sleep ( 0.01 )
857+ await bq_plugin_inst . close ( )
836858 mock_log_error .assert_called_with (
837859 "BQ Plugin: Write Error: %s" , "Test BQ Error"
838860 )
@@ -861,7 +883,7 @@ async def fake_append_rows_with_schema_error(requests, **kwargs):
861883 invocation_context = invocation_context ,
862884 user_message = types .Content (parts = [types .Part (text = "Test" )]),
863885 )
864- await asyncio . sleep ( 0.01 )
886+ await bq_plugin_inst . close ( )
865887 mock_log_error .assert_called_with (
866888 "BQ Plugin: Schema Mismatch. You may need to delete the existing"
867889 " table if you migrated from STRING content to JSON content."
@@ -889,7 +911,7 @@ async def test_before_run_callback_logs_correctly(
889911 await bq_plugin_inst .before_run_callback (
890912 invocation_context = invocation_context
891913 )
892- await asyncio . sleep ( 0.01 )
914+ await bq_plugin_inst . close ( )
893915 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
894916 _assert_common_fields (log_entry , "INVOCATION_STARTING" )
895917 assert log_entry ["content" ] is None
@@ -905,7 +927,7 @@ async def test_after_run_callback_logs_correctly(
905927 await bq_plugin_inst .after_run_callback (
906928 invocation_context = invocation_context
907929 )
908- await asyncio . sleep ( 0.01 )
930+ await bq_plugin_inst . close ( )
909931 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
910932 _assert_common_fields (log_entry , "INVOCATION_COMPLETED" )
911933 assert log_entry ["content" ] is None
@@ -922,7 +944,7 @@ async def test_before_agent_callback_logs_correctly(
922944 await bq_plugin_inst .before_agent_callback (
923945 agent = mock_agent , callback_context = callback_context
924946 )
925- await asyncio . sleep ( 0.01 )
947+ await bq_plugin_inst . close ( )
926948 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
927949 _assert_common_fields (log_entry , "AGENT_STARTING" )
928950
@@ -941,7 +963,7 @@ async def test_after_agent_callback_logs_correctly(
941963 await bq_plugin_inst .after_agent_callback (
942964 agent = mock_agent , callback_context = callback_context
943965 )
944- await asyncio . sleep ( 0.01 )
966+ await bq_plugin_inst . close ( )
945967 log_entry = _get_captured_event_dict (mock_write_client , dummy_arrow_schema )
946968 _assert_common_fields (log_entry , "AGENT_COMPLETED" )
947969
0 commit comments