Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions lib/ruby_llm/mcp/handlers/async_response.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module MCP
module Handlers
# Represents an async response for deferred completion
class AsyncResponse
attr_reader :elicitation_id, :state, :result, :error
attr_reader :elicitation_id

VALID_STATES = %i[pending completed rejected cancelled timed_out].freeze

Expand Down Expand Up @@ -95,34 +95,46 @@ def on_complete(&callback)

# Check if operation is pending
def pending?
@state == :pending
@mutex.synchronize { @state == :pending }
end

# Check if operation is completed
def completed?
@state == :completed
@mutex.synchronize { @state == :completed }
end

# Check if operation is rejected
def rejected?
@state == :rejected
@mutex.synchronize { @state == :rejected }
end

# Check if operation is cancelled
def cancelled?
@state == :cancelled
@mutex.synchronize { @state == :cancelled }
end

# Check if operation timed out
def timed_out?
@state == :timed_out
@mutex.synchronize { @state == :timed_out }
end

# Check if operation is finished (any terminal state)
def finished?
!pending?
end

def state
@mutex.synchronize { @state }
end

def result
@mutex.synchronize { @result }
end

def error
@mutex.synchronize { @error }
end

private

# Transition to new state (thread-safe)
Expand Down
28 changes: 19 additions & 9 deletions lib/ruby_llm/mcp/handlers/promise.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ module MCP
module Handlers
# Promise implementation for async operations
class Promise
attr_reader :state, :value, :reason

# Initialize a new promise
def initialize
@state = :pending
Expand Down Expand Up @@ -120,17 +118,17 @@ def reject(reason)

# Check if promise is pending
def pending?
@state == :pending
@mutex.synchronize { @state == :pending }
end

# Check if promise is fulfilled
def fulfilled?
@state == :fulfilled
@mutex.synchronize { @state == :fulfilled }
end

# Check if promise is rejected
def rejected?
@state == :rejected
@mutex.synchronize { @state == :rejected }
end

# Check if promise is settled (fulfilled or rejected)
Expand All @@ -146,7 +144,7 @@ def wait(timeout: nil)
# Wait until promise is settled
if timeout
deadline = Time.now + timeout
while pending?
while @state == :pending
remaining = deadline - Time.now
if remaining <= 0
raise Timeout::Error, "Promise timed out after #{timeout} seconds"
Expand All @@ -155,15 +153,27 @@ def wait(timeout: nil)
@condition.wait(@mutex, remaining)
end
else
@condition.wait(@mutex) while pending?
@condition.wait(@mutex) while @state == :pending
end

# Return value or raise error
return @value if fulfilled?
raise @reason if rejected?
return @value if @state == :fulfilled
raise @reason if @state == :rejected
end
end

def state
@mutex.synchronize { @state }
end

def value
@mutex.synchronize { @value }
end

def reason
@mutex.synchronize { @reason }
end

private

# Execute a callback safely
Expand Down
29 changes: 18 additions & 11 deletions lib/ruby_llm/mcp/native/cancellable_operation.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module Native
# Wraps server-initiated requests to support cancellation.
# The operation tracks terminal state so cancellation outcomes are explicit.
class CancellableOperation
attr_reader :request_id, :thread, :state
attr_reader :request_id

def initialize(request_id)
@request_id = request_id
Expand Down Expand Up @@ -61,22 +61,21 @@ def cancel
def execute(&)
return nil if cancelled?

@mutex.synchronize do
worker = @mutex.synchronize do
return nil if %i[cancelled cancelling].include?(@state)

@state = :running
end

@thread = Thread.new do
Thread.current.abort_on_exception = false
begin
@result = yield
rescue Errors::RequestCancelled, StandardError => e
@error = e
@thread = Thread.new do
Thread.current.abort_on_exception = false
begin
@result = yield
rescue Errors::RequestCancelled, StandardError => e
@error = e
end
end
end

@thread.join
worker.join
raise @error if @error && !@error.is_a?(Errors::RequestCancelled)

@result
Expand All @@ -88,6 +87,14 @@ def execute(&)
@thread = nil
end
end

def state
@mutex.synchronize { @state }
end

def thread
@mutex.synchronize { @thread }
end
end
end
end
Expand Down
1 change: 1 addition & 0 deletions lib/ruby_llm/mcp/native/transports/support/timeout.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def with_timeout(seconds, request_id: nil)
result
else
worker.kill # stop the thread (can still have some risk if shared resources)
worker.join(0.1)
raise RubyLLM::MCP::Errors::TimeoutError.new(
message: "Request timed out after #{seconds} seconds",
request_id: request_id
Expand Down
26 changes: 15 additions & 11 deletions spec/ruby_llm/mcp/cancellation_integration_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@
end

# Track if sampling was actually cancelled (shouldn't complete)
sampling_completed = false
sampling_request_id = nil
sampling_completed = Queue.new
sampling_request_ids = Queue.new

# Set up a sampling callback that would take time if not cancelled
client.on_sampling do |sample|
sampling_request_id = sample.to_h[:id]
sampling_request_ids << sample.to_h[:id]
# This should be interrupted by cancellation
sleep 1.0
sampling_completed = true
sampling_completed << true
true
end

Expand All @@ -63,12 +63,12 @@
# Call the tool in a thread so we can cancel it
tool_thread = Thread.new do
tool.execute
rescue RubyLLM::MCP::Errors::TimeoutError
rescue RubyLLM::MCP::Errors::TimeoutError, RubyLLM::MCP::Errors::TransportError
# Cancellation can race with response delivery under slower runtimes.
end

# Wait for the sampling request to start
Timeout.timeout(10) { sleep 0.1 until sampling_request_id }
sampling_request_id = Timeout.timeout(10) { sampling_request_ids.pop }

# Send cancellation notification for the sampling request
notification = RubyLLM::MCP::Notification.new(
Expand All @@ -88,7 +88,7 @@
sleep 0.2

# Verify our sampling callback never completed
expect(sampling_completed).to be false
expect(sampling_completed).to be_empty

# Clean up
tool_thread.kill if tool_thread.alive?
Expand All @@ -104,7 +104,7 @@
config.sampling.preferred_model = "gpt-4o"
end

request_ids = []
request_ids = Queue.new

client.on_sampling do |sample|
request_ids << sample.to_h[:id]
Expand All @@ -121,17 +121,21 @@
threads = 3.times.map do
Thread.new do
tool.execute
rescue RubyLLM::MCP::Errors::TimeoutError
rescue RubyLLM::MCP::Errors::TimeoutError, RubyLLM::MCP::Errors::TransportError
# Cancellation can race with response delivery under slower runtimes.
end
end

# Wait for ALL requests to start
Timeout.timeout(10) { sleep 0.1 until request_ids.length >= 3 }
started_request_ids = Timeout.timeout(10) do
ids = []
ids << request_ids.pop until ids.length >= 3
ids
end

# Cancel each request as soon as we detect it
cancelled_ids = []
request_ids.each do |request_id|
started_request_ids.each do |request_id|
notification = RubyLLM::MCP::Notification.new(
{
"method" => "notifications/cancelled",
Expand Down