diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 23c73098c49e..0792785c2eb3 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -14,8 +14,10 @@ from typing import ( Any, AsyncGenerator, + Coroutine, Dict, List, + Literal, Optional, Tuple, Union, @@ -404,8 +406,12 @@ async def async_pre_call_hook( verbose_proxy_logger.debug("content_safety: %s", content_safety) presidio_config = self.get_presidio_settings_from_request_data(data) messages = data["messages"] - tasks = [] - for m in messages: + tasks: list[Coroutine[Any, Any, str]] = [] + targets: list[ + tuple[Literal["str"], int] | tuple[Literal["block"], int, int, str] + ] = [] # track where to write back each presidio result + + for msg_idx, m in enumerate(messages): content = m.get("content", None) if content is None: continue @@ -418,15 +424,36 @@ async def async_pre_call_hook( request_data=data, ) ) + + # string content -> write back to messages[msg_idx]['content'] + targets.append(("str", msg_idx)) + elif isinstance(content, list): + # handle only dict blocks with a string 'text' field + for b_idx, block in enumerate(content): + if isinstance(block, dict) and isinstance(block.get("text"), str): + tasks.append( + self.check_pii( + text=block["text"], + output_parse_pii=self.output_parse_pii, + presidio_config=presidio_config, + request_data=data, + ) + ) + + # block content -> write back to messages[msg_idx]['content'][b_idx]['text'] + targets.append(("block", msg_idx, b_idx, "text")) + responses = await asyncio.gather(*tasks) - for index, r in enumerate(responses): - content = messages[index].get("content", None) - if content is None: - continue - if isinstance(content, str): - messages[index][ - "content" - ] = r # replace content with redacted string + + # write results back to the exact targets collected above + for redacted, tgt in zip(responses, targets): + if tgt[0] == "str": + _, mi = tgt + messages[mi]["content"] = redacted + else: + _, mi, bi, key = tgt + messages[mi]["content"][bi][key] = redacted + verbose_proxy_logger.debug( f"Presidio PII Masking: Redacted pii message: {data['messages']}" ) @@ -672,4 +699,4 @@ def update_in_memory_litellm_params(self, litellm_params: LitellmParams) -> None """ super().update_in_memory_litellm_params(litellm_params) if litellm_params.pii_entities_config: - self.pii_entities_config = litellm_params.pii_entities_config + self.pii_entities_config = litellm_params.pii_entities_config \ No newline at end of file diff --git a/tests/guardrails_tests/test_presidio_pii.py b/tests/guardrails_tests/test_presidio_pii.py index e3f811ba7df3..826caf4b22ae 100644 --- a/tests/guardrails_tests/test_presidio_pii.py +++ b/tests/guardrails_tests/test_presidio_pii.py @@ -592,3 +592,48 @@ async def test_presidio_language_configuration_with_per_request_override(): # Verify the default language (German) is used assert analyze_request_default["language"] == "de" assert analyze_request_default["text"] == test_text + + +@pytest.mark.asyncio +async def test_presidio_pre_call_hook_with_content_list_single_block(): + """ + Ensure async_pre_call_hook redacts list-of-blocks content with a 'text' field. + """ + mock_redacted = { + "text": "email is ", + "items": [ + { + "start": 9, + "end": 24, + "entity_type": "EMAIL_ADDRESS", + "text": "", + "operator": "replace", + } + ], + } + + presidio_guardrail = _OPTIONAL_PresidioPIIMasking( + mock_testing=True, + mock_redacted_text=mock_redacted, + ) + + data = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "email is test@example.com"} + ], + } + ] + } + + new_data = await presidio_guardrail.async_pre_call_hook( + user_api_key_dict=UserAPIKeyAuth(api_key="x"), + cache=DualCache(), + data=data, + call_type="completion", + ) + + assert isinstance(new_data["messages"][0]["content"], list) + assert new_data["messages"][0]["content"][0]["text"] == "email is "