Skip to content

Added functions read_resource and resource_templates from client/session.py to client/session_group.py #905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, TypeAlias

import anyio
from pydantic import BaseModel
from pydantic import AnyUrl, BaseModel
from typing_extensions import Self

import mcp
Expand Down Expand Up @@ -100,6 +100,7 @@ class _ComponentNames(BaseModel):
# Client-server connection management.
_sessions: dict[mcp.ClientSession, _ComponentNames]
_tool_to_session: dict[str, mcp.ClientSession]
_resource_to_session: dict[str, mcp.ClientSession]
_exit_stack: contextlib.AsyncExitStack
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]

Expand All @@ -116,20 +117,16 @@ def __init__(
) -> None:
"""Initializes the MCP client."""

self._tools = {}
self._resources = {}
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
self._owns_exit_stack = exit_stack is None
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook
self._prompts = {}

self._resources = {}
self._tools = {}
self._sessions = {}
self._tool_to_session = {}
if exit_stack is None:
self._exit_stack = contextlib.AsyncExitStack()
self._owns_exit_stack = True
else:
self._exit_stack = exit_stack
self._owns_exit_stack = False
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook
self._resource_to_session = {} # New mapping

async def __aenter__(self) -> Self:
# Enter the exit stack only if we created it ourselves
Expand Down Expand Up @@ -174,6 +171,16 @@ def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools

@property
def resource_templates(self) -> list[types.ResourceTemplate]:
"""Return all unique resource templates from the resources."""
templates: list[types.ResourceTemplate] = []
for r in self._resources.values():
t = getattr(r, "template", None)
if t is not None and t not in templates:
templates.append(t)
return templates

async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
session = self._tool_to_session[name]
Expand Down Expand Up @@ -296,8 +303,8 @@ async def _aggregate_components(
resources_temp: dict[str, types.Resource] = {}
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
resource_to_session_temp: dict[str, mcp.ClientSession] = {}

# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
Expand All @@ -314,6 +321,7 @@ async def _aggregate_components(
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
resource_to_session_temp[name] = session
except McpError as err:
logging.warning(f"Could not fetch resources: {err}")

Expand Down Expand Up @@ -365,8 +373,20 @@ async def _aggregate_components(
self._resources.update(resources_temp)
self._tools.update(tools_temp)
self._tool_to_session.update(tool_to_session_temp)
self._resource_to_session.update(resource_to_session_temp)

def _component_name(self, name: str, server_info: types.Implementation) -> str:
if self._component_name_hook:
return self._component_name_hook(name, server_info)
return name

async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Read a resource from the appropriate session based on the URI."""
print(self._resources)
print(self._resource_to_session)
for name, resource in self._resources.items():
if resource.uri == uri:
session = self._resource_to_session.get(name)
if session:
return await session.read_resource(uri)
raise ValueError(f"Resource not found: {uri}")
81 changes: 81 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest import mock

import pytest
from pydantic import AnyUrl

import mcp
from mcp import types
Expand Down Expand Up @@ -395,3 +396,83 @@ async def test_establish_session_parameterized(
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.serverInfo
assert returned_session is mock_entered_session

@pytest.mark.anyio
async def test_read_resource_not_found(self):
"""Test reading a non-existent resource from a session group."""
# --- Mock Dependencies ---
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
test_resource = types.Resource(
name="test_resource",
uri=AnyUrl("test://resource/1"),
description="Test resource",
)

# Mock all list methods
mock_session.list_resources.return_value = types.ListResourcesResult(
resources=[test_resource]
)
mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[])
mock_session.list_tools.return_value = types.ListToolsResult(tools=[])

# --- Test Setup ---
group = ClientSessionGroup()
group._session_exit_stacks[mock_session] = mock.AsyncMock(
spec=contextlib.AsyncExitStack
)
await group.connect_with_session(
types.Implementation(name="test_server", version="1.0.0"), mock_session
)

# --- Test Execution & Assertions ---
with pytest.raises(ValueError, match="Resource not found: test://nonexistent"):
await group.read_resource(AnyUrl("test://nonexistent"))

@pytest.mark.anyio
async def test_read_resource_success(self):
"""Test successfully reading a resource from a session group."""
# --- Mock Dependencies ---
mock_session = mock.AsyncMock(spec=mcp.ClientSession)
test_resource = types.Resource(
name="test_resource",
uri=AnyUrl("test://resource/1"),
description="Test resource",
)

# Mock all list methods
mock_session.list_resources.return_value = types.ListResourcesResult(
resources=[test_resource]
)
mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[])
mock_session.list_tools.return_value = types.ListToolsResult(tools=[])

# Mock the session's read_resource method
mock_read_result = mock.AsyncMock(spec=types.ReadResourceResult)
mock_read_result.contents = [
types.TextContent(type="text", text="Resource content")
]
mock_session.read_resource.return_value = mock_read_result

# --- Test Setup ---
group = ClientSessionGroup()
group._session_exit_stacks[mock_session] = mock.AsyncMock(
spec=contextlib.AsyncExitStack
)
await group.connect_with_session(
types.Implementation(name="test_server", version="1.0.0"), mock_session
)

# Verify resource was added
assert "test_resource" in group._resources
assert group._resources["test_resource"] == test_resource
assert "test_resource" in group._resource_to_session
assert group._resource_to_session["test_resource"] == mock_session

# --- Test Execution ---
result = await group.read_resource(AnyUrl("test://resource/1"))

# --- Assertions ---
assert result.contents == [
types.TextContent(type="text", text="Resource content")
]
mock_session.read_resource.assert_called_once_with(AnyUrl("test://resource/1"))
Loading