Skip to content

Commit 717ee8a

Browse files
committed
fixing up some streaming http errors
1 parent e23a7b4 commit 717ee8a

5 files changed

Lines changed: 284 additions & 20 deletions

File tree

Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ WORKDIR /app
33
COPY . .
44
RUN pip install --no-cache-dir -r requirements.txt \
55
&& pip install fastapi uvicorn
6-
ENV STACKHAWK_API_KEY=changeme
76
EXPOSE 8080
87
# Default: run HTTP server
98
ENTRYPOINT ["uvicorn", "stackhawk_mcp.http_server:app", "--host", "0.0.0.0", "--port", "8080"]

requirements.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ jsonschema>=4.0.0
1111
pytest>=7.0.0
1212
pytest-asyncio>=0.21.0
1313
black>=23.0.0
14-
mypy>=1.0.0
14+
mypy>=1.0.0
15+
16+
# FastAPI dependencies
17+
fastapi
18+
uvicorn

stackhawk_mcp/http_server.py

Lines changed: 235 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,257 @@
11
import os
2-
from fastapi import FastAPI, Request
3-
from fastapi.responses import JSONResponse
2+
import json
3+
import uuid
4+
from fastapi import FastAPI, Request, Response
5+
from fastapi.responses import JSONResponse, StreamingResponse
6+
from fastapi.middleware.cors import CORSMiddleware
47
import uvicorn
58
import asyncio
69
from stackhawk_mcp.server import StackHawkMCPServer
710

811
app = FastAPI()
912

13+
# Add CORS middleware
14+
app.add_middleware(
15+
CORSMiddleware,
16+
allow_origins=["*"], # In production, restrict this
17+
allow_credentials=True,
18+
allow_methods=["*"],
19+
allow_headers=["*"],
20+
)
21+
1022
# Get API key from environment
1123
API_KEY = os.environ.get("STACKHAWK_API_KEY", "changeme")
1224

1325
# Create the MCP server instance
1426
mcp_server = StackHawkMCPServer(api_key=API_KEY)
1527

16-
@app.post("/call_tool")
17-
async def call_tool(request: Request):
18-
data = await request.json()
19-
name = data["name"]
20-
arguments = data.get("arguments", {})
21-
# Call the tool handler
28+
# Store active SSE connections
29+
active_connections = {}
30+
31+
def create_jsonrpc_response(id_value, result=None, error=None):
32+
"""Create a proper JSON-RPC 2.0 response"""
33+
response = {
34+
"jsonrpc": "2.0",
35+
"id": id_value
36+
}
37+
if error:
38+
response["error"] = error
39+
else:
40+
response["result"] = result
41+
return response
42+
43+
def fix_tool_schema(tool):
44+
"""Fix tool schema to ensure outputSchema and annotations are objects"""
45+
if isinstance(tool, dict):
46+
# Ensure outputSchema is an object with proper type
47+
if tool.get("outputSchema") is None:
48+
tool["outputSchema"] = {"type": "object"}
49+
elif isinstance(tool["outputSchema"], dict) and tool["outputSchema"].get("type") is None:
50+
tool["outputSchema"]["type"] = "object"
51+
52+
# Ensure annotations is an object
53+
if tool.get("annotations") is None:
54+
tool["annotations"] = {}
55+
56+
# Ensure meta is an object
57+
if tool.get("meta") is None:
58+
tool["meta"] = {}
59+
60+
return tool
61+
62+
async def handle_initialize_request(request_data):
63+
"""Handle MCP initialize request"""
64+
return create_jsonrpc_response(
65+
request_data.get("id"),
66+
{
67+
"protocolVersion": "2025-03-26",
68+
"capabilities": {
69+
"tools": {}
70+
},
71+
"serverInfo": {
72+
"name": "StackHawk MCP",
73+
"version": "0.1.0"
74+
}
75+
}
76+
)
77+
78+
async def handle_list_tools_request(request_data):
79+
"""Handle MCP list tools request"""
2280
try:
23-
# handle_call_tool returns a list of TextContent, convert to dicts
24-
result = await mcp_server.server._call_tool(name, arguments)
25-
# Convert TextContent objects to dicts if needed
26-
return JSONResponse(content={"result": [r.dict() if hasattr(r, 'dict') else r for r in result]})
81+
tools = await mcp_server.list_tools()
82+
# Fix the tool schemas to ensure proper objects
83+
fixed_tools = [fix_tool_schema(t.dict() if hasattr(t, 'dict') else t) for t in tools]
84+
return create_jsonrpc_response(
85+
request_data.get("id"),
86+
{
87+
"tools": fixed_tools
88+
}
89+
)
2790
except Exception as e:
28-
return JSONResponse(content={"error": str(e)}, status_code=500)
91+
return create_jsonrpc_response(
92+
request_data.get("id"),
93+
error={
94+
"code": -1,
95+
"message": str(e)
96+
}
97+
)
2998

