diff --git a/lib/ruby_llm/mcp/native/transports/streamable_http.rb b/lib/ruby_llm/mcp/native/transports/streamable_http.rb index 6f4c39d..d864090 100644 --- a/lib/ruby_llm/mcp/native/transports/streamable_http.rb +++ b/lib/ruby_llm/mcp/native/transports/streamable_http.rb @@ -119,10 +119,9 @@ def request(body, wait_for_response: true) # Extract the request ID from the body (if present) request_id = body.is_a?(Hash) ? (body["id"] || body[:id]) : nil - is_initialization = body.is_a?(Hash) && (body["method"] == "initialize" || body[:method] == :initialize) response_queue = setup_response_queue(request_id, wait_for_response) - result = send_http_request(body, request_id, is_initialization: is_initialization) + result = send_http_request(body, request_id) return result if result.is_a?(RubyLLM::MCP::Result) if wait_for_response && request_id @@ -307,7 +306,7 @@ def setup_response_queue(request_id, wait_for_response) response_queue end - def send_http_request(body, request_id, is_initialization: false) + def send_http_request(body, request_id) headers = build_common_headers headers["Content-Type"] = "application/json" headers["Accept"] = "application/json, text/event-stream" @@ -315,33 +314,66 @@ def send_http_request(body, request_id, is_initialization: false) json_body = JSON.generate(body) RubyLLM::MCP.logger.debug "Sending Request: #{json_body}" - request_client = nil + use_background = request_id && @pending_mutex.synchronize { @pending_requests.key?(request_id.to_s) } + + if use_background + send_request_in_background(body, request_id, headers) + else + send_request_synchronously(body, request_id, headers) + end + end + + def send_request_in_background(body, request_id, headers) + request_client = create_connection_with_streaming_callbacks(request_id, close_when_fulfilled: true) + Thread.new do + response = request_client.post(@url, json: body, headers: headers) + handle_response(response, request_id, body) + rescue Errors::BaseError => e + @pending_mutex.synchronize do + queue = @pending_requests.delete(request_id.to_s) + queue&.push(e) + end + rescue StandardError => e + RubyLLM::MCP.logger.error "Background request error: #{e.message}" + @pending_mutex.synchronize do + queue = @pending_requests.delete(request_id.to_s) + queue&.push(Errors::TransportError.new(message: e.message, code: nil)) + end + ensure + close_client(request_client) + end + nil + end + + def send_request_synchronously(body, request_id, headers) + request_client = create_connection_with_streaming_callbacks(request_id, close_when_fulfilled: false) begin - connection = if is_initialization - @connection - else - request_client = create_connection_with_streaming_callbacks(request_id) - request_client - end - - response = connection.post(@url, json: body, headers: headers) + response = request_client.post(@url, json: body, headers: headers) handle_response(response, request_id, body) ensure - @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } if request_id - close_client(request_client) if request_client && !is_initialization + close_client(request_client) end end - def create_connection_with_streaming_callbacks(request_id) + def create_connection_with_streaming_callbacks(request_id, close_when_fulfilled: false) buffer = +"" client = Support::HTTPClient.connection.plugin(:callbacks) - client = client.on_response_body_chunk do |request, _response, chunk| + client = client.on_response_body_chunk do |request, response, chunk| next unless running? + if (session_id = response.headers["mcp-session-id"]) && !@session_id + @session_id = session_id + end + RubyLLM::MCP.logger.debug "Received chunk: #{chunk.bytesize} bytes for #{request.uri}" buffer << chunk process_sse_buffer_events(buffer, request_id&.to_s) + + if close_when_fulfilled && request_id + fulfilled = @pending_mutex.synchronize { !@pending_requests.key?(request_id.to_s) } + request.close if fulfilled + end end client = client.with( timeout: { @@ -402,7 +434,10 @@ def handle_success_response(response, request_id, _original_message) result = RubyLLM::MCP::Result.new(json_response, session_id: @session_id) if request_id - @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } + @pending_mutex.synchronize do + queue = @pending_requests.delete(request_id.to_s) + queue&.push(result) + end end result @@ -571,7 +606,7 @@ def attempt_authentication_retry(www_authenticate, resource_metadata_url, reques if success RubyLLM::MCP.logger.info("Authentication challenge handled successfully, retrying request") - result = send_http_request(original_message, request_id, is_initialization: false) + result = send_http_request(original_message, request_id) @auth_retry_attempted = false return result end @@ -894,21 +929,30 @@ def parse_and_validate_http_response(response_body) end def wait_for_response_with_timeout(request_id, response_queue) - result = with_timeout(@request_timeout / 1000, request_id: request_id) do - response_queue.pop - end + timeout_seconds = @request_timeout / 1000.0 + deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + timeout_seconds - # Check if we received a shutdown error sentinel - if result.is_a?(Errors::TransportError) - raise result - end + loop do + result = response_queue.pop(true) + raise result if result.is_a?(Exception) - result - rescue RubyLLM::MCP::Errors::TimeoutError => e - log_message = "StreamableHTTP request timeout (ID: #{request_id}) after #{@request_timeout / 1000} seconds" - RubyLLM::MCP.logger.error(log_message) + return result + rescue ThreadError + if Process.clock_gettime(Process::CLOCK_MONOTONIC) >= deadline + log_message = + "StreamableHTTP request timeout (ID: #{request_id}) after #{@request_timeout / 1000} seconds" + RubyLLM::MCP.logger.error(log_message) + @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } + raise Errors::TimeoutError.new( + message: "Request timed out after #{@request_timeout / 1000} seconds", + request_id: request_id + ) + end + + sleep(0.05) + end + ensure @pending_mutex.synchronize { @pending_requests.delete(request_id.to_s) } - raise e end def cleanup_sse_resources diff --git a/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb b/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb index 6383b66..8970f0a 100644 --- a/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb +++ b/spec/ruby_llm/mcp/native/transports/streamable_http_spec.rb @@ -2171,6 +2171,191 @@ end end + describe "background thread routing for wait_for_response: true requests" do + before do + WebMock.enable! + end + + after do + WebMock.reset! + WebMock.enable! + end + + context "when queue is registered (wait_for_response: true)" do + it "delivers successful 200 JSON response via background thread queue" do + stub_request(:post, TestServerManager::HTTP_SERVER_URL) + .to_return( + status: 200, + headers: { "Content-Type" => "application/json" }, + body: { "jsonrpc" => "2.0", "id" => 1, "result" => { "tools" => [] } }.to_json + ) + + result = transport.request({ "method" => "tools/list", "id" => 1 }) + + expect(result).to be_a(RubyLLM::MCP::Result) + expect(result.result["tools"]).to eq([]) + end + + it "preserves AuthenticationRequiredError type from background thread without wrapping" do + stub_request(:post, TestServerManager::HTTP_SERVER_URL) + .to_return(status: 401) + + expect do + transport.request({ "method" => "tools/list", "id" => 2 }) + end.to raise_error( + RubyLLM::MCP::Errors::AuthenticationRequiredError, + /no OAuth provider configured/ + ) + end + end + end + + describe "inline SSE stream handling for POST responses" do + let(:request_id) { 1 } + let(:response_queue) { Queue.new } + let(:http_client) { RubyLLM::MCP::Native::Transports::Support::HTTPClient } + let(:fake_client_class) do + Class.new do + attr_reader :callback + + def plugin(_) + self + end + + def on_response_body_chunk(&block) + @callback = block + self + end + + def with(*) + self + end + end + end + let(:fake_client) { fake_client_class.new } + + before do + transport + + allow(mock_coordinator).to receive(:process_result) { |result| result } + allow(http_client).to receive(:connection).and_return(fake_client) + + transport.instance_variable_get(:@pending_mutex).synchronize do + transport.instance_variable_get(:@pending_requests)[request_id.to_s] = response_queue + end + end + + after do + allow(http_client).to receive(:connection).and_call_original + end + + it "pushes inline SSE result to queue and closes request when fulfilled" do + transport.send(:create_connection_with_streaming_callbacks, request_id, close_when_fulfilled: true) + + request = instance_double(HTTPX::Request, uri: "http://example.test", close: nil) + response = instance_double(HTTPX::Response, headers: { "mcp-session-id" => "session-abc" }) + payload = { "jsonrpc" => "2.0", "id" => request_id, "result" => { "ok" => true } }.to_json + + expect(request).to receive(:close) + + fake_client.callback.call(request, response, "data: #{payload}\n\n") + + result = response_queue.pop(true) + expect(result).to be_a(RubyLLM::MCP::Result) + expect(result.result["ok"]).to be(true) + expect(transport.instance_variable_get(:@session_id)).to eq("session-abc") + + pending_requests = transport.instance_variable_get(:@pending_requests) + expect(pending_requests).not_to have_key(request_id.to_s) + end + + it "does not close request when close_when_fulfilled is false" do + transport.send(:create_connection_with_streaming_callbacks, request_id, close_when_fulfilled: false) + + request = instance_double(HTTPX::Request, uri: "http://example.test", close: nil) + response = instance_double(HTTPX::Response, headers: {}) + payload = { "jsonrpc" => "2.0", "id" => request_id, "result" => { "ok" => true } }.to_json + + expect(request).not_to receive(:close) + + fake_client.callback.call(request, response, "data: #{payload}\n\n") + + result = response_queue.pop(true) + expect(result).to be_a(RubyLLM::MCP::Result) + end + end + + describe "exception sentinel re-raising in wait_for_response_with_timeout" do + let(:request_id) { "sentinel-rethrow-test" } + let(:response_queue) { Queue.new } + + before do + transport.instance_variable_get(:@pending_mutex).synchronize do + transport.instance_variable_get(:@pending_requests)[request_id] = response_queue + end + end + + it "re-raises SessionExpiredError pushed to queue by background thread" do + session_error = RubyLLM::MCP::Errors::SessionExpiredError.new( + message: "Session has expired" + ) + response_queue.push(session_error) + + expect do + transport.send(:wait_for_response_with_timeout, request_id, response_queue) + end.to raise_error(RubyLLM::MCP::Errors::SessionExpiredError, /Session has expired/) + end + + it "re-raises AuthenticationRequiredError pushed to queue by background thread" do + auth_error = RubyLLM::MCP::Errors::AuthenticationRequiredError.new( + message: "Authentication required" + ) + response_queue.push(auth_error) + + expect do + transport.send(:wait_for_response_with_timeout, request_id, response_queue) + end.to raise_error(RubyLLM::MCP::Errors::AuthenticationRequiredError, /Authentication required/) + end + + it "cleans up pending request via ensure block when exception sentinel is re-raised" do + session_error = RubyLLM::MCP::Errors::SessionExpiredError.new(message: "expired") + response_queue.push(session_error) + + begin + transport.send(:wait_for_response_with_timeout, request_id, response_queue) + rescue RubyLLM::MCP::Errors::SessionExpiredError + nil + end + + expect(transport.instance_variable_get(:@pending_requests)).not_to have_key(request_id) + end + end + + describe "polling timeout loop in wait_for_response_with_timeout" do + it "times out and cleans up when queue remains empty" do + short_timeout_transport = described_class.new( + url: TestServerManager::HTTP_SERVER_URL, + request_timeout: 50, + coordinator: mock_coordinator, + options: {} + ) + + request_id = "polling-timeout" + response_queue = Queue.new + + short_timeout_transport.instance_variable_get(:@pending_mutex).synchronize do + short_timeout_transport.instance_variable_get(:@pending_requests)[request_id] = response_queue + end + + expect do + short_timeout_transport.send(:wait_for_response_with_timeout, request_id, response_queue) + end.to raise_error(RubyLLM::MCP::Errors::TimeoutError, /Request timed out/) + + pending_requests = short_timeout_transport.instance_variable_get(:@pending_requests) + expect(pending_requests).not_to have_key(request_id) + end + end + describe "204 No Content response handling for session termination" do before do WebMock.enable!