Skip to content

Commit

Permalink
Add incremental updating of open streams count and closed_streams state
Browse files Browse the repository at this point in the history
  • Loading branch information
kahuang committed Feb 17, 2019
1 parent f96b4f5 commit a633092
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 18 deletions.
42 changes: 27 additions & 15 deletions h2/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def __init__(self, config=None):
self.encoder = Encoder()
self.decoder = Decoder()

self._open_outbound_stream_count = 0
self._open_inbound_stream_count = 0

# This won't always actually do anything: for versions of HPACK older
# than 2.3.0 it does nothing. However, we have to try!
self.decoder.max_header_list_size = self.DEFAULT_MAX_HEADER_LIST_SIZE
Expand Down Expand Up @@ -362,6 +365,8 @@ def __init__(self, config=None):
size_limit=self.MAX_CLOSED_STREAMS
)

self._streams_to_close = list()

# The flow control window manager for the connection.
self._inbound_flow_control_window_manager = WindowManager(
max_window_size=self.local_settings.initial_window_size
Expand All @@ -383,6 +388,15 @@ def __init__(self, config=None):
ExtensionFrame: self._receive_unknown_frame
}

def _increment_open_streams(self, stream_id, incr):
if stream_id % 2 == 0:
self._open_inbound_stream_count += incr
elif stream_id % 2 == 1:
self._open_outbound_stream_count += incr

def _close_stream(self, stream_id):
self._streams_to_close.append(stream_id)

def _prepare_for_sending(self, frames):
if not frames:
return
Expand All @@ -393,22 +407,18 @@ def _open_streams(self, remainder):
"""
A common method of counting number of open streams. Returns the number
of streams that are open *and* that have (stream ID % 2) == remainder.
While it iterates, also deletes any closed streams.
Also cleans up closed streams.
"""
count = 0
to_delete = []

for stream_id, stream in self.streams.items():
if stream.open and (stream_id % 2 == remainder):
count += 1
elif stream.closed:
to_delete.append(stream_id)

for stream_id in to_delete:
for stream_id in self._streams_to_close:
stream = self.streams.pop(stream_id)
self._closed_streams[stream_id] = stream.closed_by
self._streams_to_close = list()

return count
if remainder == 0:
return self._open_inbound_stream_count
elif remainder == 1:
return self._open_outbound_stream_count
return 0

@property
def open_outbound_streams(self):
Expand Down Expand Up @@ -467,7 +477,9 @@ def _begin_new_stream(self, stream_id, allowed_ids):
stream_id,
config=self.config,
inbound_window_size=self.local_settings.initial_window_size,
outbound_window_size=self.remote_settings.initial_window_size
outbound_window_size=self.remote_settings.initial_window_size,
increment_open_stream_count_callback=self._increment_open_streams,
close_stream_callback=self._close_stream,
)
self.config.logger.debug("Stream ID %d created", stream_id)
s.max_inbound_frame_size = self.max_inbound_frame_size
Expand Down Expand Up @@ -1542,8 +1554,8 @@ def _receive_headers_frame(self, frame):
max_open_streams = self.local_settings.max_concurrent_streams
if (self.open_inbound_streams + 1) > max_open_streams:
raise TooManyStreamsError(
"Max outbound streams is %d, %d open" %
(max_open_streams, self.open_outbound_streams)
"Max inbound streams is %d, %d open" %
(max_open_streams, self.open_inbound_streams)
)

# Let's decode the headers. We handle headers as bytes internally up
Expand Down
83 changes: 81 additions & 2 deletions h2/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,56 @@ def send_alt_svc(self, previous_state):
(H2StreamStateMachine.send_on_closed_stream, StreamState.CLOSED),
}

"""
Wraps a stream state change function to ensure that we keep
the parent H2Connection's state in sync
"""
def sync_state_change(func):
def wrapper(self, *args, **kwargs):
# Collect state at the beginning.
start_state = self.state_machine.state
started_open = self.open
started_closed = not started_open

# Do the state change (if any).
result = func(self, *args, **kwargs)

# Collect state at the end.
end_state = self.state_machine.state
ended_open = self.open
ended_closed = not ended_open

if end_state == StreamState.CLOSED and start_state != end_state:
if self._close_stream_callback:
self._close_stream_callback(self.stream_id)
# Clear callback so we only call this once per stream
self._close_stream_callback = None

# If we were open, but are now closed, decrement
# the open stream count, and call the close callback.
if started_open and ended_closed:
if self._decrement_open_stream_count_callback:
self._decrement_open_stream_count_callback(self.stream_id,
-1,)
# Clear callback so we only call this once per stream
self._decrement_open_stream_count_callback = None

if self._close_stream_callback:
self._close_stream_callback(self.stream_id)
# Clear callback so we only call this once per stream
self._close_stream_callback = None

# If we were closed, but are now open, increment
# the open stream count.
elif started_closed and ended_open:
if self._increment_open_stream_count_callback:
self._increment_open_stream_count_callback(self.stream_id,
1,)
# Clear callback so we only call this once per stream
self._increment_open_stream_count_callback = None
return result
return wrapper