30-
@app.get("/list_tools")
31-
async def list_tools():
99+
async def handle_call_tool_request(request_data):
100+
"""Handle MCP call tool request"""
32101
try:
33-
# handle_list_tools returns a list of Tool objects
34-
result = await mcp_server.server._list_tools()
35-
return JSONResponse(content={"tools": [t.dict() if hasattr(t, 'dict') else t for t in result]})
102+
params = request_data.get("params", {})
103+
name = params.get("name")
104+
arguments = params.get("arguments", {})
105+
106+
result = await mcp_server.call_tool(name, arguments)
107+
return create_jsonrpc_response(
108+
request_data.get("id"),
109+
{
110+
"content": [r.dict() if hasattr(r, 'dict') else r for r in result]
111+
}
112+
)
113+
except Exception as e:
114+
return create_jsonrpc_response(
115+
request_data.get("id"),
116+
error={
117+
"code": -1,
118+
"message": str(e)
119+
}
120+
)
121+
122+
@app.post("/mcp")
123+
async def mcp_endpoint(request: Request):
124+
"""Main MCP endpoint that handles all JSON-RPC messages"""
125+
126+
# Check Accept header
127+
accept_header = request.headers.get("accept", "")
128+
if "application/json" not in accept_header and "text/event-stream" not in accept_header:
129+
return JSONResponse(
130+
content={"error": "Accept header must include application/json or text/event-stream"},
131+
status_code=400
132+
)
133+
134+
try:
135+
# Parse request body
136+
body = await request.body()
137+
if not body:
138+
return JSONResponse(content={"error": "Empty request body"}, status_code=400)
139+
140+
data = json.loads(body)
141+
142+
# Handle batched requests
143+
if isinstance(data, list):
144+
responses = []
145+
for item in data:
146+
response = await handle_jsonrpc_message(item)
147+
if response:
148+
responses.append(response)
149+
150+
if len(responses) == 1:
151+
return JSONResponse(content=responses[0])
152+
else:
153+
return JSONResponse(content=responses)
154+
else:
155+
# Single request
156+
response = await handle_jsonrpc_message(data)
157+
if response:
158+
return JSONResponse(content=response)
159+
else:
160+
return Response(status_code=202) # Accepted with no body
161+
162+
except json.JSONDecodeError:
163+
return JSONResponse(content={"error": "Invalid JSON"}, status_code=400)
36164
except Exception as e:
37165
return JSONResponse(content={"error": str(e)}, status_code=500)
38166

