Skip to content

Commit a173232

Browse files
authored
Fix MLFlow tags - split request_tags into (key, val) if request_tag has colon (#15914)
* Fix mlflow tags - split request_tags into (key, val) if request_tag has colon * Redundant name: tag_dict -> tags
1 parent 8b14241 commit a173232

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

litellm/integrations/mlflow.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def _extract_and_set_chat_attributes(self, span, kwargs, response_obj):
6060

6161
inputs = self._construct_input(kwargs)
6262
input_messages = inputs.get("messages", [])
63-
output_messages = [c.message.model_dump(exclude_none=True) for c in getattr(response_obj, "choices", [])]
63+
output_messages = [
64+
c.message.model_dump(exclude_none=True)
65+
for c in getattr(response_obj, "choices", [])
66+
]
6467
if messages := [*input_messages, *output_messages]:
6568
set_span_chat_messages(span, messages)
6669
if tools := inputs.get("tools"):
@@ -184,7 +187,9 @@ def _extract_attributes(self, kwargs):
184187
"call_type": kwargs.get("call_type"),
185188
"model": kwargs.get("model"),
186189
}
187-
standard_obj: Optional[StandardLoggingPayload] = kwargs.get("standard_logging_object")
190+
standard_obj: Optional[StandardLoggingPayload] = kwargs.get(
191+
"standard_logging_object"
192+
)
188193
if standard_obj:
189194
attributes.update(
190195
{
@@ -257,12 +262,25 @@ def _start_span_or_trace(self, kwargs, start_time):
257262
span_type=span_type,
258263
inputs=inputs,
259264
attributes=attributes,
260-
tags=self._transform_tag_list_to_dict(attributes.get("request_tags", [])),
265+
tags=self._transform_tag_list_to_dict(
266+
attributes.get("request_tags", [])
267+
),
261268
start_time_ns=start_time_ns,
262269
)
263270

264271
def _transform_tag_list_to_dict(self, tag_list: list) -> dict:
265-
return {tag: "" for tag in tag_list}
272+
"""
273+
Transform a list of colon-separated tags into a dictionary.
274+
Tags without colons are stored with empty string as the value.
275+
"""
276+
tags = {}
277+
for tag in tag_list:
278+
if ":" in tag:
279+
k, v = tag.split(":", 1)
280+
tags[k.strip()] = v.strip()
281+
else:
282+
tags[tag.strip()] = ""
283+
return tags
266284

267285
def _end_span_or_trace(self, span, outputs, end_time_ns, status):
268286
"""End an MLflow span or a trace."""

tests/test_litellm/integrations/test_mlflow.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def test_mlflow_logging_functionality():
5757
messages=[{"role": "user", "content": "test message"}],
5858
prediction=test_prediction,
5959
mock_response="test response",
60-
metadata={"tags": ["tag1", "tag2", "production"]},
60+
metadata={"tags": ["tag1", "tag2", "production", "jobID:214590dsff09fds", "taskName:run_page_classification"]},
6161
)
6262

6363
# Allow time for async processing
@@ -72,7 +72,13 @@ async def test_mlflow_logging_functionality():
7272

7373
# Check that tags parameter was included and properly transformed
7474
tags_param = call_args.kwargs.get("tags", {})
75-
expected_tags = {"tag1": "", "tag2": "", "production": ""}
75+
expected_tags = {
76+
"tag1": "",
77+
"tag2": "",
78+
"production": "",
79+
"jobID": "214590dsff09fds",
80+
"taskName": "run_page_classification",
81+
}
7682
assert tags_param == expected_tags, f"Expected tags {expected_tags}, got {tags_param}"
7783

7884
# Check that prediction parameter was included in inputs

0 commit comments

Comments
 (0)