11import os
2- from fastapi import FastAPI , Request
3- from fastapi .responses import JSONResponse
2+ import json
3+ import uuid
4+ from fastapi import FastAPI , Request , Response
5+ from fastapi .responses import JSONResponse , StreamingResponse
6+ from fastapi .middleware .cors import CORSMiddleware
47import uvicorn
58import asyncio
69from stackhawk_mcp .server import StackHawkMCPServer
710
811app = FastAPI ()
912
13+ # Add CORS middleware
14+ app .add_middleware (
15+ CORSMiddleware ,
16+ allow_origins = ["*" ], # In production, restrict this
17+ allow_credentials = True ,
18+ allow_methods = ["*" ],
19+ allow_headers = ["*" ],
20+ )
21+
1022# Get API key from environment
1123API_KEY = os .environ .get ("STACKHAWK_API_KEY" , "changeme" )
1224
1325# Create the MCP server instance
1426mcp_server = StackHawkMCPServer (api_key = API_KEY )
1527
16- @app .post ("/call_tool" )
17- async def call_tool (request : Request ):
18- data = await request .json ()
19- name = data ["name" ]
20- arguments = data .get ("arguments" , {})
21- # Call the tool handler
28+ # Store active SSE connections
29+ active_connections = {}
30+
31+ def create_jsonrpc_response (id_value , result = None , error = None ):
32+ """Create a proper JSON-RPC 2.0 response"""
33+ response = {
34+ "jsonrpc" : "2.0" ,
35+ "id" : id_value
36+ }
37+ if error :
38+ response ["error" ] = error
39+ else :
40+ response ["result" ] = result
41+ return response
42+
43+ def fix_tool_schema (tool ):
44+ """Fix tool schema to ensure outputSchema and annotations are objects"""
45+ if isinstance (tool , dict ):
46+ # Ensure outputSchema is an object with proper type
47+ if tool .get ("outputSchema" ) is None :
48+ tool ["outputSchema" ] = {"type" : "object" }
49+ elif isinstance (tool ["outputSchema" ], dict ) and tool ["outputSchema" ].get ("type" ) is None :
50+ tool ["outputSchema" ]["type" ] = "object"
51+
52+ # Ensure annotations is an object
53+ if tool .get ("annotations" ) is None :
54+ tool ["annotations" ] = {}
55+
56+ # Ensure meta is an object
57+ if tool .get ("meta" ) is None :
58+ tool ["meta" ] = {}
59+
60+ return tool
61+
62+ async def handle_initialize_request (request_data ):
63+ """Handle MCP initialize request"""
64+ return create_jsonrpc_response (
65+ request_data .get ("id" ),
66+ {
67+ "protocolVersion" : "2025-03-26" ,
68+ "capabilities" : {
69+ "tools" : {}
70+ },
71+ "serverInfo" : {
72+ "name" : "StackHawk MCP" ,
73+ "version" : "0.1.0"
74+ }
75+ }
76+ )
77+
78+ async def handle_list_tools_request (request_data ):
79+ """Handle MCP list tools request"""
2280 try :
23- # handle_call_tool returns a list of TextContent, convert to dicts
24- result = await mcp_server .server ._call_tool (name , arguments )
25- # Convert TextContent objects to dicts if needed
26- return JSONResponse (content = {"result" : [r .dict () if hasattr (r , 'dict' ) else r for r in result ]})
81+ tools = await mcp_server .list_tools ()
82+ # Fix the tool schemas to ensure proper objects
83+ fixed_tools = [fix_tool_schema (t .dict () if hasattr (t , 'dict' ) else t ) for t in tools ]
84+ return create_jsonrpc_response (
85+ request_data .get ("id" ),
86+ {
87+ "tools" : fixed_tools
88+ }
89+ )
2790 except Exception as e :
28- return JSONResponse (content = {"error" : str (e )}, status_code = 500 )
91+ return create_jsonrpc_response (
92+ request_data .get ("id" ),
93+ error = {
94+ "code" : - 1 ,
95+ "message" : str (e )
96+ }
97+ )
2998
30- @ app . get ( "/list_tools" )
31- async def list_tools ():
99+ async def handle_call_tool_request ( request_data ):
100+ """Handle MCP call tool request"""
32101 try :
33- # handle_list_tools returns a list of Tool objects
34- result = await mcp_server .server ._list_tools ()
35- return JSONResponse (content = {"tools" : [t .dict () if hasattr (t , 'dict' ) else t for t in result ]})
102+ params = request_data .get ("params" , {})
103+ name = params .get ("name" )
104+ arguments = params .get ("arguments" , {})
105+
106+ result = await mcp_server .call_tool (name , arguments )
107+ return create_jsonrpc_response (
108+ request_data .get ("id" ),
109+ {
110+ "content" : [r .dict () if hasattr (r , 'dict' ) else r for r in result ]
111+ }
112+ )
113+ except Exception as e :
114+ return create_jsonrpc_response (
115+ request_data .get ("id" ),
116+ error = {
117+ "code" : - 1 ,
118+ "message" : str (e )
119+ }
120+ )
121+
122+ @app .post ("/mcp" )
123+ async def mcp_endpoint (request : Request ):
124+ """Main MCP endpoint that handles all JSON-RPC messages"""
125+
126+ # Check Accept header
127+ accept_header = request .headers .get ("accept" , "" )
128+ if "application/json" not in accept_header and "text/event-stream" not in accept_header :
129+ return JSONResponse (
130+ content = {"error" : "Accept header must include application/json or text/event-stream" },
131+ status_code = 400
132+ )
133+
134+ try :
135+ # Parse request body
136+ body = await request .body ()
137+ if not body :
138+ return JSONResponse (content = {"error" : "Empty request body" }, status_code = 400 )
139+
140+ data = json .loads (body )
141+
142+ # Handle batched requests
143+ if isinstance (data , list ):
144+ responses = []
145+ for item in data :
146+ response = await handle_jsonrpc_message (item )
147+ if response :
148+ responses .append (response )
149+
150+ if len (responses ) == 1 :
151+ return JSONResponse (content = responses [0 ])
152+ else :
153+ return JSONResponse (content = responses )
154+ else :
155+ # Single request
156+ response = await handle_jsonrpc_message (data )
157+ if response :
158+ return JSONResponse (content = response )
159+ else :
160+ return Response (status_code = 202 ) # Accepted with no body
161+
162+ except json .JSONDecodeError :
163+ return JSONResponse (content = {"error" : "Invalid JSON" }, status_code = 400 )
36164 except Exception as e :
37165 return JSONResponse (content = {"error" : str (e )}, status_code = 500 )
38166
167+ async def handle_jsonrpc_message (message ):
168+ """Handle individual JSON-RPC message"""
169+ if not isinstance (message , dict ):
170+ return None
171+
172+ method = message .get ("method" )
173+ message_id = message .get ("id" )
174+
175+ if method == "initialize" :
176+ return await handle_initialize_request (message )
177+ elif method == "tools/list" :
178+ return await handle_list_tools_request (message )
179+ elif method == "tools/call" :
180+ return await handle_call_tool_request (message )
181+ elif method == "notifications/cancelled" :
182+ # Handle cancellation
183+ return None
184+ else :
185+ return create_jsonrpc_response (
186+ message_id ,
187+ error = {
188+ "code" : - 32601 ,
189+ "message" : f"Method not found: { method } "
190+ }
191+ )
192+
193+ @app .get ("/mcp" )
194+ async def mcp_sse_endpoint (request : Request ):
195+ """SSE endpoint for streaming responses"""
196+
197+ # Check Accept header
198+ accept_header = request .headers .get ("accept" , "" )
199+ if "text/event-stream" not in accept_header :
200+ return JSONResponse (
201+ content = {"error" : "Accept header must include text/event-stream" },
202+ status_code = 405
203+ )
204+
205+ # Generate connection ID
206+ connection_id = str (uuid .uuid4 ())
207+ active_connections [connection_id ] = True
208+
209+ async def event_stream ():
210+ try :
211+ while active_connections .get (connection_id , False ):
212+ # For now, just keep the connection alive
213+ # In a real implementation, you'd send actual events here
214+ await asyncio .sleep (1 )
215+ except Exception :
216+ pass
217+ finally :
218+ if connection_id in active_connections :
219+ del active_connections [connection_id ]
220+
221+ return StreamingResponse (
222+ event_stream (),
223+ media_type = "text/event-stream" ,
224+ headers = {
225+ "Cache-Control" : "no-cache" ,
226+ "Connection" : "keep-alive" ,
227+ "Content-Type" : "text/event-stream" ,
228+ }
229+ )
230+
231+ # Legacy endpoints for manual testing (keep these for now)
232+ @app .get ("/" )
233+ async def root ():
234+ return JSONResponse (content = {
235+ "jsonrpc" : "2.0" ,
236+ "id" : 1 ,
237+ "result" : {
238+ "serverName" : "StackHawk MCP" ,
239+ "serverVersion" : "0.1.0" ,
240+ "protocolVersion" : "v1"
241+ }
242+ })
243+
244+ @app .post ("/" )
245+ async def root_post ():
246+ return JSONResponse (content = {
247+ "jsonrpc" : "2.0" ,
248+ "id" : 1 ,
249+ "result" : {
250+ "serverName" : "StackHawk MCP" ,
251+ "serverVersion" : "0.1.0" ,
252+ "protocolVersion" : "v1"
253+ }
254+ })
255+
39256if __name__ == "__main__" :
40257 uvicorn .run ("stackhawk_mcp.http_server:app" , host = "0.0.0.0" , port = 8080 , reload = True )
0 commit comments