167+
async def handle_jsonrpc_message(message):
168+
"""Handle individual JSON-RPC message"""
169+
if not isinstance(message, dict):
170+
return None
171+
172+
method = message.get("method")
173+
message_id = message.get("id")
174+
175+
if method == "initialize":
176+
return await handle_initialize_request(message)
177+
elif method == "tools/list":
178+
return await handle_list_tools_request(message)
179+
elif method == "tools/call":
180+
return await handle_call_tool_request(message)
181+
elif method == "notifications/cancelled":
182+
# Handle cancellation
183+
return None
184+
else:
185+
return create_jsonrpc_response(
186+
message_id,
187+
error={
188+
"code": -32601,
189+
"message": f"Method not found: {method}"
190+
}
191+
)
192+
193+
@app.get("/mcp")
194+
async def mcp_sse_endpoint(request: Request):
195+
"""SSE endpoint for streaming responses"""
196+
197+
# Check Accept header
198+
accept_header = request.headers.get("accept", "")
199+
if "text/event-stream" not in accept_header:
200+
return JSONResponse(
201+
content={"error": "Accept header must include text/event-stream"},
202+
status_code=405
203+
)
204+
205+
# Generate connection ID
206+
connection_id = str(uuid.uuid4())
207+
active_connections[connection_id] = True
208+
209+
async def event_stream():
210+
try:
211+
while active_connections.get(connection_id, False):
212+
# For now, just keep the connection alive
213+
# In a real implementation, you'd send actual events here
214+
await asyncio.sleep(1)
215+
except Exception:
216+
pass
217+
finally:
218+
if connection_id in active_connections:
219+
del active_connections[connection_id]
220+
221+
return StreamingResponse(
222+
event_stream(),
223+
media_type="text/event-stream",
224+
headers={
225+
"Cache-Control": "no-cache",
226+
"Connection": "keep-alive",
227+
"Content-Type": "text/event-stream",
228+
}
229+
)
230+
231+
# Legacy endpoints for manual testing (keep these for now)
232+
@app.get("/")
233+
async def root():
234+
return JSONResponse(content={
235+
"jsonrpc": "2.0",
236+
"id": 1,
237+
"result": {
238+
"serverName": "StackHawk MCP",
239+
"serverVersion": "0.1.0",
240+
"protocolVersion": "v1"
241+
}
242+
})
243+
244+
@app.post("/")
245+
async def root_post():
246+
return JSONResponse(content={
247+
"jsonrpc": "2.0",
248+
"id": 1,
249+
"result": {
250+
"serverName": "StackHawk MCP",
251+
"serverVersion": "0.1.0",
252+
"protocolVersion": "v1"
253+
}
254+
})
255+
39256
if __name__ == "__main__":
40257
uvicorn.run("stackhawk_mcp.http_server:app", host="0.0.0.0", port=8080, reload=True)

stackhawk_mcp/server.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,9 @@ async def handle_list_tools() -> list[Tool]:
894894
debug_print(f"Error in list_tools: {e}")
895895
raise
896896

897+
# Set as instance attribute so it's available for FastAPI
898+
self._list_tools_handler = handle_list_tools
899+
897900
@self.server.call_tool()
898901
async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent]:
899902
"""Handle tool calls"""
@@ -980,6 +983,10 @@ async def handle_call_tool(name: str, arguments: dict) -> list[types.TextContent
980983
return [types.TextContent(type="text", text=json.dumps(error_result, indent=2))]
981984

982985
debug_print("MCP handlers setup complete")
986+
self._call_tool_handler = handle_call_tool
987+
988+
async def call_tool(self, name: str, arguments: dict):
989+
return await self._call_tool_handler(name, arguments)
983990

984991
async def _get_schema(self) -> Dict[str, Any]:
985992
"""Get the StackHawk YAML schema with caching"""
@@ -3743,6 +3750,9 @@ async def cleanup(self):
37433750
debug_print("Cleaning up StackHawk client...")
37443751
await self.client.close()
37453752

3753+
async def list_tools(self):
3754+
return await self._list_tools_handler()
3755+
37463756

37473757
async def main():
37483758
"""Main entry point"""

tests/test_tool_schema.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import sys
2+
import os
3+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4+
5+
import pytest
6+
7+
from stackhawk_mcp.server import StackHawkMCPServer
8+
9+
def fix_tool_schema(tool):
10+
if isinstance(tool, dict):
11+
tool["outputSchema"] = {"type": "object"}
12+
if tool.get("annotations") is None:
13+
tool["annotations"] = {}
14+
if tool.get("meta") is None:
15+
tool["meta"] = {}
16+
return tool
17+
18+
@pytest.mark.asyncio
19+
async def test_tool_schema_output_schema_and_annotations():
20+
server = StackHawkMCPServer(api_key="dummy")
21+
tools = await server.list_tools()
22+
for tool in tools:
23+
# Use model_dump for Pydantic v2+, fallback to dict()
24+
if hasattr(tool, "model_dump"):
25+
tool_dict = tool.model_dump()
26+
elif hasattr(tool, "dict"):
27+
tool_dict = tool.dict()
28+
else:
29+
tool_dict = tool
30+
tool_dict = fix_tool_schema(tool_dict)
31+
assert isinstance(tool_dict.get("outputSchema"), dict), f"outputSchema is not a dict: {tool_dict}"
32+
assert tool_dict["outputSchema"].get("type") == "object", f"outputSchema.type is not 'object': {tool_dict}"
33+
assert isinstance(tool_dict.get("annotations"), dict), f"annotations is not a dict: {tool_dict}"
34+
assert isinstance(tool_dict.get("meta"), dict), f"meta is not a dict: {tool_dict}"

0 commit comments

Comments
 (0)