class H2Stream(object):
"""
Expand All @@ -782,18 +832,29 @@ def __init__(self,
stream_id,
config,
inbound_window_size,
outbound_window_size):
outbound_window_size,
increment_open_stream_count_callback,
close_stream_callback,):
self.state_machine = H2StreamStateMachine(stream_id)
self.stream_id = stream_id
self.max_outbound_frame_size = None
self.request_method = None

# The current value of the outbound stream flow control window
# The current value of the outbound stream flow control window.
self.outbound_flow_control_window = outbound_window_size

# The flow control manager.
self._inbound_window_manager = WindowManager(inbound_window_size)

# Callback to increment open stream count for the H2Connection.
self._increment_open_stream_count_callback = increment_open_stream_count_callback

# Callback to decrement open stream count for the H2Connection.
self._decrement_open_stream_count_callback = increment_open_stream_count_callback

# Callback to clean up state for the H2Connection once we're closed.
self._close_stream_callback = close_stream_callback

# The expected content length, if any.
self._expected_content_length = None

Expand Down Expand Up @@ -850,6 +911,7 @@ def closed_by(self):
"""
return self.state_machine.stream_closed_by

@sync_state_change
def upgrade(self, client_side):
"""
Called by the connection to indicate that this stream is the initial
Expand All @@ -868,6 +930,7 @@ def upgrade(self, client_side):
self.state_machine.process_input(input_)
return

@sync_state_change
def send_headers(self, headers, encoder, end_stream=False):
"""
Returns a list of HEADERS/CONTINUATION frames to emit as either headers
Expand Down Expand Up @@ -917,6 +980,7 @@ def send_headers(self, headers, encoder, end_stream=False):

return frames

@sync_state_change
def push_stream_in_band(self, related_stream_id, headers, encoder):
"""
Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed
Expand All @@ -941,6 +1005,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder):

return frames

@sync_state_change
def locally_pushed(self):
"""
Mark this stream as one that was pushed by this peer. Must be called
Expand All @@ -954,6 +1019,7 @@ def locally_pushed(self):
assert not events
return []

@sync_state_change
def send_data(self, data, end_stream=False, pad_length=None):
"""
Prepare some data frames. Optionally end the stream.
Expand Down Expand Up @@ -981,6 +1047,7 @@ def send_data(self, data, end_stream=False, pad_length=None):

return [df]

@sync_state_change
def end_stream(self):
"""
End a stream without sending data.
Expand All @@ -992,6 +1059,7 @@ def end_stream(self):
df.flags.add('END_STREAM')
return [df]

@sync_state_change
def advertise_alternative_service(self, field_value):
"""
Advertise an RFC 7838 alternative service. The semantics of this are
Expand All @@ -1005,6 +1073,7 @@ def advertise_alternative_service(self, field_value):
asf.field = field_value
return [asf]

@sync_state_change
def increase_flow_control_window(self, increment):
"""
Increase the size of the flow control window for the remote side.
Expand All @@ -1020,6 +1089,7 @@ def increase_flow_control_window(self, increment):
wuf.window_increment = increment
return [wuf]

@sync_state_change
def receive_push_promise_in_band(self,
promised_stream_id,
headers,
Expand All @@ -1044,6 +1114,7 @@ def receive_push_promise_in_band(self,
)
return [], events

@sync_state_change
def remotely_pushed(self, pushed_headers):
"""
Mark this stream as one that was pushed by the remote peer. Must be
Expand All @@ -1057,6 +1128,7 @@ def remotely_pushed(self, pushed_headers):
self._authority = authority_from_headers(pushed_headers)
return [], events

@sync_state_change
def receive_headers(self, headers, end_stream, header_encoding):
"""
Receive a set of headers (or trailers).
Expand Down Expand Up @@ -1091,6 +1163,7 @@ def receive_headers(self, headers, end_stream, header_encoding):
)
return [], events

@sync_state_change
def receive_data(self, data, end_stream, flow_control_len):
"""
Receive some data.
Expand All @@ -1114,6 +1187,7 @@ def receive_data(self, data, end_stream, flow_control_len):
events[0].flow_controlled_length = flow_control_len
return [], events

@sync_state_change
def receive_window_update(self, increment):
"""
Handle a WINDOW_UPDATE increment.
Expand Down Expand Up @@ -1150,6 +1224,7 @@ def receive_window_update(self, increment):

return frames, events

@sync_state_change
def receive_continuation(self):
"""
A naked CONTINUATION frame has been received. This is always an error,
Expand All @@ -1162,6 +1237,7 @@ def receive_continuation(self):
)
assert False, "Should not be reachable"

@sync_state_change
def receive_alt_svc(self, frame):
"""
An Alternative Service frame was received on the stream. This frame
Expand Down Expand Up @@ -1189,6 +1265,7 @@ def receive_alt_svc(self, frame):

return [], events

@sync_state_change
def reset_stream(self, error_code=0):
"""
Close the stream locally. Reset the stream with an error code.
Expand All @@ -1202,6 +1279,7 @@ def reset_stream(self, error_code=0):
rsf.error_code = error_code
return [rsf]

@sync_state_change
def stream_reset(self, frame):
"""
Handle a stream being reset remotely.
Expand All @@ -1217,6 +1295,7 @@ def stream_reset(self, frame):

return [], events

@sync_state_change
def acknowledge_received_data(self, acknowledged_size):
"""
The user has informed us that they've processed some amount of data
Expand Down
2 changes: 1 addition & 1 deletion test/test_basic_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1851,7 +1851,7 @@ def test_stream_repr(self):
"""
Ensure stream string representation is appropriate.
"""
s = h2.stream.H2Stream(4, None, 12, 14)
s = h2.stream.H2Stream(4, None, 12, 14, None, None)
assert repr(s) == "<H2Stream id:4 state:<StreamState.IDLE: 0>>"


Expand Down

0 comments on commit a633092

Please sign in to comment.