diff --git a/cognite/client/data_classes/agents/agent_tools.py b/cognite/client/data_classes/agents/agent_tools.py index d19d95eff3..5d4671c185 100644 --- a/cognite/client/data_classes/agents/agent_tools.py +++ b/cognite/client/data_classes/agents/agent_tools.py @@ -14,6 +14,9 @@ WriteableCogniteResourceList, ) +# Constants +DEFAULT_QKG_VERSION = "v2" + @dataclass class AgentToolCore(WriteableCogniteResource["AgentToolUpsert"], ABC): @@ -166,10 +169,13 @@ class QueryKnowledgeGraphAgentToolConfiguration(WriteableCogniteResource): Args: data_models (Sequence[DataModelInfo] | None): The data models and views to query. instance_spaces (InstanceSpaces | None): The instance spaces to query. + version (Literal["v1", "v2"]): The version of the QKG tool to use. + Defaults to DEFAULT_QKG_VERSION ("v2"). """ data_models: Sequence[DataModelInfo] | None = None instance_spaces: InstanceSpaces | None = None + version: Literal["v1", "v2"] = DEFAULT_QKG_VERSION @classmethod def _load( @@ -183,9 +189,12 @@ def _load( if "instanceSpaces" in resource: instance_spaces = InstanceSpaces._load(resource["instanceSpaces"]) + version = resource.get("version", DEFAULT_QKG_VERSION) + return cls( data_models=data_models, instance_spaces=instance_spaces, + version=version, ) def dump(self, camel_case: bool = True) -> dict[str, Any]: @@ -196,6 +205,7 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]: if self.instance_spaces: key = "instanceSpaces" if camel_case else "instance_spaces" result[key] = self.instance_spaces.dump(camel_case=camel_case) + result["version"] = self.version return result def as_write(self) -> QueryKnowledgeGraphAgentToolConfiguration: diff --git a/tests/tests_unit/test_data_classes/test_agents/test_agent_tools.py b/tests/tests_unit/test_data_classes/test_agents/test_agent_tools.py index 73297f732a..2665c79f7b 100644 --- a/tests/tests_unit/test_data_classes/test_agents/test_agent_tools.py +++ b/tests/tests_unit/test_data_classes/test_agents/test_agent_tools.py @@ -3,6 +3,7 @@ import pytest from cognite.client.data_classes.agents.agent_tools import ( + DEFAULT_QKG_VERSION, AgentTool, AskDocumentAgentTool, QueryKnowledgeGraphAgentTool, @@ -49,6 +50,45 @@ "configuration": {"key": "value"}, } +# Test QKG examples with different versions +qkg_example_with_v2 = { + "name": "qkgExampleWithV2", + "type": "queryKnowledgeGraph", + "description": "Query the knowledge graph with v2", + "configuration": { + "dataModels": [ + {"space": "cdf_cdm", "externalId": "CogniteCore", "version": "v1", "viewExternalIds": ["CogniteAsset"]} + ], + "instanceSpaces": {"type": "manual", "spaces": ["my_space"]}, + "version": DEFAULT_QKG_VERSION, + }, +} + +qkg_example_v1 = { + "name": "qkgExampleV1", + "type": "queryKnowledgeGraph", + "description": "Query the knowledge graph with v1", + "configuration": { + "dataModels": [ + {"space": "cdf_cdm", "externalId": "CogniteCore", "version": "v1", "viewExternalIds": ["CogniteAsset"]} + ], + "instanceSpaces": {"type": "manual", "spaces": ["my_space"]}, + "version": "v1", + }, +} + +qkg_example_no_version = { + "name": "qkgExampleNoVersion", + "type": "queryKnowledgeGraph", + "description": "Query the knowledge graph without version specified", + "configuration": { + "dataModels": [ + {"space": "cdf_cdm", "externalId": "CogniteCore", "version": "v1", "viewExternalIds": ["CogniteAsset"]} + ], + "instanceSpaces": {"type": "manual", "spaces": ["my_space"]}, + }, +} + class TestAgentToolLoad: @pytest.mark.parametrize( @@ -81,7 +121,11 @@ def test_agent_tool_load_returns_correct_subtype(self, tool_data: dict, expected # For QueryKnowledgeGraph, we expect a structured configuration object assert isinstance(loaded_tool.configuration, QueryKnowledgeGraphAgentToolConfiguration) # Compare by serializing the structured object back to dict - assert loaded_tool.configuration.dump(camel_case=True) == tool_data["configuration"] + # Version field is added automatically if not present, so we need to account for it + expected_config = tool_data["configuration"].copy() + if "version" not in expected_config: + expected_config["version"] = DEFAULT_QKG_VERSION # Default version + assert loaded_tool.configuration.dump(camel_case=True) == expected_config else: # For other tools (like UnknownAgentTool), configuration should be a dict assert loaded_tool.configuration == tool_data["configuration"] @@ -129,7 +173,7 @@ def test_agent_tool_dump_returns_correct_type_for_unknown_tool(self) -> None: assert dumped_tool["description"] == unknown_example["description"] assert dumped_tool["configuration"] == unknown_example["configuration"] - def test_agent_tool_dump_returns_correct_type_for_query_knowledge_graph_tool(self) -> None: + def test_agent_tool_dump_returns_correct_type_for_qkg_tool(self) -> None: """Test that AgentTool.dump() returns the correct type for query knowledge graph tools.""" loaded_tool = AgentTool._load(qkg_example) dumped_tool = loaded_tool.dump(camel_case=True) @@ -137,7 +181,11 @@ def test_agent_tool_dump_returns_correct_type_for_query_knowledge_graph_tool(sel assert dumped_tool["type"] == "queryKnowledgeGraph" assert dumped_tool["name"] == qkg_example["name"] assert dumped_tool["description"] == qkg_example["description"] - assert dumped_tool["configuration"] == qkg_example["configuration"] + + # Check configuration components individually since version is now added automatically + expected_config = qkg_example["configuration"].copy() + expected_config["version"] = DEFAULT_QKG_VERSION # Default version is added during load/dump + assert dumped_tool["configuration"] == expected_config class TestAgentToolUpsert: @@ -163,3 +211,55 @@ def test_agent_tool_upsert_returns_correct_type(self, tool_data: dict, expected_ assert dumped_tool["name"] == tool_data["name"] assert dumped_tool["description"] == tool_data["description"] + + +class TestQueryKnowledgeGraphAgentToolVersions: + """Test QKG tool version functionality.""" + + def test_qkg_tool_with_explicit_v2_version(self) -> None: + """Test QKG tool with explicit v2 version.""" + loaded_tool = AgentTool._load(qkg_example_with_v2) + + assert isinstance(loaded_tool, QueryKnowledgeGraphAgentTool) + assert loaded_tool.configuration is not None + assert loaded_tool.configuration.version == DEFAULT_QKG_VERSION + + # Test that it dumps correctly + dumped_tool = loaded_tool.dump(camel_case=True) + assert dumped_tool["configuration"]["version"] == DEFAULT_QKG_VERSION + + def test_qkg_tool_with_explicit_v1_version(self) -> None: + """Test QKG tool with explicit v1 version.""" + loaded_tool = AgentTool._load(qkg_example_v1) + + assert isinstance(loaded_tool, QueryKnowledgeGraphAgentTool) + assert loaded_tool.configuration is not None + assert loaded_tool.configuration.version == "v1" + + # Test that it dumps correctly + dumped_tool = loaded_tool.dump(camel_case=True) + assert dumped_tool["configuration"]["version"] == "v1" + + def test_qkg_tool_defaults_to_v2_when_no_version_specified(self) -> None: + """Test QKG tool defaults to v2 when no version is specified.""" + loaded_tool = AgentTool._load(qkg_example_no_version) + + assert isinstance(loaded_tool, QueryKnowledgeGraphAgentTool) + assert loaded_tool.configuration is not None + assert loaded_tool.configuration.version == DEFAULT_QKG_VERSION + + # Test that it dumps correctly with default version + dumped_tool = loaded_tool.dump(camel_case=True) + assert dumped_tool["configuration"]["version"] == DEFAULT_QKG_VERSION + + def test_qkg_tool_upsert_preserves_version(self) -> None: + """Test that QKG tool upsert preserves version information.""" + loaded_tool = AgentTool._load(qkg_example_v1) + upsert_tool = loaded_tool.as_write() + + assert upsert_tool.configuration is not None + assert upsert_tool.configuration.version == "v1" + + # Test that upsert dumps correctly + dumped_tool = upsert_tool.dump(camel_case=True) + assert dumped_tool["configuration"]["version"] == "v1"