diff --git a/pyproject.toml b/pyproject.toml index b3f1904f..77cd5254 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,8 +20,8 @@ maintainers = [ requires-python = ">=3.9" dependencies = [ - "hyperframe>=6.0,<7", - "hpack>=4.0,<5", + "hyperframe>=6.1,<7", + "hpack>=4.1,<5", ] dynamic = ["version"] @@ -66,6 +66,7 @@ testing = [ linting = [ "ruff>=0.8.0,<1", "mypy>=1.13.0,<2", + "typing_extensions>=4.12.2", ] packaging = [ @@ -90,8 +91,58 @@ h2 = [ "py.typed" ] version = { attr = "h2.__version__" } [tool.ruff] -line-length = 140 +line-length = 150 target-version = "py39" +format.preview = true +format.docstring-code-line-length = 100 +format.docstring-code-format = true +lint.select = [ + "ALL", +] +lint.ignore = [ + "PYI034", # PEP 673 not yet available in Python 3.9 - only in 3.11+ + "ANN001", # args with typing.Any + "ANN002", # args with typing.Any + "ANN003", # kwargs with typing.Any + "ANN401", # kwargs with typing.Any + "SLF001", # implementation detail + "CPY", # not required + "D101", # docs readability + "D102", # docs readability + "D105", # docs readability + "D107", # docs readability + "D200", # docs readability + "D205", # docs readability + "D205", # docs readability + "D203", # docs readability + "D212", # docs readability + "D400", # docs readability + "D401", # docs readability + "D415", # docs readability + "PLR2004", # readability + "SIM108", # readability + "RUF012", # readability + "FBT001", # readability + "FBT002", # readability + "PGH003", # readability + "PGH004", # readability + "FIX001", # readability + "FIX002", # readability + "TD001", # readability + "TD002", # readability + "TD003", # readability + "S101", # readability + "PD901", # readability + "ERA001", # readability + "ARG001", # readability + "ARG002", # readability + "PLR0913", # readability +] +lint.isort.required-imports = [ "from __future__ import annotations" ] + +[tool.mypy] +show_error_codes = true +strict = true [tool.pytest.ini_options] testpaths = [ "tests" ] @@ -150,7 +201,7 @@ commands = [ dependency_groups = ["linting"] commands = [ ["ruff", "check", "src/"], - # TODO: ["mypy", "src/"], + ["mypy", "src/"], ] [tool.tox.env.docs] diff --git a/src/h2/__init__.py b/src/h2/__init__.py index d3fd0cb6..edb8be5b 100644 --- a/src/h2/__init__.py +++ b/src/h2/__init__.py @@ -1,8 +1,9 @@ -# -*- coding: utf-8 -*- """ h2 ~~ A HTTP/2 implementation. """ -__version__ = '4.1.0' +from __future__ import annotations + +__version__ = "4.1.0" diff --git a/src/h2/config.py b/src/h2/config.py index df5de453..cbc3b1ea 100644 --- a/src/h2/config.py +++ b/src/h2/config.py @@ -1,12 +1,13 @@ -# -*- coding: utf-8 -*- """ h2/config ~~~~~~~~~ Objects for controlling the configuration of the HTTP/2 stack. """ +from __future__ import annotations import sys +from typing import Any class _BooleanConfigOption: @@ -14,16 +15,18 @@ class _BooleanConfigOption: Descriptor for handling a boolean config option. This will block attempts to set boolean config options to non-bools. """ - def __init__(self, name): + + def __init__(self, name: str) -> None: self.name = name - self.attr_name = '_%s' % self.name + self.attr_name = f"_{self.name}" - def __get__(self, instance, owner): - return getattr(instance, self.attr_name) + def __get__(self, instance: Any, owner: Any) -> bool: + return getattr(instance, self.attr_name) # type: ignore - def __set__(self, instance, value): + def __set__(self, instance: Any, value: bool) -> None: if not isinstance(value, bool): - raise ValueError("%s must be a bool" % self.name) + msg = f"{self.name} must be a bool" + raise ValueError(msg) # noqa: TRY004 setattr(instance, self.attr_name, value) @@ -35,20 +38,19 @@ class DummyLogger: conditionals being sprinkled throughout the h2 code for calls to logging functions when no logger is passed into the corresponding object. """ - def __init__(self, *vargs): + + def __init__(self, *vargs) -> None: # type: ignore pass - def debug(self, *vargs, **kwargs): + def debug(self, *vargs, **kwargs) -> None: # type: ignore """ No-op logging. Only level needed for now. """ - pass - def trace(self, *vargs, **kwargs): + def trace(self, *vargs, **kwargs) -> None: # type: ignore """ No-op logging. Only level needed for now. """ - pass class OutputLogger: @@ -61,15 +63,16 @@ class OutputLogger: Defaults to ``sys.stderr``. :param trace: Enables trace-level output. Defaults to ``False``. """ - def __init__(self, file=None, trace_level=False): + + def __init__(self, file=None, trace_level=False) -> None: # type: ignore super().__init__() self.file = file or sys.stderr self.trace_level = trace_level - def debug(self, fmtstr, *args): + def debug(self, fmtstr, *args) -> None: # type: ignore print(f"h2 (debug): {fmtstr % args}", file=self.file) - def trace(self, fmtstr, *args): + def trace(self, fmtstr, *args) -> None: # type: ignore if self.trace_level: print(f"h2 (trace): {fmtstr % args}", file=self.file) @@ -147,32 +150,33 @@ class H2Configuration: :type logger: ``logging.Logger`` """ - client_side = _BooleanConfigOption('client_side') + + client_side = _BooleanConfigOption("client_side") validate_outbound_headers = _BooleanConfigOption( - 'validate_outbound_headers' + "validate_outbound_headers", ) normalize_outbound_headers = _BooleanConfigOption( - 'normalize_outbound_headers' + "normalize_outbound_headers", ) split_outbound_cookies = _BooleanConfigOption( - 'split_outbound_cookies' + "split_outbound_cookies", ) validate_inbound_headers = _BooleanConfigOption( - 'validate_inbound_headers' + "validate_inbound_headers", ) normalize_inbound_headers = _BooleanConfigOption( - 'normalize_inbound_headers' + "normalize_inbound_headers", ) def __init__(self, - client_side=True, - header_encoding=None, - validate_outbound_headers=True, - normalize_outbound_headers=True, - split_outbound_cookies=False, - validate_inbound_headers=True, - normalize_inbound_headers=True, - logger=None): + client_side: bool = True, + header_encoding: bool | str | None = None, + validate_outbound_headers: bool = True, + normalize_outbound_headers: bool = True, + split_outbound_cookies: bool = False, + validate_inbound_headers: bool = True, + normalize_inbound_headers: bool = True, + logger: DummyLogger | OutputLogger | None = None) -> None: self.client_side = client_side self.header_encoding = header_encoding self.validate_outbound_headers = validate_outbound_headers @@ -183,7 +187,7 @@ def __init__(self, self.logger = logger or DummyLogger(__name__) @property - def header_encoding(self): + def header_encoding(self) -> bool | str | None: """ Controls whether the headers emitted by this object in events are transparently decoded to ``unicode`` strings, and what encoding is used @@ -195,12 +199,14 @@ def header_encoding(self): return self._header_encoding @header_encoding.setter - def header_encoding(self, value): + def header_encoding(self, value: bool | str | None) -> None: """ Enforces constraints on the value of header encoding. """ if not isinstance(value, (bool, str, type(None))): - raise ValueError("header_encoding must be bool, string, or None") + msg = "header_encoding must be bool, string, or None" + raise ValueError(msg) # noqa: TRY004 if value is True: - raise ValueError("header_encoding cannot be True") + msg = "header_encoding cannot be True" + raise ValueError(msg) self._header_encoding = value diff --git a/src/h2/connection.py b/src/h2/connection.py index ca2b3832..28be9fca 100644 --- a/src/h2/connection.py +++ b/src/h2/connection.py @@ -1,41 +1,75 @@ -# -*- coding: utf-8 -*- """ h2/connection ~~~~~~~~~~~~~ An implementation of a HTTP/2 connection. """ -import base64 +from __future__ import annotations +import base64 from enum import Enum, IntEnum +from typing import TYPE_CHECKING, Any, Callable +from hpack.exceptions import HPACKError, OversizedHeaderListError +from hpack.hpack import Decoder, Encoder from hyperframe.exceptions import InvalidPaddingError from hyperframe.frame import ( - GoAwayFrame, WindowUpdateFrame, HeadersFrame, DataFrame, PingFrame, - PushPromiseFrame, SettingsFrame, RstStreamFrame, PriorityFrame, - ContinuationFrame, AltSvcFrame, ExtensionFrame + AltSvcFrame, + ContinuationFrame, + DataFrame, + ExtensionFrame, + Frame, + GoAwayFrame, + HeadersFrame, + PingFrame, + PriorityFrame, + PushPromiseFrame, + RstStreamFrame, + SettingsFrame, + WindowUpdateFrame, ) -from hpack.hpack import Encoder, Decoder -from hpack.exceptions import HPACKError, OversizedHeaderListError from .config import H2Configuration from .errors import ErrorCodes, _error_code_from_int from .events import ( - WindowUpdated, RemoteSettingsChanged, PingReceived, PingAckReceived, - SettingsAcknowledged, ConnectionTerminated, PriorityUpdated, - AlternativeServiceAvailable, UnknownFrameReceived + AlternativeServiceAvailable, + ConnectionTerminated, + Event, + InformationalResponseReceived, + PingAckReceived, + PingReceived, + PriorityUpdated, + RemoteSettingsChanged, + RequestReceived, + ResponseReceived, + SettingsAcknowledged, + TrailersReceived, + UnknownFrameReceived, + WindowUpdated, ) from .exceptions import ( - ProtocolError, NoSuchStreamError, FlowControlError, FrameTooLargeError, - TooManyStreamsError, StreamClosedError, StreamIDTooLowError, - NoAvailableStreamIDError, RFC1122Error, DenialOfServiceError + DenialOfServiceError, + FlowControlError, + FrameTooLargeError, + NoAvailableStreamIDError, + NoSuchStreamError, + ProtocolError, + RFC1122Error, + StreamClosedError, + StreamIDTooLowError, + TooManyStreamsError, ) from .frame_buffer import FrameBuffer -from .settings import Settings, SettingCodes +from .settings import ChangedSetting, SettingCodes, Settings from .stream import H2Stream, StreamClosedBy from .utilities import SizeLimitDict, guard_increment_window from .windows import WindowManager +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Iterable + + from hpack.struct import Header, HeaderWeaklyTyped + class ConnectionState(Enum): IDLE = 0 @@ -81,6 +115,7 @@ class H2ConnectionStateMachine: maintains very little state directly, instead focusing entirely on managing state transitions. """ + # For the purposes of this state machine we treat HEADERS and their # associated CONTINUATION frames as a single jumbo frame. The protocol # allows/requires this by preventing other frames from being interleved in @@ -210,24 +245,24 @@ class H2ConnectionStateMachine: (None, ConnectionState.CLOSED), } - def __init__(self): + def __init__(self) -> None: self.state = ConnectionState.IDLE - def process_input(self, input_): + def process_input(self, input_: ConnectionInputs) -> list[Event]: """ Process a specific input in the state machine. """ if not isinstance(input_, ConnectionInputs): - raise ValueError("Input must be an instance of ConnectionInputs") + msg = "Input must be an instance of ConnectionInputs" + raise ValueError(msg) # noqa: TRY004 try: func, target_state = self._transitions[(self.state, input_)] - except KeyError: + except KeyError as e: old_state = self.state self.state = ConnectionState.CLOSED - raise ProtocolError( - "Invalid input %s in state %s" % (input_, old_state) - ) + msg = f"Invalid input {input_} in state {old_state}" + raise ProtocolError(msg) from e else: self.state = target_state if func is not None: # pragma: no cover @@ -264,6 +299,7 @@ class H2Connection: :type config: :class:`H2Configuration ` """ + # The initial maximum outbound frame size. This can be changed by receiving # a settings frame. DEFAULT_MAX_OUTBOUND_FRAME_SIZE = 65535 @@ -284,9 +320,9 @@ class H2Connection: # Keep in memory limited amount of results for streams closes MAX_CLOSED_STREAMS = 2**16 - def __init__(self, config=None): + def __init__(self, config: H2Configuration | None = None) -> None: self.state_machine = H2ConnectionStateMachine() - self.streams = {} + self.streams: dict[int, H2Stream] = {} self.highest_inbound_stream_id = 0 self.highest_outbound_stream_id = 0 self.encoder = Encoder() @@ -299,11 +335,7 @@ def __init__(self, config=None): #: The configuration for this HTTP/2 connection object. #: #: .. versionadded:: 2.5.0 - self.config = config - if self.config is None: - self.config = H2Configuration( - client_side=True, - ) + self.config = config or H2Configuration(client_side=True) # Objects that store settings, including defaults. # @@ -324,7 +356,7 @@ def __init__(self, config=None): SettingCodes.MAX_CONCURRENT_STREAMS: 100, SettingCodes.MAX_HEADER_LIST_SIZE: self.DEFAULT_MAX_HEADER_LIST_SIZE, - } + }, ) self.remote_settings = Settings(client=not self.config.client_side) @@ -347,7 +379,7 @@ def __init__(self, config=None): # A private variable to store a sequence of received header frames # until completion. - self._header_frames = [] + self._header_frames: list[Frame] = [] # Data that needs to be sent. self._data_to_send = bytearray() @@ -358,17 +390,17 @@ def __init__(self, config=None): # Also used to determine whether we should consider a frame received # while a stream is closed as either a stream error or a connection # error. - self._closed_streams = SizeLimitDict( - size_limit=self.MAX_CLOSED_STREAMS + self._closed_streams: dict[int, StreamClosedBy | None] = SizeLimitDict( + size_limit=self.MAX_CLOSED_STREAMS, ) # The flow control window manager for the connection. self._inbound_flow_control_window_manager = WindowManager( - max_window_size=self.local_settings.initial_window_size + max_window_size=self.local_settings.initial_window_size, ) # When in doubt use dict-dispatch. - self._frame_dispatch_table = { + self._frame_dispatch_table: dict[type[Frame], Callable] = { # type: ignore HeadersFrame: self._receive_headers_frame, PushPromiseFrame: self._receive_push_promise_frame, SettingsFrame: self._receive_settings_frame, @@ -380,16 +412,16 @@ def __init__(self, config=None): GoAwayFrame: self._receive_goaway_frame, ContinuationFrame: self._receive_naked_continuation, AltSvcFrame: self._receive_alt_svc_frame, - ExtensionFrame: self._receive_unknown_frame + ExtensionFrame: self._receive_unknown_frame, } - def _prepare_for_sending(self, frames): + def _prepare_for_sending(self, frames: list[Frame]) -> None: if not frames: return - self._data_to_send += b''.join(f.serialize() for f in frames) + self._data_to_send += b"".join(f.serialize() for f in frames) assert all(f.body_len <= self.max_outbound_frame_size for f in frames) - def _open_streams(self, remainder): + def _open_streams(self, remainder: int) -> int: """ A common method of counting number of open streams. Returns the number of streams that are open *and* that have (stream ID % 2) == remainder. @@ -411,7 +443,7 @@ def _open_streams(self, remainder): return count @property - def open_outbound_streams(self): + def open_outbound_streams(self) -> int: """ The current number of open outbound streams. """ @@ -419,7 +451,7 @@ def open_outbound_streams(self): return self._open_streams(outbound_numbers) @property - def open_inbound_streams(self): + def open_inbound_streams(self) -> int: """ The current number of open inbound streams. """ @@ -427,7 +459,7 @@ def open_inbound_streams(self): return self._open_streams(inbound_numbers) @property - def inbound_flow_control_window(self): + def inbound_flow_control_window(self) -> int: """ The size of the inbound flow control window for the connection. This is rarely publicly useful: instead, use :meth:`remote_flow_control_window @@ -436,7 +468,7 @@ def inbound_flow_control_window(self): """ return self._inbound_flow_control_window_manager.current_window_size - def _begin_new_stream(self, stream_id, allowed_ids): + def _begin_new_stream(self, stream_id: int, allowed_ids: AllowedStreamIDs) -> H2Stream: """ Initiate a new stream. @@ -447,7 +479,7 @@ def _begin_new_stream(self, stream_id, allowed_ids): :param allowed_ids: What kind of stream ID is allowed. """ self.config.logger.debug( - "Attempting to initiate stream ID %d", stream_id + "Attempting to initiate stream ID %d", stream_id, ) outbound = self._stream_id_is_outbound(stream_id) highest_stream_id = ( @@ -459,15 +491,14 @@ def _begin_new_stream(self, stream_id, allowed_ids): raise StreamIDTooLowError(stream_id, highest_stream_id) if (stream_id % 2) != int(allowed_ids): - raise ProtocolError( - "Invalid stream ID for peer." - ) + msg = "Invalid stream ID for peer." + raise ProtocolError(msg) s = H2Stream( stream_id, config=self.config, inbound_window_size=self.local_settings.initial_window_size, - outbound_window_size=self.remote_settings.initial_window_size + outbound_window_size=self.remote_settings.initial_window_size, ) self.config.logger.debug("Stream ID %d created", stream_id) s.max_outbound_frame_size = self.max_outbound_frame_size @@ -482,7 +513,7 @@ def _begin_new_stream(self, stream_id, allowed_ids): return s - def initiate_connection(self): + def initiate_connection(self) -> None: """ Provides any data that needs to be sent at the start of the connection. Must be called for both clients and servers. @@ -490,20 +521,20 @@ def initiate_connection(self): self.config.logger.debug("Initializing connection") self.state_machine.process_input(ConnectionInputs.SEND_SETTINGS) if self.config.client_side: - preamble = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + preamble = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" else: - preamble = b'' + preamble = b"" f = SettingsFrame(0) for setting, value in self.local_settings.items(): f.settings[setting] = value self.config.logger.debug( - "Send Settings frame: %s", self.local_settings + "Send Settings frame: %s", self.local_settings, ) self._data_to_send += preamble + f.serialize() - def initiate_upgrade_connection(self, settings_header=None): + def initiate_upgrade_connection(self, settings_header: bytes | None = None) -> bytes | None: """ Call to initialise the connection object for use with an upgraded HTTP/2 connection (i.e. a connection negotiated using the @@ -537,7 +568,7 @@ def initiate_upgrade_connection(self, settings_header=None): :rtype: ``bytes`` or ``None`` """ self.config.logger.debug( - "Upgrade connection. Current settings: %s", self.local_settings + "Upgrade connection. Current settings: %s", self.local_settings, ) frame_data = None @@ -558,7 +589,7 @@ def initiate_upgrade_connection(self, settings_header=None): # the state machine, but ignoring the return value. settings_header = base64.urlsafe_b64decode(settings_header) f = SettingsFrame(0) - f.parse_body(settings_header) + f.parse_body(memoryview(settings_header)) self._receive_settings_frame(f) # Set up appropriate state. Stream 1 in a half-closed state: @@ -576,7 +607,7 @@ def initiate_upgrade_connection(self, settings_header=None): self.streams[1].upgrade(self.config.client_side) return frame_data - def _get_or_create_stream(self, stream_id, allowed_ids): + def _get_or_create_stream(self, stream_id: int, allowed_ids: AllowedStreamIDs) -> H2Stream: """ Gets a stream by its stream ID. Will create one if one does not already exist. Use allowed_ids to circumvent the usual stream ID rules for @@ -590,7 +621,7 @@ def _get_or_create_stream(self, stream_id, allowed_ids): except KeyError: return self._begin_new_stream(stream_id, allowed_ids) - def _get_stream_by_id(self, stream_id): + def _get_stream_by_id(self, stream_id: int | None) -> H2Stream: """ Gets a stream by its stream ID. Raises NoSuchStreamError if the stream ID does not correspond to a known stream and is higher than the current @@ -599,9 +630,11 @@ def _get_stream_by_id(self, stream_id): .. versionchanged:: 2.0.0 Removed this function from the public API. """ + if not stream_id: + raise NoSuchStreamError(-1) # pragma: no cover try: return self.streams[stream_id] - except KeyError: + except KeyError as e: outbound = self._stream_id_is_outbound(stream_id) highest_stream_id = ( self.highest_outbound_stream_id if outbound else @@ -609,11 +642,10 @@ def _get_stream_by_id(self, stream_id): ) if stream_id > highest_stream_id: - raise NoSuchStreamError(stream_id) - else: - raise StreamClosedError(stream_id) + raise NoSuchStreamError(stream_id) from e + raise StreamClosedError(stream_id) from e - def get_next_available_stream_id(self): + def get_next_available_stream_id(self) -> int: """ Returns an integer suitable for use as the stream ID for the next stream created by this endpoint. For server endpoints, this stream ID @@ -642,16 +674,21 @@ def get_next_available_stream_id(self): else: next_stream_id = self.highest_outbound_stream_id + 2 self.config.logger.debug( - "Next available stream ID %d", next_stream_id + "Next available stream ID %d", next_stream_id, ) if next_stream_id > self.HIGHEST_ALLOWED_STREAM_ID: - raise NoAvailableStreamIDError("Exhausted allowed stream IDs") + msg = "Exhausted allowed stream IDs" + raise NoAvailableStreamIDError(msg) return next_stream_id - def send_headers(self, stream_id, headers, end_stream=False, - priority_weight=None, priority_depends_on=None, - priority_exclusive=None): + def send_headers(self, + stream_id: int, + headers: Iterable[HeaderWeaklyTyped], + end_stream: bool = False, + priority_weight: int | None = None, + priority_depends_on: int | None = None, + priority_exclusive: bool | None = None) -> None: """ Send headers on a given stream. @@ -750,26 +787,26 @@ def send_headers(self, stream_id, headers, end_stream=False, :returns: Nothing """ self.config.logger.debug( - "Send headers on stream ID %d", stream_id + "Send headers on stream ID %d", stream_id, ) # Check we can open the stream. if stream_id not in self.streams: max_open_streams = self.remote_settings.max_concurrent_streams if (self.open_outbound_streams + 1) > max_open_streams: - raise TooManyStreamsError( - "Max outbound streams is %d, %d open" % - (max_open_streams, self.open_outbound_streams) - ) + msg = f"Max outbound streams is {max_open_streams}, {self.open_outbound_streams} open" + raise TooManyStreamsError(msg) self.state_machine.process_input(ConnectionInputs.SEND_HEADERS) stream = self._get_or_create_stream( - stream_id, AllowedStreamIDs(self.config.client_side) - ) - frames = stream.send_headers( - headers, self.encoder, end_stream + stream_id, AllowedStreamIDs(self.config.client_side), ) + frames: list[Frame] = [] + frames.extend(stream.send_headers( + headers, self.encoder, end_stream, + )) + # We may need to send priority information. priority_present = ( (priority_weight is not None) or @@ -779,20 +816,27 @@ def send_headers(self, stream_id, headers, end_stream=False, if priority_present: if not self.config.client_side: - raise RFC1122Error("Servers SHOULD NOT prioritize streams.") + msg = "Servers SHOULD NOT prioritize streams." + raise RFC1122Error(msg) headers_frame = frames[0] - headers_frame.flags.add('PRIORITY') + assert isinstance(headers_frame, HeadersFrame) + + headers_frame.flags.add("PRIORITY") frames[0] = _add_frame_priority( headers_frame, priority_weight, priority_depends_on, - priority_exclusive + priority_exclusive, ) self._prepare_for_sending(frames) - def send_data(self, stream_id, data, end_stream=False, pad_length=None): + def send_data(self, + stream_id: int, + data: bytes | memoryview, + end_stream: bool = False, + pad_length: Any = None) -> None: """ Send data on a given stream. @@ -827,34 +871,32 @@ def send_data(self, stream_id, data, end_stream=False, pad_length=None): :returns: Nothing """ self.config.logger.debug( - "Send data on stream ID %d with len %d", stream_id, len(data) + "Send data on stream ID %d with len %d", stream_id, len(data), ) frame_size = len(data) if pad_length is not None: if not isinstance(pad_length, int): - raise TypeError("pad_length must be an int") + msg = "pad_length must be an int" + raise TypeError(msg) if pad_length < 0 or pad_length > 255: - raise ValueError("pad_length must be within range: [0, 255]") + msg = "pad_length must be within range: [0, 255]" + raise ValueError(msg) # Account for padding bytes plus the 1-byte padding length field. frame_size += pad_length + 1 self.config.logger.debug( - "Frame size on stream ID %d is %d", stream_id, frame_size + "Frame size on stream ID %d is %d", stream_id, frame_size, ) if frame_size > self.local_flow_control_window(stream_id): - raise FlowControlError( - "Cannot send %d bytes, flow control window is %d." % - (frame_size, self.local_flow_control_window(stream_id)) - ) - elif frame_size > self.max_outbound_frame_size: - raise FrameTooLargeError( - "Cannot send frame size %d, max frame size is %d" % - (frame_size, self.max_outbound_frame_size) - ) + msg = f"Cannot send {frame_size} bytes, flow control window is {self.local_flow_control_window(stream_id)}" + raise FlowControlError(msg) + if frame_size > self.max_outbound_frame_size: + msg = f"Cannot send frame size {frame_size}, max frame size is {self.max_outbound_frame_size}" + raise FrameTooLargeError(msg) self.state_machine.process_input(ConnectionInputs.SEND_DATA) frames = self.streams[stream_id].send_data( - data, end_stream, pad_length=pad_length + data, end_stream, pad_length=pad_length, ) self._prepare_for_sending(frames) @@ -862,11 +904,11 @@ def send_data(self, stream_id, data, end_stream=False, pad_length=None): self.outbound_flow_control_window -= frame_size self.config.logger.debug( "Outbound flow control window size is %d", - self.outbound_flow_control_window + self.outbound_flow_control_window, ) assert self.outbound_flow_control_window >= 0 - def end_stream(self, stream_id): + def end_stream(self, stream_id: int) -> None: """ Cleanly end a given stream. @@ -882,7 +924,7 @@ def end_stream(self, stream_id): frames = self.streams[stream_id].end_stream() self._prepare_for_sending(frames) - def increment_flow_control_window(self, increment, stream_id=None): + def increment_flow_control_window(self, increment: int, stream_id: int | None = None) -> None: """ Increment a flow control window, optionally for a single stream. Allows the remote peer to send more data. @@ -901,22 +943,20 @@ def increment_flow_control_window(self, increment, stream_id=None): :raises: ``ValueError`` """ if not (1 <= increment <= self.MAX_WINDOW_INCREMENT): - raise ValueError( - "Flow control increment must be between 1 and %d" % - self.MAX_WINDOW_INCREMENT - ) + msg = f"Flow control increment must be between 1 and {self.MAX_WINDOW_INCREMENT}" + raise ValueError(msg) self.state_machine.process_input(ConnectionInputs.SEND_WINDOW_UPDATE) if stream_id is not None: stream = self.streams[stream_id] frames = stream.increase_flow_control_window( - increment + increment, ) self.config.logger.debug( "Increase stream ID %d flow control window by %d", - stream_id, increment + stream_id, increment, ) else: self._inbound_flow_control_window_manager.window_opened(increment) @@ -925,12 +965,15 @@ def increment_flow_control_window(self, increment, stream_id=None): frames = [f] self.config.logger.debug( - "Increase connection flow control window by %d", increment + "Increase connection flow control window by %d", increment, ) self._prepare_for_sending(frames) - def push_stream(self, stream_id, promised_stream_id, request_headers): + def push_stream(self, + stream_id: int, + promised_stream_id: int, + request_headers: Iterable[HeaderWeaklyTyped]) -> None: """ Push a response to the client by sending a PUSH_PROMISE frame. @@ -953,11 +996,12 @@ def push_stream(self, stream_id, promised_stream_id, request_headers): :returns: Nothing """ self.config.logger.debug( - "Send Push Promise frame on stream ID %d", stream_id + "Send Push Promise frame on stream ID %d", stream_id, ) if not self.remote_settings.enable_push: - raise ProtocolError("Remote peer has disabled stream push") + msg = "Remote peer has disabled stream push" + raise ProtocolError(msg) self.state_machine.process_input(ConnectionInputs.SEND_PUSH_PROMISE) stream = self._get_stream_by_id(stream_id) @@ -968,20 +1012,21 @@ def push_stream(self, stream_id, promised_stream_id, request_headers): # this shortcut works because only servers can push and the state # machine will enforce this. if (stream_id % 2) == 0: - raise ProtocolError("Cannot recursively push streams.") + msg = "Cannot recursively push streams." + raise ProtocolError(msg) new_stream = self._begin_new_stream( - promised_stream_id, AllowedStreamIDs.EVEN + promised_stream_id, AllowedStreamIDs.EVEN, ) self.streams[promised_stream_id] = new_stream frames = stream.push_stream_in_band( - promised_stream_id, request_headers, self.encoder + promised_stream_id, request_headers, self.encoder, ) new_frames = new_stream.locally_pushed() self._prepare_for_sending(frames + new_frames) - def ping(self, opaque_data): + def ping(self, opaque_data: bytes | str) -> None: """ Send a PING frame. @@ -992,14 +1037,15 @@ def ping(self, opaque_data): self.config.logger.debug("Send Ping frame") if not isinstance(opaque_data, bytes) or len(opaque_data) != 8: - raise ValueError("Invalid value for ping data: %r" % opaque_data) + msg = f"Invalid value for ping data: {opaque_data!r}" + raise ValueError(msg) self.state_machine.process_input(ConnectionInputs.SEND_PING) f = PingFrame(0) f.opaque_data = opaque_data self._prepare_for_sending([f]) - def reset_stream(self, stream_id, error_code=0): + def reset_stream(self, stream_id: int, error_code: ErrorCodes | int = 0) -> None: """ Reset a stream. @@ -1022,9 +1068,10 @@ def reset_stream(self, stream_id, error_code=0): frames = stream.reset_stream(error_code) self._prepare_for_sending(frames) - def close_connection(self, error_code=0, additional_data=None, - last_stream_id=None): - + def close_connection(self, + error_code: ErrorCodes | int = 0, + additional_data: bytes | None = None, + last_stream_id: int | None = None) -> None: """ Close a connection, emitting a GOAWAY frame. @@ -1053,11 +1100,11 @@ def close_connection(self, error_code=0, additional_data=None, stream_id=0, last_stream_id=last_stream_id, error_code=error_code, - additional_data=(additional_data or b'') + additional_data=(additional_data or b""), ) self._prepare_for_sending([f]) - def update_settings(self, new_settings): + def update_settings(self, new_settings: dict[SettingCodes | int, int]) -> None: """ Update the local settings. This will prepare and emit the appropriate SETTINGS frame. @@ -1065,7 +1112,7 @@ def update_settings(self, new_settings): :param new_settings: A dictionary of {setting: new value} """ self.config.logger.debug( - "Update connection settings to %s", new_settings + "Update connection settings to %s", new_settings, ) self.state_machine.process_input(ConnectionInputs.SEND_SETTINGS) self.local_settings.update(new_settings) @@ -1074,9 +1121,9 @@ def update_settings(self, new_settings): self._prepare_for_sending([s]) def advertise_alternative_service(self, - field_value, - origin=None, - stream_id=None): + field_value: bytes | str, + origin: bytes | None = None, + stream_id: int | None = None) -> None: """ Notify a client about an available Alternative Service. @@ -1131,13 +1178,15 @@ def advertise_alternative_service(self, :returns: Nothing. """ if not isinstance(field_value, bytes): - raise ValueError("Field must be bytestring.") + msg = "Field must be bytestring." + raise ValueError(msg) # noqa: TRY004 if origin is not None and stream_id is not None: - raise ValueError("Must not provide both origin and stream_id") + msg = "Must not provide both origin and stream_id" + raise ValueError(msg) self.state_machine.process_input( - ConnectionInputs.SEND_ALTERNATIVE_SERVICE + ConnectionInputs.SEND_ALTERNATIVE_SERVICE, ) if origin is not None: @@ -1145,15 +1194,18 @@ def advertise_alternative_service(self, f = AltSvcFrame(stream_id=0) f.origin = origin f.field = field_value - frames = [f] + frames: list[Frame] = [f] else: stream = self._get_stream_by_id(stream_id) frames = stream.advertise_alternative_service(field_value) self._prepare_for_sending(frames) - def prioritize(self, stream_id, weight=None, depends_on=None, - exclusive=None): + def prioritize(self, + stream_id: int, + weight: int | None = None, + depends_on: int | None = None, + exclusive: bool | None = None) -> None: """ Notify a server about the priority of a stream. @@ -1217,18 +1269,19 @@ def prioritize(self, stream_id, weight=None, depends_on=None, :type exclusive: ``bool`` """ if not self.config.client_side: - raise RFC1122Error("Servers SHOULD NOT prioritize streams.") + msg = "Servers SHOULD NOT prioritize streams." + raise RFC1122Error(msg) self.state_machine.process_input( - ConnectionInputs.SEND_PRIORITY + ConnectionInputs.SEND_PRIORITY, ) frame = PriorityFrame(stream_id) - frame = _add_frame_priority(frame, weight, depends_on, exclusive) + frame_prio = _add_frame_priority(frame, weight, depends_on, exclusive) - self._prepare_for_sending([frame]) + self._prepare_for_sending([frame_prio]) - def local_flow_control_window(self, stream_id): + def local_flow_control_window(self, stream_id: int) -> int: """ Returns the maximum amount of data that can be sent on stream ``stream_id``. @@ -1252,10 +1305,10 @@ def local_flow_control_window(self, stream_id): stream = self._get_stream_by_id(stream_id) return min( self.outbound_flow_control_window, - stream.outbound_flow_control_window + stream.outbound_flow_control_window, ) - def remote_flow_control_window(self, stream_id): + def remote_flow_control_window(self, stream_id: int) -> int: """ Returns the maximum amount of data the remote peer can send on stream ``stream_id``. @@ -1279,10 +1332,10 @@ def remote_flow_control_window(self, stream_id): stream = self._get_stream_by_id(stream_id) return min( self.inbound_flow_control_window, - stream.inbound_flow_control_window + stream.inbound_flow_control_window, ) - def acknowledge_received_data(self, acknowledged_size, stream_id): + def acknowledge_received_data(self, acknowledged_size: int, stream_id: int) -> None: """ Inform the :class:`H2Connection ` that a certain number of flow-controlled bytes have been processed, and that @@ -1302,17 +1355,16 @@ def acknowledge_received_data(self, acknowledged_size, stream_id): """ self.config.logger.debug( "Ack received data on stream ID %d with size %d", - stream_id, acknowledged_size + stream_id, acknowledged_size, ) if stream_id <= 0: - raise ValueError( - "Stream ID %d is not valid for acknowledge_received_data" % - stream_id - ) + msg = f"Stream ID {stream_id} is not valid for acknowledge_received_data" + raise ValueError(msg) if acknowledged_size < 0: - raise ValueError("Cannot acknowledge negative data") + msg = "Cannot acknowledge negative data" + raise ValueError(msg) - frames = [] + frames: list[Frame] = [] conn_manager = self._inbound_flow_control_window_manager conn_increment = conn_manager.process_bytes(acknowledged_size) @@ -1331,12 +1383,12 @@ def acknowledge_received_data(self, acknowledged_size, stream_id): # No point incrementing the windows of closed streams. if stream.open: frames.extend( - stream.acknowledge_received_data(acknowledged_size) + stream.acknowledge_received_data(acknowledged_size), ) self._prepare_for_sending(frames) - def data_to_send(self, amount=None): + def data_to_send(self, amount: int | None = None) -> bytes: """ Returns some data for sending out of the internal data buffer. @@ -1355,12 +1407,11 @@ def data_to_send(self, amount=None): data = bytes(self._data_to_send) self._data_to_send = bytearray() return data - else: - data = bytes(self._data_to_send[:amount]) - self._data_to_send = self._data_to_send[amount:] - return data + data = bytes(self._data_to_send[:amount]) + self._data_to_send = self._data_to_send[amount:] + return data - def clear_outbound_data_buffer(self): + def clear_outbound_data_buffer(self) -> None: """ Clears the outbound data buffer, such that if this call was immediately followed by a call to @@ -1372,7 +1423,7 @@ def clear_outbound_data_buffer(self): """ self._data_to_send = bytearray() - def _acknowledge_settings(self): + def _acknowledge_settings(self) -> list[Frame]: """ Acknowledge settings that have been received. @@ -1406,10 +1457,10 @@ def _acknowledge_settings(self): stream.max_outbound_frame_size = setting.new_value f = SettingsFrame(0) - f.flags.add('ACK') + f.flags.add("ACK") return [f] - def _flow_control_change_from_settings(self, old_value, new_value): + def _flow_control_change_from_settings(self, old_value: int | None, new_value: int) -> None: """ Update flow control windows in response to a change in the value of SETTINGS_INITIAL_WINDOW_SIZE. @@ -1419,15 +1470,15 @@ def _flow_control_change_from_settings(self, old_value, new_value): increment the *connection* flow control window, per section 6.9.2 of RFC 7540. """ - delta = new_value - old_value + delta = new_value - (old_value or 0) for stream in self.streams.values(): stream.outbound_flow_control_window = guard_increment_window( stream.outbound_flow_control_window, - delta + delta, ) - def _inbound_flow_control_change_from_settings(self, old_value, new_value): + def _inbound_flow_control_change_from_settings(self, old_value: int | None, new_value: int) -> None: """ Update remote flow control windows in response to a change in the value of SETTINGS_INITIAL_WINDOW_SIZE. @@ -1435,12 +1486,12 @@ def _inbound_flow_control_change_from_settings(self, old_value, new_value): When this setting is changed, it automatically updates all remote flow control windows by the delta in the settings values. """ - delta = new_value - old_value + delta = new_value - (old_value or 0) for stream in self.streams.values(): stream._inbound_flow_control_change_from_settings(delta) - def receive_data(self, data): + def receive_data(self, data: bytes) -> list[Event]: """ Pass some received HTTP/2 data to the connection for handling. @@ -1450,19 +1501,20 @@ def receive_data(self, data): this data. """ self.config.logger.trace( - "Process received data on connection. Received data: %r", data + "Process received data on connection. Received data: %r", data, ) - events = [] + events: list[Event] = [] self.incoming_buffer.add_data(data) self.incoming_buffer.max_frame_size = self.max_inbound_frame_size try: for frame in self.incoming_buffer: events.extend(self._receive_frame(frame)) - except InvalidPaddingError: + except InvalidPaddingError as e: self._terminate_connection(ErrorCodes.PROTOCOL_ERROR) - raise ProtocolError("Received frame with invalid padding.") + msg = "Received frame with invalid padding." + raise ProtocolError(msg) from e except ProtocolError as e: # For whatever reason, receiving the frame caused a protocol error. # We should prepare to emit a GoAway frame before throwing the @@ -1473,13 +1525,14 @@ def receive_data(self, data): return events - def _receive_frame(self, frame): + def _receive_frame(self, frame: Frame) -> list[Event]: """ Handle a frame received on the connection. .. versionchanged:: 2.0.0 Removed from the public API. """ + events: list[Event] self.config.logger.trace("Received frame: %s", repr(frame)) try: # I don't love using __class__ here, maybe reconsider it. @@ -1511,7 +1564,7 @@ def _receive_frame(self, frame): events = [] elif self._stream_is_closed_by_end(e.stream_id): # Closed by END_STREAM is a connection error. - raise StreamClosedError(e.stream_id) + raise StreamClosedError(e.stream_id) from e else: # Closed implicitly, also a connection error, but of type # PROTOCOL_ERROR. @@ -1521,7 +1574,7 @@ def _receive_frame(self, frame): return events - def _terminate_connection(self, error_code): + def _terminate_connection(self, error_code: ErrorCodes) -> None: """ Terminate the connection early. Used in error handling blocks to send GOAWAY frames. @@ -1532,7 +1585,7 @@ def _terminate_connection(self, error_code): self.state_machine.process_input(ConnectionInputs.SEND_GOAWAY) self._prepare_for_sending([f]) - def _receive_headers_frame(self, frame): + def _receive_headers_frame(self, frame: HeadersFrame) -> tuple[list[Frame], list[Event]]: """ Receive a headers frame on the connection. """ @@ -1541,10 +1594,8 @@ def _receive_headers_frame(self, frame): if frame.stream_id not in self.streams: max_open_streams = self.local_settings.max_concurrent_streams if (self.open_inbound_streams + 1) > max_open_streams: - raise TooManyStreamsError( - "Max outbound streams is %d, %d open" % - (max_open_streams, self.open_outbound_streams) - ) + msg = f"Max outbound streams is {max_open_streams}, {self.open_outbound_streams} open" + raise TooManyStreamsError(msg) # Let's decode the headers. We handle headers as bytes internally up # until we hang them off the event, at which point we may optionally @@ -1552,41 +1603,45 @@ def _receive_headers_frame(self, frame): headers = _decode_headers(self.decoder, frame.data) events = self.state_machine.process_input( - ConnectionInputs.RECV_HEADERS + ConnectionInputs.RECV_HEADERS, ) stream = self._get_or_create_stream( - frame.stream_id, AllowedStreamIDs(not self.config.client_side) + frame.stream_id, AllowedStreamIDs(not self.config.client_side), ) frames, stream_events = stream.receive_headers( headers, - 'END_STREAM' in frame.flags, - self.config.header_encoding + "END_STREAM" in frame.flags, + self.config.header_encoding, ) - if 'PRIORITY' in frame.flags: + if "PRIORITY" in frame.flags: p_frames, p_events = self._receive_priority_frame(frame) + expected_frame_types = (RequestReceived, ResponseReceived, TrailersReceived, InformationalResponseReceived) + assert isinstance(stream_events[0], expected_frame_types) + assert isinstance(p_events[0], PriorityUpdated) stream_events[0].priority_updated = p_events[0] stream_events.extend(p_events) assert not p_frames return frames, events + stream_events - def _receive_push_promise_frame(self, frame): + def _receive_push_promise_frame(self, frame: PushPromiseFrame) -> tuple[list[Frame], list[Event]]: """ Receive a push-promise frame on the connection. """ if not self.local_settings.enable_push: - raise ProtocolError("Received pushed stream") + msg = "Received pushed stream" + raise ProtocolError(msg) pushed_headers = _decode_headers(self.decoder, frame.data) events = self.state_machine.process_input( - ConnectionInputs.RECV_PUSH_PROMISE + ConnectionInputs.RECV_PUSH_PROMISE, ) try: stream = self._get_stream_by_id(frame.stream_id) - except NoSuchStreamError: + except NoSuchStreamError as e: # We need to check if the parent stream was reset by us. If it was # then we presume that the PUSH_PROMISE was in flight when we reset # the parent stream. Rather than accept the new stream, just reset @@ -1602,7 +1657,8 @@ def _receive_push_promise_frame(self, frame): f.error_code = ErrorCodes.REFUSED_STREAM return [f], events - raise ProtocolError("Attempted to push on closed stream.") + msg = "Attempted to push on closed stream." + raise ProtocolError(msg) from e # We need to prevent peers pushing streams in response to streams that # they themselves have already pushed: see #163 and RFC 7540 ยง 6.6. The @@ -1610,7 +1666,8 @@ def _receive_push_promise_frame(self, frame): # this shortcut works because only servers can push and the state # machine will enforce this. if (frame.stream_id % 2) == 0: - raise ProtocolError("Cannot recursively push streams.") + msg = "Cannot recursively push streams." + raise ProtocolError(msg) try: frames, stream_events = stream.receive_push_promise_in_band( @@ -1627,61 +1684,66 @@ def _receive_push_promise_frame(self, frame): return [f], events new_stream = self._begin_new_stream( - frame.promised_stream_id, AllowedStreamIDs.EVEN + frame.promised_stream_id, AllowedStreamIDs.EVEN, ) self.streams[frame.promised_stream_id] = new_stream new_stream.remotely_pushed(pushed_headers) return frames, events + stream_events - def _handle_data_on_closed_stream(self, events, exc, frame): + def _handle_data_on_closed_stream(self, + events: list[Event], + exc: StreamClosedError, + frame: DataFrame) -> tuple[list[Frame], list[Event]]: # This stream is already closed - and yet we received a DATA frame. # The received DATA frame counts towards the connection flow window. # We need to manually to acknowledge the DATA frame to update the flow # window of the connection. Otherwise the whole connection stalls due # the inbound flow window being 0. - frames = [] + frames: list[Frame] = [] conn_manager = self._inbound_flow_control_window_manager conn_increment = conn_manager.process_bytes( - frame.flow_controlled_length + frame.flow_controlled_length, ) + if conn_increment: - f = WindowUpdateFrame(0) - f.window_increment = conn_increment - frames.append(f) + window_update_frame = WindowUpdateFrame(0) + window_update_frame.window_increment = conn_increment + frames.append(window_update_frame) self.config.logger.debug( "Received DATA frame on closed stream %d - " "auto-emitted a WINDOW_UPDATE by %d", - frame.stream_id, conn_increment + frame.stream_id, conn_increment, ) - f = RstStreamFrame(exc.stream_id) - f.error_code = exc.error_code - frames.append(f) + + rst_stream_frame = RstStreamFrame(exc.stream_id) + rst_stream_frame.error_code = exc.error_code + frames.append(rst_stream_frame) self.config.logger.debug( - "Stream %d already CLOSED or cleaned up - " - "auto-emitted a RST_FRAME" % frame.stream_id + "Stream %s already CLOSED or cleaned up - auto-emitted a RST_FRAME", + frame.stream_id, ) return frames, events + exc._events - def _receive_data_frame(self, frame): + def _receive_data_frame(self, frame: DataFrame) -> tuple[list[Frame], list[Event]]: """ Receive a data frame on the connection. """ flow_controlled_length = frame.flow_controlled_length events = self.state_machine.process_input( - ConnectionInputs.RECV_DATA + ConnectionInputs.RECV_DATA, ) self._inbound_flow_control_window_manager.window_consumed( - flow_controlled_length + flow_controlled_length, ) try: stream = self._get_stream_by_id(frame.stream_id) frames, stream_events = stream.receive_data( frame.data, - 'END_STREAM' in frame.flags, - flow_controlled_length + "END_STREAM" in frame.flags, + flow_controlled_length, ) except StreamClosedError as e: # This stream is either marked as CLOSED or already gone from our @@ -1690,16 +1752,16 @@ def _receive_data_frame(self, frame): return frames, events + stream_events - def _receive_settings_frame(self, frame): + def _receive_settings_frame(self, frame: SettingsFrame) -> tuple[list[Frame], list[Event]]: """ Receive a SETTINGS frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_SETTINGS + ConnectionInputs.RECV_SETTINGS, ) # This is an ack of the local settings. - if 'ACK' in frame.flags: + if "ACK" in frame.flags: changed_settings = self._local_settings_acked() ack_event = SettingsAcknowledged() ack_event.changed_settings = changed_settings @@ -1710,14 +1772,14 @@ def _receive_settings_frame(self, frame): self.remote_settings.update(frame.settings) events.append( RemoteSettingsChanged.from_settings( - self.remote_settings, frame.settings - ) + self.remote_settings, frame.settings, + ), ) frames = self._acknowledge_settings() return frames, events - def _receive_window_update_frame(self, frame): + def _receive_window_update_frame(self, frame: WindowUpdateFrame) -> tuple[list[Frame], list[Event]]: """ Receive a WINDOW_UPDATE frame on the connection. """ @@ -1725,14 +1787,14 @@ def _receive_window_update_frame(self, frame): # If we reach in here, we can assume a valid value. events = self.state_machine.process_input( - ConnectionInputs.RECV_WINDOW_UPDATE + ConnectionInputs.RECV_WINDOW_UPDATE, ) if frame.stream_id: try: stream = self._get_stream_by_id(frame.stream_id) frames, stream_events = stream.receive_window_update( - frame.window_increment + frame.window_increment, ) except StreamClosedError: return [], events @@ -1740,7 +1802,7 @@ def _receive_window_update_frame(self, frame): # Increment our local flow control window. self.outbound_flow_control_window = guard_increment_window( self.outbound_flow_control_window, - frame.window_increment + frame.window_increment, ) # FIXME: Should we split this into one event per active stream? @@ -1752,55 +1814,56 @@ def _receive_window_update_frame(self, frame): return frames, events + stream_events - def _receive_ping_frame(self, frame): + def _receive_ping_frame(self, frame: PingFrame) -> tuple[list[Frame], list[Event]]: """ Receive a PING frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_PING + ConnectionInputs.RECV_PING, ) - flags = [] + frames: list[Frame] = [] - if 'ACK' in frame.flags: + evt: PingReceived | PingAckReceived + if "ACK" in frame.flags: evt = PingAckReceived() else: evt = PingReceived() # automatically ACK the PING with the same 'opaque data' f = PingFrame(0) - f.flags = {'ACK'} + f.flags.add("ACK") f.opaque_data = frame.opaque_data - flags.append(f) + frames.append(f) evt.ping_data = frame.opaque_data events.append(evt) - return flags, events + return frames, events - def _receive_rst_stream_frame(self, frame): + def _receive_rst_stream_frame(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]]: """ Receive a RST_STREAM frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_RST_STREAM + ConnectionInputs.RECV_RST_STREAM, ) try: stream = self._get_stream_by_id(frame.stream_id) except NoSuchStreamError: # The stream is missing. That's ok, we just do nothing here. - stream_frames = [] - stream_events = [] + stream_frames: list[Frame] = [] + stream_events: list[Event] = [] else: stream_frames, stream_events = stream.stream_reset(frame) return stream_frames, events + stream_events - def _receive_priority_frame(self, frame): + def _receive_priority_frame(self, frame: HeadersFrame | PriorityFrame) -> tuple[list[Frame], list[Event]]: """ Receive a PRIORITY frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_PRIORITY + ConnectionInputs.RECV_PRIORITY, ) event = PriorityUpdated() @@ -1814,19 +1877,18 @@ def _receive_priority_frame(self, frame): # A stream may not depend on itself. if event.depends_on == frame.stream_id: - raise ProtocolError( - "Stream %d may not depend on itself" % frame.stream_id - ) + msg = f"Stream {frame.stream_id} may not depend on itself" + raise ProtocolError(msg) events.append(event) return [], events - def _receive_goaway_frame(self, frame): + def _receive_goaway_frame(self, frame: GoAwayFrame) -> tuple[list[Frame], list[Event]]: """ Receive a GOAWAY frame on the connection. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_GOAWAY + ConnectionInputs.RECV_GOAWAY, ) # Clear the outbound data buffer: we cannot send further data now. @@ -1842,7 +1904,7 @@ def _receive_goaway_frame(self, frame): return [], events - def _receive_naked_continuation(self, frame): + def _receive_naked_continuation(self, frame: ContinuationFrame) -> None: """ A naked CONTINUATION frame has been received. This is always an error, but the type of error it is depends on the state of the stream and must @@ -1851,9 +1913,10 @@ def _receive_naked_continuation(self, frame): """ stream = self._get_stream_by_id(frame.stream_id) stream.receive_continuation() - assert False, "Should not be reachable" + msg = "Should not be reachable" # pragma: no cover + raise AssertionError(msg) # pragma: no cover - def _receive_alt_svc_frame(self, frame): + def _receive_alt_svc_frame(self, frame: AltSvcFrame) -> tuple[list[Frame], list[Event]]: """ An ALTSVC frame has been received. This frame, specified in RFC 7838, is used to advertise alternative places where the same service can be @@ -1863,7 +1926,7 @@ def _receive_alt_svc_frame(self, frame): 0, and its semantics are different in each case. """ events = self.state_machine.process_input( - ConnectionInputs.RECV_ALTERNATIVE_SERVICE + ConnectionInputs.RECV_ALTERNATIVE_SERVICE, ) frames = [] @@ -1898,7 +1961,7 @@ def _receive_alt_svc_frame(self, frame): return frames, events - def _receive_unknown_frame(self, frame): + def _receive_unknown_frame(self, frame: ExtensionFrame) -> tuple[list[Frame], list[Event]]: """ We have received a frame that we do not understand. This is almost certainly an extension frame, though it's impossible to be entirely @@ -1909,13 +1972,13 @@ def _receive_unknown_frame(self, frame): """ # All we do here is log. self.config.logger.debug( - "Received unknown extension frame (ID %d)", frame.stream_id + "Received unknown extension frame (ID %d)", frame.stream_id, ) event = UnknownFrameReceived() event.frame = frame return [], [event] - def _local_settings_acked(self): + def _local_settings_acked(self) -> dict[SettingCodes | int, ChangedSetting]: """ Handle the local settings being ACKed, update internal state. """ @@ -1944,14 +2007,14 @@ def _local_settings_acked(self): return changes - def _stream_id_is_outbound(self, stream_id): + def _stream_id_is_outbound(self, stream_id: int) -> bool: """ Returns ``True`` if the stream ID corresponds to an outbound stream (one initiated by this peer), returns ``False`` otherwise. """ return (stream_id % 2 == int(self.config.client_side)) - def _stream_closed_by(self, stream_id): + def _stream_closed_by(self, stream_id: int) -> StreamClosedBy | None: """ Returns how the stream was closed. @@ -1966,27 +2029,30 @@ def _stream_closed_by(self, stream_id): return self._closed_streams[stream_id] return None - def _stream_is_closed_by_reset(self, stream_id): + def _stream_is_closed_by_reset(self, stream_id: int) -> bool: """ Returns ``True`` if the stream was closed by sending or receiving a RST_STREAM frame. Returns ``False`` otherwise. """ return self._stream_closed_by(stream_id) in ( - StreamClosedBy.RECV_RST_STREAM, StreamClosedBy.SEND_RST_STREAM + StreamClosedBy.RECV_RST_STREAM, StreamClosedBy.SEND_RST_STREAM, ) - def _stream_is_closed_by_end(self, stream_id): + def _stream_is_closed_by_end(self, stream_id: int) -> bool: """ Returns ``True`` if the stream was closed by sending or receiving an END_STREAM flag in a HEADERS or DATA frame. Returns ``False`` otherwise. """ return self._stream_closed_by(stream_id) in ( - StreamClosedBy.RECV_END_STREAM, StreamClosedBy.SEND_END_STREAM + StreamClosedBy.RECV_END_STREAM, StreamClosedBy.SEND_END_STREAM, ) -def _add_frame_priority(frame, weight=None, depends_on=None, exclusive=None): +def _add_frame_priority(frame: PriorityFrame | HeadersFrame, + weight: int | None = None, + depends_on: int | None = None, + exclusive: bool | None = None) -> PriorityFrame | HeadersFrame: """ Adds priority data to a given frame. Does not change any flags set on that frame: if the caller is adding priority information to a HEADERS frame they @@ -1998,20 +2064,17 @@ def _add_frame_priority(frame, weight=None, depends_on=None, exclusive=None): """ # A stream may not depend on itself. if depends_on == frame.stream_id: - raise ProtocolError( - "Stream %d may not depend on itself" % frame.stream_id - ) + msg = f"Stream {frame.stream_id} may not depend on itself" + raise ProtocolError(msg) # Weight must be between 1 and 256. if weight is not None: if weight > 256 or weight < 1: - raise ProtocolError( - "Weight must be between 1 and 256, not %d" % weight - ) - else: - # Weight is an integer between 1 and 256, but the byte only allows - # 0 to 255: subtract one. - weight -= 1 + msg = f"Weight must be between 1 and 256, not {weight}" + raise ProtocolError(msg) + # Weight is an integer between 1 and 256, but the byte only allows + # 0 to 255: subtract one. + weight -= 1 # Set defaults for anything not provided. weight = weight if weight is not None else 15 @@ -2025,7 +2088,7 @@ def _add_frame_priority(frame, weight=None, depends_on=None, exclusive=None): return frame -def _decode_headers(decoder, encoded_header_block): +def _decode_headers(decoder: Decoder, encoded_header_block: bytes) -> Iterable[Header]: """ Decode a HPACK-encoded header block, translating HPACK exceptions into sensible h2 errors. @@ -2039,9 +2102,11 @@ def _decode_headers(decoder, encoded_header_block): # This is a symptom of a HPACK bomb attack: the user has # disregarded our requirements on how large a header block we'll # accept. - raise DenialOfServiceError("Oversized header block: %s" % e) + msg = f"Oversized header block: {e}" + raise DenialOfServiceError(msg) from e except (HPACKError, IndexError, TypeError, UnicodeDecodeError) as e: # We should only need HPACKError here, but versions of HPACK older # than 2.1.0 throw all three others as well. For maximum # compatibility, catch all of them. - raise ProtocolError("Error decoding header block: %s" % e) + msg = f"Error decoding header block: {e}" + raise ProtocolError(msg) from e diff --git a/src/h2/errors.py b/src/h2/errors.py index 303df597..24ebe00f 100644 --- a/src/h2/errors.py +++ b/src/h2/errors.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ h2/errors ~~~~~~~~~ @@ -8,6 +7,8 @@ The current registry is available at: https://tools.ietf.org/html/rfc7540#section-11.4 """ +from __future__ import annotations + import enum @@ -17,6 +18,7 @@ class ErrorCodes(enum.IntEnum): .. versionadded:: 2.5.0 """ + #: Graceful shutdown. NO_ERROR = 0x0 @@ -60,7 +62,7 @@ class ErrorCodes(enum.IntEnum): HTTP_1_1_REQUIRED = 0xd -def _error_code_from_int(code): +def _error_code_from_int(code: int) -> ErrorCodes | int: """ Given an integer error code, returns either one of :class:`ErrorCodes ` or, if not present in the known set of codes, @@ -72,4 +74,4 @@ def _error_code_from_int(code): return code -__all__ = ['ErrorCodes'] +__all__ = ["ErrorCodes"] diff --git a/src/h2/events.py b/src/h2/events.py index 83080614..b81fd1a6 100644 --- a/src/h2/events.py +++ b/src/h2/events.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ h2/events ~~~~~~~~~ @@ -9,16 +8,25 @@ track of events triggered by receiving data. Each time data is provided to the H2 state machine it processes the data and returns a list of Event objects. """ +from __future__ import annotations + import binascii +from typing import TYPE_CHECKING + +from .settings import ChangedSetting, SettingCodes, Settings, _setting_code_from_int + +if TYPE_CHECKING: # pragma: no cover + from hpack import HeaderTuple + from hyperframe.frame import Frame -from .settings import ChangedSetting, _setting_code_from_int + from .errors import ErrorCodes class Event: """ Base class for h2 events. """ - pass + class RequestReceived(Event): @@ -26,7 +34,7 @@ class RequestReceived(Event): The RequestReceived event is fired whenever all of a request's headers are received. This event carries the HTTP headers for the given request and the stream ID of the new stream. - + In HTTP/2, headers may be sent as a HEADERS frame followed by zero or more CONTINUATION frames with the final frame setting the END_HEADERS flag. This event is fired after the entire sequence is received. @@ -38,31 +46,30 @@ class RequestReceived(Event): .. versionchanged:: 2.4.0 Added ``stream_ended`` and ``priority_updated`` properties. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID for the stream this request was made on. - self.stream_id = None + self.stream_id: int | None = None #: The request headers. - self.headers = None + self.headers: list[HeaderTuple] | None = None #: If this request also ended the stream, the associated #: :class:`StreamEnded ` event will be available #: here. #: #: .. versionadded:: 2.4.0 - self.stream_ended = None + self.stream_ended: StreamEnded | None = None #: If this request also had associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated = None + self.priority_updated: PriorityUpdated | None = None - def __repr__(self): - return "" % ( - self.stream_id, self.headers - ) + def __repr__(self) -> str: + return f"" class ResponseReceived(Event): @@ -75,34 +82,33 @@ class ResponseReceived(Event): Changed the type of ``headers`` to :class:`HeaderTuple `. This has no effect on current users. - .. versionchanged:: 2.4.0 + .. versionchanged:: 2.4.0 Added ``stream_ended`` and ``priority_updated`` properties. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID for the stream this response was made on. - self.stream_id = None + self.stream_id: int | None = None #: The response headers. - self.headers = None + self.headers: list[HeaderTuple] | None = None #: If this response also ended the stream, the associated #: :class:`StreamEnded ` event will be available #: here. #: #: .. versionadded:: 2.4.0 - self.stream_ended = None + self.stream_ended: StreamEnded | None = None #: If this response also had associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated = None + self.priority_updated: PriorityUpdated | None = None - def __repr__(self): - return "" % ( - self.stream_id, self.headers - ) + def __repr__(self) -> str: + return f"" class TrailersReceived(Event): @@ -121,30 +127,29 @@ class TrailersReceived(Event): .. versionchanged:: 2.4.0 Added ``stream_ended`` and ``priority_updated`` properties. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID for the stream on which these trailers were received. - self.stream_id = None + self.stream_id: int | None = None #: The trailers themselves. - self.headers = None + self.headers: list[HeaderTuple] | None = None #: Trailers always end streams. This property has the associated #: :class:`StreamEnded ` in it. #: #: .. versionadded:: 2.4.0 - self.stream_ended = None + self.stream_ended: StreamEnded | None = None #: If the trailers also set associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated = None + self.priority_updated: PriorityUpdated | None = None - def __repr__(self): - return "" % ( - self.stream_id, self.headers - ) + def __repr__(self) -> str: + return f"" class _HeadersSent(Event): @@ -154,7 +159,7 @@ class _HeadersSent(Event): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _ResponseSent(_HeadersSent): @@ -165,7 +170,7 @@ class _ResponseSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _RequestSent(_HeadersSent): @@ -176,7 +181,7 @@ class _RequestSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _TrailersSent(_HeadersSent): @@ -189,7 +194,7 @@ class _TrailersSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class _PushedRequestSent(_HeadersSent): @@ -200,7 +205,7 @@ class _PushedRequestSent(_HeadersSent): This is an internal event, used to determine validation steps on outgoing header blocks. """ - pass + class InformationalResponseReceived(Event): @@ -225,25 +230,24 @@ class InformationalResponseReceived(Event): .. versionchanged:: 2.4.0 Added ``priority_updated`` property. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID for the stream this informational response was made #: on. - self.stream_id = None + self.stream_id: int | None = None #: The headers for this informational response. - self.headers = None + self.headers: list[HeaderTuple] | None = None #: If this response also had associated priority information, the #: associated :class:`PriorityUpdated ` #: event will be available here. #: #: .. versionadded:: 2.4.0 - self.priority_updated = None + self.priority_updated: PriorityUpdated | None = None - def __repr__(self): - return "" % ( - self.stream_id, self.headers - ) + def __repr__(self) -> str: + return f"" class DataReceived(Event): @@ -255,34 +259,35 @@ class DataReceived(Event): .. versionchanged:: 2.4.0 Added ``stream_ended`` property. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID for the stream this data was received on. - self.stream_id = None + self.stream_id: int | None = None #: The data itself. - self.data = None + self.data: bytes | None = None #: The amount of data received that counts against the flow control #: window. Note that padding counts against the flow control window, so #: when adjusting flow control you should always use this field rather #: than ``len(data)``. - self.flow_controlled_length = None + self.flow_controlled_length: int | None = None #: If this data chunk also completed the stream, the associated #: :class:`StreamEnded ` event will be available #: here. #: #: .. versionadded:: 2.4.0 - self.stream_ended = None + self.stream_ended: StreamEnded | None = None - def __repr__(self): + def __repr__(self) -> str: return ( - "" % ( + "".format( self.stream_id, self.flow_controlled_length, - _bytes_representation(self.data[:20]), + _bytes_representation(self.data[:20]) if self.data else "", ) ) @@ -295,18 +300,17 @@ class WindowUpdated(Event): the stream to which it applies (set to zero if the window update applies to the connection), and the delta in the window size. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID of the stream whose flow control window was changed. #: May be ``0`` if the connection window was changed. - self.stream_id = None + self.stream_id: int | None = None #: The window delta. - self.delta = None + self.delta: int | None = None - def __repr__(self): - return "" % ( - self.stream_id, self.delta - ) + def __repr__(self) -> str: + return f"" class RemoteSettingsChanged(Event): @@ -329,14 +333,17 @@ class RemoteSettingsChanged(Event): This is no longer the case: h2 now automatically acknowledges them. """ - def __init__(self): + + def __init__(self) -> None: #: A dictionary of setting byte to #: :class:`ChangedSetting `, representing #: the changed settings. - self.changed_settings = {} + self.changed_settings: dict[int, ChangedSetting] = {} @classmethod - def from_settings(cls, old_settings, new_settings): + def from_settings(cls, + old_settings: Settings | dict[int, int], + new_settings: dict[int, int]) -> RemoteSettingsChanged: """ Build a RemoteSettingsChanged event from a set of changed settings. @@ -347,15 +354,15 @@ def from_settings(cls, old_settings, new_settings): """ e = cls() for setting, new_value in new_settings.items(): - setting = _setting_code_from_int(setting) - original_value = old_settings.get(setting) - change = ChangedSetting(setting, original_value, new_value) - e.changed_settings[setting] = change + s = _setting_code_from_int(setting) + original_value = old_settings.get(s) + change = ChangedSetting(s, original_value, new_value) + e.changed_settings[s] = change return e - def __repr__(self): - return "" % ( + def __repr__(self) -> str: + return "".format( ", ".join(repr(cs) for cs in self.changed_settings.values()), ) @@ -368,14 +375,13 @@ class PingReceived(Event): .. versionadded:: 3.1.0 """ - def __init__(self): + + def __init__(self) -> None: #: The data included on the ping. - self.ping_data = None + self.ping_data: bytes | None = None - def __repr__(self): - return "" % ( - _bytes_representation(self.ping_data), - ) + def __repr__(self) -> str: + return f"" class PingAckReceived(Event): @@ -389,14 +395,13 @@ class PingAckReceived(Event): .. versionchanged:: 4.0.0 Removed deprecated but equivalent ``PingAcknowledged``. """ - def __init__(self): + + def __init__(self) -> None: #: The data included on the ping. - self.ping_data = None + self.ping_data: bytes | None = None - def __repr__(self): - return "" % ( - _bytes_representation(self.ping_data), - ) + def __repr__(self) -> str: + return f"" class StreamEnded(Event): @@ -405,12 +410,13 @@ class StreamEnded(Event): party. The stream may not be fully closed if it has not been closed locally, but no further data or headers should be expected on that stream. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID of the stream that was closed. - self.stream_id = None + self.stream_id: int | None = None - def __repr__(self): - return "" % self.stream_id + def __repr__(self) -> str: + return f"" class StreamReset(Event): @@ -423,21 +429,20 @@ class StreamReset(Event): .. versionchanged:: 2.0.0 This event is now fired when h2 automatically resets a stream. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID of the stream that was reset. - self.stream_id = None + self.stream_id: int | None = None #: The error code given. Either one of :class:`ErrorCodes #: ` or ``int`` - self.error_code = None + self.error_code: ErrorCodes | None = None #: Whether the remote peer sent a RST_STREAM or we did. self.remote_reset = True - def __repr__(self): - return "" % ( - self.stream_id, self.error_code, self.remote_reset - ) + def __repr__(self) -> str: + return f"" class PushedStreamReceived(Event): @@ -446,24 +451,21 @@ class PushedStreamReceived(Event): received from a remote peer. The event carries on it the new stream ID, the ID of the parent stream, and the request headers pushed by the remote peer. """ - def __init__(self): + + def __init__(self) -> None: #: The Stream ID of the stream created by the push. - self.pushed_stream_id = None + self.pushed_stream_id: int | None = None #: The Stream ID of the stream that the push is related to. - self.parent_stream_id = None + self.parent_stream_id: int | None = None #: The request headers, sent by the remote party in the push. - self.headers = None + self.headers: list[HeaderTuple] | None = None - def __repr__(self): + def __repr__(self) -> str: return ( - "" % ( - self.pushed_stream_id, - self.parent_stream_id, - self.headers, - ) + f"" ) @@ -474,16 +476,16 @@ class SettingsAcknowledged(Event): acknowedged, in the same format as :class:`h2.events.RemoteSettingsChanged`. """ - def __init__(self): + + def __init__(self) -> None: #: A dictionary of setting byte to #: :class:`ChangedSetting `, representing #: the changed settings. - self.changed_settings = {} + self.changed_settings: dict[SettingCodes | int, ChangedSetting] = {} - def __repr__(self): - return "" % ( - ", ".join(repr(cs) for cs in self.changed_settings.values()), - ) + def __repr__(self) -> str: + s = ", ".join(repr(cs) for cs in self.changed_settings.values()) + return f"" class PriorityUpdated(Event): @@ -496,31 +498,27 @@ class PriorityUpdated(Event): .. versionadded:: 2.0.0 """ - def __init__(self): + + def __init__(self) -> None: #: The ID of the stream whose priority information is being updated. - self.stream_id = None + self.stream_id: int | None = None #: The new stream weight. May be the same as the original stream #: weight. An integer between 1 and 256. - self.weight = None + self.weight: int | None = None #: The stream ID this stream now depends on. May be ``0``. - self.depends_on = None + self.depends_on: int | None = None #: Whether the stream *exclusively* depends on the parent stream. If it #: does, this stream should inherit the current children of its new #: parent. - self.exclusive = None + self.exclusive: bool | None = None - def __repr__(self): + def __repr__(self) -> str: return ( - "" % ( - self.stream_id, - self.weight, - self.depends_on, - self.exclusive - ) + f"" ) @@ -530,29 +528,30 @@ class ConnectionTerminated(Event): the remote peer using a GOAWAY frame. Once received, no further action may be taken on the connection: a new connection must be established. """ - def __init__(self): + + def __init__(self) -> None: #: The error code cited when tearing down the connection. Should be #: one of :class:`ErrorCodes `, but may not be if #: unknown HTTP/2 extensions are being used. - self.error_code = None + self.error_code: ErrorCodes | int | None = None #: The stream ID of the last stream the remote peer saw. This can #: provide an indication of what data, if any, never reached the remote #: peer and so can safely be resent. - self.last_stream_id = None + self.last_stream_id: int | None = None #: Additional debug data that can be appended to GOAWAY frame. - self.additional_data = None + self.additional_data: bytes | None = None - def __repr__(self): + def __repr__(self) -> str: return ( - "" % ( + "".format( self.error_code, self.last_stream_id, _bytes_representation( self.additional_data[:20] - if self.additional_data else None) + if self.additional_data else None), ) ) @@ -577,26 +576,27 @@ class AlternativeServiceAvailable(Event): .. versionadded:: 2.3.0 """ - def __init__(self): + + def __init__(self) -> None: #: The origin to which the alternative service field value applies. #: This field is either supplied by the server directly, or inferred by #: h2 from the ``:authority`` pseudo-header field that was sent #: by the user when initiating the stream on which the frame was #: received. - self.origin = None + self.origin: bytes | None = None #: The ALTSVC field value. This contains information about the HTTP #: alternative service being advertised by the server. h2 does #: not parse this field: it is left exactly as sent by the server. The #: structure of the data in this field is given by `RFC 7838 Section 3 #: `_. - self.field_value = None + self.field_value: bytes | None = None - def __repr__(self): + def __repr__(self) -> str: return ( - "" % ( - (self.origin or b'').decode('utf-8', 'ignore'), - (self.field_value or b'').decode('utf-8', 'ignore'), + "".format( + (self.origin or b"").decode("utf-8", "ignore"), + (self.field_value or b"").decode("utf-8", "ignore"), ) ) @@ -615,15 +615,16 @@ class UnknownFrameReceived(Event): .. versionadded:: 2.7.0 """ - def __init__(self): + + def __init__(self) -> None: #: The hyperframe Frame object that encapsulates the received frame. - self.frame = None + self.frame: Frame | None = None - def __repr__(self): + def __repr__(self) -> str: return "" -def _bytes_representation(data): +def _bytes_representation(data: bytes | None) -> str | None: """ Converts a bytestring into something that is safe to print on all Python platforms. @@ -635,4 +636,4 @@ def _bytes_representation(data): if data is None: return None - return binascii.hexlify(data).decode('ascii') + return binascii.hexlify(data).decode("ascii") diff --git a/src/h2/exceptions.py b/src/h2/exceptions.py index e22bebc0..e4776795 100644 --- a/src/h2/exceptions.py +++ b/src/h2/exceptions.py @@ -1,11 +1,12 @@ -# -*- coding: utf-8 -*- """ h2/exceptions ~~~~~~~~~~~~~ Exceptions for the HTTP/2 module. """ -import h2.errors +from __future__ import annotations + +from .errors import ErrorCodes class H2Error(Exception): @@ -18,16 +19,18 @@ class ProtocolError(H2Error): """ An action was attempted in violation of the HTTP/2 protocol. """ + #: The error code corresponds to this kind of Protocol Error. - error_code = h2.errors.ErrorCodes.PROTOCOL_ERROR + error_code = ErrorCodes.PROTOCOL_ERROR class FrameTooLargeError(ProtocolError): """ The frame that we tried to send or that we received was too large. """ + #: The error code corresponds to this kind of Protocol Error. - error_code = h2.errors.ErrorCodes.FRAME_SIZE_ERROR + error_code = ErrorCodes.FRAME_SIZE_ERROR class FrameDataMissingError(ProtocolError): @@ -36,8 +39,9 @@ class FrameDataMissingError(ProtocolError): .. versionadded:: 2.0.0 """ + #: The error code corresponds to this kind of Protocol Error. - error_code = h2.errors.ErrorCodes.FRAME_SIZE_ERROR + error_code = ErrorCodes.FRAME_SIZE_ERROR class TooManyStreamsError(ProtocolError): @@ -45,15 +49,16 @@ class TooManyStreamsError(ProtocolError): An attempt was made to open a stream that would lead to too many concurrent streams. """ - pass + class FlowControlError(ProtocolError): """ An attempted action violates flow control constraints. """ + #: The error code corresponds to this kind of Protocol Error. - error_code = h2.errors.ErrorCodes.FLOW_CONTROL_ERROR + error_code = ErrorCodes.FLOW_CONTROL_ERROR class StreamIDTooLowError(ProtocolError): @@ -61,17 +66,16 @@ class StreamIDTooLowError(ProtocolError): An attempt was made to open a stream that had an ID that is lower than the highest ID we have seen on this connection. """ - def __init__(self, stream_id, max_stream_id): + + def __init__(self, stream_id: int, max_stream_id: int) -> None: #: The ID of the stream that we attempted to open. self.stream_id = stream_id #: The current highest-seen stream ID. self.max_stream_id = max_stream_id - def __str__(self): - return "StreamIDTooLowError: %d is lower than %d" % ( - self.stream_id, self.max_stream_id - ) + def __str__(self) -> str: + return f"StreamIDTooLowError: {self.stream_id} is lower than {self.max_stream_id}" class NoAvailableStreamIDError(ProtocolError): @@ -81,7 +85,7 @@ class NoAvailableStreamIDError(ProtocolError): .. versionadded:: 2.0.0 """ - pass + class NoSuchStreamError(ProtocolError): @@ -92,7 +96,8 @@ class NoSuchStreamError(ProtocolError): Became a subclass of :class:`ProtocolError ` """ - def __init__(self, stream_id): + + def __init__(self, stream_id: int) -> None: #: The stream ID corresponds to the non-existent stream. self.stream_id = stream_id @@ -104,16 +109,17 @@ class StreamClosedError(NoSuchStreamError): that the stream has since been closed, and that all state relating to that stream has been removed. """ - def __init__(self, stream_id): + + def __init__(self, stream_id: int) -> None: #: The stream ID corresponds to the nonexistent stream. self.stream_id = stream_id #: The relevant HTTP/2 error code. - self.error_code = h2.errors.ErrorCodes.STREAM_CLOSED + self.error_code = ErrorCodes.STREAM_CLOSED # Any events that internal code may need to fire. Not relevant to # external users that may receive a StreamClosedError. - self._events = [] + self._events = [] # type: ignore class InvalidSettingsValueError(ProtocolError, ValueError): @@ -122,8 +128,9 @@ class InvalidSettingsValueError(ProtocolError, ValueError): .. versionadded:: 2.0.0 """ - def __init__(self, msg, error_code): - super(InvalidSettingsValueError, self).__init__(msg) + + def __init__(self, msg: str, error_code: ErrorCodes) -> None: + super().__init__(msg) self.error_code = error_code @@ -134,14 +141,13 @@ class InvalidBodyLengthError(ProtocolError): .. versionadded:: 2.0.0 """ - def __init__(self, expected, actual): + + def __init__(self, expected: int, actual: int) -> None: self.expected_length = expected self.actual_length = actual - def __str__(self): - return "InvalidBodyLengthError: Expected %d bytes, received %d" % ( - self.expected_length, self.actual_length - ) + def __str__(self) -> str: + return f"InvalidBodyLengthError: Expected {self.expected_length} bytes, received {self.actual_length}" class UnsupportedFrameError(ProtocolError): @@ -153,7 +159,7 @@ class UnsupportedFrameError(ProtocolError): .. versionchanged:: 4.0.0 Removed deprecated KeyError parent class. """ - pass + class RFC1122Error(H2Error): @@ -168,9 +174,9 @@ class RFC1122Error(H2Error): .. versionadded:: 2.4.0 """ + # shazow says I'm going to regret naming the exception this way. If that # turns out to be true, TELL HIM NOTHING. - pass class DenialOfServiceError(ProtocolError): @@ -182,6 +188,7 @@ class DenialOfServiceError(ProtocolError): .. versionadded:: 2.5.0 """ + #: The error code corresponds to this kind of #: :class:`ProtocolError ` - error_code = h2.errors.ErrorCodes.ENHANCE_YOUR_CALM + error_code = ErrorCodes.ENHANCE_YOUR_CALM diff --git a/src/h2/frame_buffer.py b/src/h2/frame_buffer.py index 785775eb..30d96e81 100644 --- a/src/h2/frame_buffer.py +++ b/src/h2/frame_buffer.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ h2/frame_buffer ~~~~~~~~~~~~~~~ @@ -6,14 +5,12 @@ A data structure that provides a way to iterate over a byte buffer in terms of frames. """ -from hyperframe.exceptions import InvalidFrameError, InvalidDataError -from hyperframe.frame import ( - Frame, HeadersFrame, ContinuationFrame, PushPromiseFrame -) +from __future__ import annotations -from .exceptions import ( - ProtocolError, FrameTooLargeError, FrameDataMissingError -) +from hyperframe.exceptions import InvalidDataError, InvalidFrameError +from hyperframe.frame import ContinuationFrame, Frame, HeadersFrame, PushPromiseFrame + +from .exceptions import FrameDataMissingError, FrameTooLargeError, ProtocolError # To avoid a DOS attack based on sending loads of continuation frames, we limit # the maximum number we're perpared to receive. In this case, we'll set the @@ -28,17 +25,18 @@ class FrameBuffer: """ - This is a data structure that expects to act as a buffer for HTTP/2 data - that allows iteraton in terms of H2 frames. + A buffer data structure for HTTP/2 data that allows iteraton in terms of + H2 frames. """ - def __init__(self, server=False): - self.data = b'' + + def __init__(self, server: bool = False) -> None: + self.data = b"" self.max_frame_size = 0 - self._preamble = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' if server else b'' + self._preamble = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" if server else b"" self._preamble_len = len(self._preamble) - self._headers_buffer = [] + self._headers_buffer: list[HeadersFrame | ContinuationFrame | PushPromiseFrame] = [] - def add_data(self, data): + def add_data(self, data: bytes) -> None: """ Add more data to the frame buffer. @@ -49,7 +47,8 @@ def add_data(self, data): of_which_preamble = min(self._preamble_len, data_len) if self._preamble[:of_which_preamble] != data[:of_which_preamble]: - raise ProtocolError("Invalid HTTP/2 preamble.") + msg = "Invalid HTTP/2 preamble." + raise ProtocolError(msg) data = data[of_which_preamble:] self._preamble_len -= of_which_preamble @@ -57,17 +56,15 @@ def add_data(self, data): self.data += data - def _validate_frame_length(self, length): + def _validate_frame_length(self, length: int) -> None: """ Confirm that the frame is an appropriate length. """ if length > self.max_frame_size: - raise FrameTooLargeError( - "Received overlong frame: length %d, max %d" % - (length, self.max_frame_size) - ) + msg = f"Received overlong frame: length {length}, max {self.max_frame_size}" + raise FrameTooLargeError(msg) - def _update_header_buffer(self, f): + def _update_header_buffer(self, f: Frame | None) -> Frame | None: """ Updates the internal header buffer. Returns a frame that should replace the current one. May throw exceptions if this frame is invalid. @@ -85,26 +82,29 @@ def _update_header_buffer(self, f): f.stream_id == stream_id ) if not valid_frame: - raise ProtocolError("Invalid frame during header block.") + msg = "Invalid frame during header block." + raise ProtocolError(msg) + assert isinstance(f, ContinuationFrame) # Append the frame to the buffer. self._headers_buffer.append(f) if len(self._headers_buffer) > CONTINUATION_BACKLOG: - raise ProtocolError("Too many continuation frames received.") + msg = "Too many continuation frames received." + raise ProtocolError(msg) # If this is the end of the header block, then we want to build a # mutant HEADERS frame that's massive. Use the original one we got, # then set END_HEADERS and set its data appopriately. If it's not # the end of the block, lose the current frame: we can't yield it. - if 'END_HEADERS' in f.flags: + if "END_HEADERS" in f.flags: f = self._headers_buffer[0] - f.flags.add('END_HEADERS') - f.data = b''.join(x.data for x in self._headers_buffer) + f.flags.add("END_HEADERS") + f.data = b"".join(x.data for x in self._headers_buffer) self._headers_buffer = [] else: f = None elif (isinstance(f, (HeadersFrame, PushPromiseFrame)) and - 'END_HEADERS' not in f.flags): + "END_HEADERS" not in f.flags): # This is the start of a headers block! Save the frame off and then # act like we didn't receive one. self._headers_buffer.append(f) @@ -113,26 +113,25 @@ def _update_header_buffer(self, f): return f # The methods below support the iterator protocol. - def __iter__(self): + def __iter__(self) -> FrameBuffer: return self - def __next__(self): + def __next__(self) -> Frame: # First, check that we have enough data to successfully parse the # next frame header. If not, bail. Otherwise, parse it. if len(self.data) < 9: - raise StopIteration() + raise StopIteration try: - f, length = Frame.parse_frame_header(self.data[:9]) - except (InvalidDataError, InvalidFrameError) as e: # pragma: no cover - raise ProtocolError( - "Received frame with invalid header: %s" % str(e) - ) + f, length = Frame.parse_frame_header(memoryview(self.data[:9])) + except (InvalidDataError, InvalidFrameError) as err: # pragma: no cover + msg = f"Received frame with invalid header: {err!s}" + raise ProtocolError(msg) from err # Next, check that we have enough length to parse the frame body. If # not, bail, leaving the frame header data in the buffer for next time. if len(self.data) < length + 9: - raise StopIteration() + raise StopIteration # Confirm the frame has an appropriate length. self._validate_frame_length(length) @@ -140,21 +139,23 @@ def __next__(self): # Try to parse the frame body try: f.parse_body(memoryview(self.data[9:9+length])) - except InvalidDataError: - raise ProtocolError("Received frame with non-compliant data") - except InvalidFrameError: - raise FrameDataMissingError("Frame data missing or invalid") + except InvalidDataError as err: + msg = "Received frame with non-compliant data" + raise ProtocolError(msg) from err + except InvalidFrameError as err: + msg = "Frame data missing or invalid" + raise FrameDataMissingError(msg) from err # At this point, as we know we'll use or discard the entire frame, we # can update the data. self.data = self.data[9+length:] # Pass the frame through the header buffer. - f = self._update_header_buffer(f) + new_frame = self._update_header_buffer(f) # If we got a frame we didn't understand or shouldn't yield, rather # than return None it'd be better if we just tried to get the next # frame in the sequence instead. Recurse back into ourselves to do # that. This is safe because the amount of work we have to do here is # strictly bounded by the length of the buffer. - return f if f is not None else self.__next__() + return new_frame if new_frame is not None else self.__next__() diff --git a/src/h2/py.typed b/src/h2/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/h2/settings.py b/src/h2/settings.py index 969a162e..c1be953b 100644 --- a/src/h2/settings.py +++ b/src/h2/settings.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ h2/settings ~~~~~~~~~~~ @@ -7,14 +6,17 @@ API for manipulating HTTP/2 settings, keeping track of both the current active state of the settings and the unacknowledged future values of the settings. """ +from __future__ import annotations + import collections -from collections.abc import MutableMapping import enum +from collections.abc import Iterator, MutableMapping +from typing import Union from hyperframe.frame import SettingsFrame -from h2.errors import ErrorCodes -from h2.exceptions import InvalidSettingsValueError +from .errors import ErrorCodes +from .exceptions import InvalidSettingsValueError class SettingCodes(enum.IntEnum): @@ -55,7 +57,7 @@ class SettingCodes(enum.IntEnum): ENABLE_CONNECT_PROTOCOL = SettingsFrame.ENABLE_CONNECT_PROTOCOL -def _setting_code_from_int(code): +def _setting_code_from_int(code: int) -> SettingCodes | int: """ Given an integer setting code, returns either one of :class:`SettingCodes ` or, if not present in the known set of codes, @@ -69,7 +71,7 @@ def _setting_code_from_int(code): class ChangedSetting: - def __init__(self, setting, original_value, new_value): + def __init__(self, setting: SettingCodes | int, original_value: int | None, new_value: int) -> None: #: The setting code given. Either one of :class:`SettingCodes #: ` or ``int`` #: @@ -82,18 +84,13 @@ def __init__(self, setting, original_value, new_value): #: The new value after being changed. self.new_value = new_value - def __repr__(self): + def __repr__(self) -> str: return ( - "ChangedSetting(setting=%s, original_value=%s, " - "new_value=%s)" - ) % ( - self.setting, - self.original_value, - self.new_value + f"ChangedSetting(setting={self.setting!s}, original_value={self.original_value}, new_value={self.new_value})" ) -class Settings(MutableMapping): +class Settings(MutableMapping[Union[SettingCodes, int], int]): """ An object that encapsulates HTTP/2 settings state. @@ -128,14 +125,15 @@ class Settings(MutableMapping): set, rather than RFC 7540's defaults. :type initial_vales: ``MutableMapping`` """ - def __init__(self, client=True, initial_values=None): + + def __init__(self, client: bool = True, initial_values: dict[SettingCodes, int] | None = None) -> None: # Backing object for the settings. This is a dictionary of # (setting: [list of values]), where the first value in the list is the # current value of the setting. Strictly this doesn't use lists but # instead uses collections.deque to avoid repeated memory allocations. # # This contains the default values for HTTP/2. - self._settings = { + self._settings: dict[SettingCodes | int, collections.deque[int]] = { SettingCodes.HEADER_TABLE_SIZE: collections.deque([4096]), SettingCodes.ENABLE_PUSH: collections.deque([int(client)]), SettingCodes.INITIAL_WINDOW_SIZE: collections.deque([65535]), @@ -146,20 +144,21 @@ def __init__(self, client=True, initial_values=None): for key, value in initial_values.items(): invalid = _validate_setting(key, value) if invalid: + msg = f"Setting {key} has invalid value {value}" raise InvalidSettingsValueError( - "Setting %d has invalid value %d" % (key, value), - error_code=invalid + msg, + error_code=invalid, ) self._settings[key] = collections.deque([value]) - def acknowledge(self): + def acknowledge(self) -> dict[SettingCodes | int, ChangedSetting]: """ The settings have been acknowledged, either by the user (remote settings) or by the remote peer (local settings). :returns: A dict of {setting: ChangedSetting} that were applied. """ - changed_settings = {} + changed_settings: dict[SettingCodes | int, ChangedSetting] = {} # If there is more than one setting in the list, we have a setting # value outstanding. Update them. @@ -168,14 +167,14 @@ def acknowledge(self): old_setting = v.popleft() new_setting = v[0] changed_settings[k] = ChangedSetting( - k, old_setting, new_setting + k, old_setting, new_setting, ) return changed_settings # Provide easy-access to well known settings. @property - def header_table_size(self): + def header_table_size(self) -> int: """ The current value of the :data:`HEADER_TABLE_SIZE ` setting. @@ -183,11 +182,11 @@ def header_table_size(self): return self[SettingCodes.HEADER_TABLE_SIZE] @header_table_size.setter - def header_table_size(self, value): + def header_table_size(self, value: int) -> None: self[SettingCodes.HEADER_TABLE_SIZE] = value @property - def enable_push(self): + def enable_push(self) -> int: """ The current value of the :data:`ENABLE_PUSH ` setting. @@ -195,11 +194,11 @@ def enable_push(self): return self[SettingCodes.ENABLE_PUSH] @enable_push.setter - def enable_push(self, value): + def enable_push(self, value: int) -> None: self[SettingCodes.ENABLE_PUSH] = value @property - def initial_window_size(self): + def initial_window_size(self) -> int: """ The current value of the :data:`INITIAL_WINDOW_SIZE ` setting. @@ -207,11 +206,11 @@ def initial_window_size(self): return self[SettingCodes.INITIAL_WINDOW_SIZE] @initial_window_size.setter - def initial_window_size(self, value): + def initial_window_size(self, value: int) -> None: self[SettingCodes.INITIAL_WINDOW_SIZE] = value @property - def max_frame_size(self): + def max_frame_size(self) -> int: """ The current value of the :data:`MAX_FRAME_SIZE ` setting. @@ -219,11 +218,11 @@ def max_frame_size(self): return self[SettingCodes.MAX_FRAME_SIZE] @max_frame_size.setter - def max_frame_size(self, value): + def max_frame_size(self, value: int) -> None: self[SettingCodes.MAX_FRAME_SIZE] = value @property - def max_concurrent_streams(self): + def max_concurrent_streams(self) -> int: """ The current value of the :data:`MAX_CONCURRENT_STREAMS ` setting. @@ -231,11 +230,11 @@ def max_concurrent_streams(self): return self.get(SettingCodes.MAX_CONCURRENT_STREAMS, 2**32+1) @max_concurrent_streams.setter - def max_concurrent_streams(self, value): + def max_concurrent_streams(self, value: int) -> None: self[SettingCodes.MAX_CONCURRENT_STREAMS] = value @property - def max_header_list_size(self): + def max_header_list_size(self) -> int | None: """ The current value of the :data:`MAX_HEADER_LIST_SIZE ` setting. If not set, @@ -246,11 +245,11 @@ def max_header_list_size(self): return self.get(SettingCodes.MAX_HEADER_LIST_SIZE, None) @max_header_list_size.setter - def max_header_list_size(self, value): + def max_header_list_size(self, value: int) -> None: self[SettingCodes.MAX_HEADER_LIST_SIZE] = value @property - def enable_connect_protocol(self): + def enable_connect_protocol(self) -> int: """ The current value of the :data:`ENABLE_CONNECT_PROTOCOL ` setting. @@ -258,11 +257,11 @@ def enable_connect_protocol(self): return self[SettingCodes.ENABLE_CONNECT_PROTOCOL] @enable_connect_protocol.setter - def enable_connect_protocol(self, value): + def enable_connect_protocol(self, value: int) -> None: self[SettingCodes.ENABLE_CONNECT_PROTOCOL] = value # Implement the MutableMapping API. - def __getitem__(self, key): + def __getitem__(self, key: SettingCodes | int) -> int: val = self._settings[key][0] # Things that were created when a setting was received should stay @@ -272,45 +271,44 @@ def __getitem__(self, key): return val - def __setitem__(self, key, value): + def __setitem__(self, key: SettingCodes | int, value: int) -> None: invalid = _validate_setting(key, value) if invalid: + msg = f"Setting {key} has invalid value {value}" raise InvalidSettingsValueError( - "Setting %d has invalid value %d" % (key, value), - error_code=invalid + msg, + error_code=invalid, ) try: items = self._settings[key] except KeyError: - items = collections.deque([None]) + items = collections.deque([None]) # type: ignore self._settings[key] = items items.append(value) - def __delitem__(self, key): + def __delitem__(self, key: SettingCodes | int) -> None: del self._settings[key] - def __iter__(self): + def __iter__(self) -> Iterator[SettingCodes | int]: return self._settings.__iter__() - def __len__(self): + def __len__(self) -> int: return len(self._settings) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Settings): return self._settings == other._settings - else: - return NotImplemented + return NotImplemented - def __ne__(self, other): + def __ne__(self, other: object) -> bool: if isinstance(other, Settings): return not self == other - else: - return NotImplemented + return NotImplemented -def _validate_setting(setting, value): # noqa: C901 +def _validate_setting(setting: SettingCodes | int, value: int) -> ErrorCodes: """ Confirms that a specific setting has a well-formed value. If the setting is invalid, returns an error code. Otherwise, returns 0 (NO_ERROR). @@ -327,8 +325,7 @@ def _validate_setting(setting, value): # noqa: C901 elif setting == SettingCodes.MAX_HEADER_LIST_SIZE: if value < 0: return ErrorCodes.PROTOCOL_ERROR - elif setting == SettingCodes.ENABLE_CONNECT_PROTOCOL: - if value not in (0, 1): - return ErrorCodes.PROTOCOL_ERROR + elif setting == SettingCodes.ENABLE_CONNECT_PROTOCOL and value not in (0, 1): + return ErrorCodes.PROTOCOL_ERROR - return 0 + return ErrorCodes.NO_ERROR diff --git a/src/h2/stream.py b/src/h2/stream.py index a6b77289..7d4a12e3 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -1,35 +1,58 @@ -# -*- coding: utf-8 -*- """ h2/stream ~~~~~~~~~ An implementation of a HTTP/2 stream. """ +from __future__ import annotations + from enum import Enum, IntEnum +from typing import TYPE_CHECKING, Any + from hpack import HeaderTuple -from hyperframe.frame import ( - HeadersFrame, ContinuationFrame, DataFrame, WindowUpdateFrame, - RstStreamFrame, PushPromiseFrame, AltSvcFrame -) +from hyperframe.frame import AltSvcFrame, ContinuationFrame, DataFrame, Frame, HeadersFrame, PushPromiseFrame, RstStreamFrame, WindowUpdateFrame from .errors import ErrorCodes, _error_code_from_int from .events import ( - RequestReceived, ResponseReceived, DataReceived, WindowUpdated, - StreamEnded, PushedStreamReceived, StreamReset, TrailersReceived, - InformationalResponseReceived, AlternativeServiceAvailable, - _ResponseSent, _RequestSent, _TrailersSent, _PushedRequestSent -) -from .exceptions import ( - ProtocolError, StreamClosedError, InvalidBodyLengthError, FlowControlError + AlternativeServiceAvailable, + DataReceived, + Event, + InformationalResponseReceived, + PushedStreamReceived, + RequestReceived, + ResponseReceived, + StreamEnded, + StreamReset, + TrailersReceived, + WindowUpdated, + _PushedRequestSent, + _RequestSent, + _ResponseSent, + _TrailersSent, ) +from .exceptions import FlowControlError, InvalidBodyLengthError, ProtocolError, StreamClosedError from .utilities import ( - guard_increment_window, is_informational_response, authority_from_headers, - validate_headers, validate_outbound_headers, normalize_outbound_headers, - HeaderValidationFlags, extract_method_header, normalize_inbound_headers, - utf8_encode_headers + HeaderValidationFlags, + authority_from_headers, + extract_method_header, + guard_increment_window, + is_informational_response, + normalize_inbound_headers, + normalize_outbound_headers, + utf8_encode_headers, + validate_headers, + validate_outbound_headers, ) from .windows import WindowManager +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Generator, Iterable + + from hpack.hpack import Encoder + from hpack.struct import Header, HeaderWeaklyTyped + + from .config import H2Configuration + class StreamState(IntEnum): IDLE = 0 @@ -75,7 +98,7 @@ class StreamClosedBy(Enum): # this is that we potentially check whether a stream in a given state is open # quite frequently: given that we check so often, we should do so in the # fastest and most performant way possible. -STREAM_OPEN = [False for _ in range(0, len(StreamState))] +STREAM_OPEN = [False for _ in range(len(StreamState))] STREAM_OPEN[StreamState.OPEN] = True STREAM_OPEN[StreamState.HALF_CLOSED_LOCAL] = True STREAM_OPEN[StreamState.HALF_CLOSED_REMOTE] = True @@ -91,37 +114,38 @@ class H2StreamStateMachine: :param stream_id: The stream ID of this stream. This is stored primarily for logging purposes. """ - def __init__(self, stream_id): + + def __init__(self, stream_id: int) -> None: self.state = StreamState.IDLE self.stream_id = stream_id #: Whether this peer is the client side of this stream. - self.client = None + self.client: bool | None = None # Whether trailers have been sent/received on this stream or not. - self.headers_sent = None - self.trailers_sent = None - self.headers_received = None - self.trailers_received = None + self.headers_sent: bool | None = None + self.trailers_sent: bool | None = None + self.headers_received: bool | None = None + self.trailers_received: bool | None = None # How the stream was closed. One of StreamClosedBy. - self.stream_closed_by = None + self.stream_closed_by: StreamClosedBy | None = None - def process_input(self, input_): + def process_input(self, input_: StreamInputs) -> Any: """ Process a specific input in the state machine. """ if not isinstance(input_, StreamInputs): - raise ValueError("Input must be an instance of StreamInputs") + msg = "Input must be an instance of StreamInputs" + raise ValueError(msg) # noqa: TRY004 try: func, target_state = _transitions[(self.state, input_)] - except KeyError: + except KeyError as err: old_state = self.state self.state = StreamState.CLOSED - raise ProtocolError( - "Invalid input %s in state %s" % (input_, old_state) - ) + msg = f"Invalid input {input_} in state {old_state}" + raise ProtocolError(msg) from err else: previous_state = self.state self.state = target_state @@ -131,13 +155,13 @@ def process_input(self, input_): except ProtocolError: self.state = StreamState.CLOSED raise - except AssertionError as e: # pragma: no cover + except AssertionError as err: # pragma: no cover self.state = StreamState.CLOSED - raise ProtocolError(e) + raise ProtocolError(err) from err return [] - def request_sent(self, previous_state): + def request_sent(self, previous_state: StreamState) -> list[Event]: """ Fires when a request is sent. """ @@ -147,24 +171,22 @@ def request_sent(self, previous_state): return [event] - def response_sent(self, previous_state): + def response_sent(self, previous_state: StreamState) -> list[Event]: """ Fires when something that should be a response is sent. This 'response' may actually be trailers. """ if not self.headers_sent: if self.client is True or self.client is None: - raise ProtocolError("Client cannot send responses.") + msg = "Client cannot send responses." + raise ProtocolError(msg) self.headers_sent = True - event = _ResponseSent() - else: - assert not self.trailers_sent - self.trailers_sent = True - event = _TrailersSent() - - return [event] + return [_ResponseSent()] + assert not self.trailers_sent + self.trailers_sent = True + return [_TrailersSent()] - def request_received(self, previous_state): + def request_received(self, previous_state: StreamState) -> list[Event]: """ Fires when a request is received. """ @@ -174,15 +196,15 @@ def request_received(self, previous_state): self.client = False self.headers_received = True event = RequestReceived() - event.stream_id = self.stream_id return [event] - def response_received(self, previous_state): + def response_received(self, previous_state: StreamState) -> list[Event]: """ Fires when a response is received. Also disambiguates between responses and trailers. """ + event: ResponseReceived | TrailersReceived if not self.headers_received: assert self.client is True self.headers_received = True @@ -195,17 +217,18 @@ def response_received(self, previous_state): event.stream_id = self.stream_id return [event] - def data_received(self, previous_state): + def data_received(self, previous_state: StreamState) -> list[Event]: """ Fires when data is received. """ if not self.headers_received: - raise ProtocolError("cannot receive data before headers") + msg = "cannot receive data before headers" + raise ProtocolError(msg) event = DataReceived() event.stream_id = self.stream_id return [event] - def window_updated(self, previous_state): + def window_updated(self, previous_state: StreamState) -> list[Event]: """ Fires when a window update frame is received. """ @@ -213,7 +236,7 @@ def window_updated(self, previous_state): event.stream_id = self.stream_id return [event] - def stream_half_closed(self, previous_state): + def stream_half_closed(self, previous_state: StreamState) -> list[Event]: """ Fires when an END_STREAM flag is received in the OPEN state, transitioning this stream to a HALF_CLOSED_REMOTE state. @@ -222,7 +245,7 @@ def stream_half_closed(self, previous_state): event.stream_id = self.stream_id return [event] - def stream_ended(self, previous_state): + def stream_ended(self, previous_state: StreamState) -> list[Event]: """ Fires when a stream is cleanly ended. """ @@ -231,7 +254,7 @@ def stream_ended(self, previous_state): event.stream_id = self.stream_id return [event] - def stream_reset(self, previous_state): + def stream_reset(self, previous_state: StreamState) -> list[Event]: """ Fired when a stream is forcefully reset. """ @@ -240,7 +263,7 @@ def stream_reset(self, previous_state): event.stream_id = self.stream_id return [event] - def send_new_pushed_stream(self, previous_state): + def send_new_pushed_stream(self, previous_state: StreamState) -> list[Event]: """ Fires on the newly pushed stream, when pushed by the local peer. @@ -251,7 +274,7 @@ def send_new_pushed_stream(self, previous_state): self.headers_received = True return [] - def recv_new_pushed_stream(self, previous_state): + def recv_new_pushed_stream(self, previous_state: StreamState) -> list[Event]: """ Fires on the newly pushed stream, when pushed by the remote peer. @@ -262,18 +285,19 @@ def recv_new_pushed_stream(self, previous_state): self.headers_sent = True return [] - def send_push_promise(self, previous_state): + def send_push_promise(self, previous_state: StreamState) -> list[Event]: """ Fires on the already-existing stream when a PUSH_PROMISE frame is sent. We may only send PUSH_PROMISE frames if we're a server. """ if self.client is True: - raise ProtocolError("Cannot push streams from client peers.") + msg = "Cannot push streams from client peers." + raise ProtocolError(msg) event = _PushedRequestSent() return [event] - def recv_push_promise(self, previous_state): + def recv_push_promise(self, previous_state: StreamState) -> list[Event]: """ Fires on the already-existing stream when a PUSH_PROMISE frame is received. We may only receive PUSH_PROMISE frames if we're a client. @@ -291,21 +315,21 @@ def recv_push_promise(self, previous_state): event.parent_stream_id = self.stream_id return [event] - def send_end_stream(self, previous_state): + def send_end_stream(self, previous_state: StreamState) -> None: """ Called when an attempt is made to send END_STREAM in the HALF_CLOSED_REMOTE state. """ self.stream_closed_by = StreamClosedBy.SEND_END_STREAM - def send_reset_stream(self, previous_state): + def send_reset_stream(self, previous_state: StreamState) -> None: """ Called when an attempt is made to send RST_STREAM in a non-closed stream state. """ self.stream_closed_by = StreamClosedBy.SEND_RST_STREAM - def reset_stream_on_error(self, previous_state): + def reset_stream_on_error(self, previous_state: StreamState) -> None: """ Called when we need to forcefully emit another RST_STREAM frame on behalf of the state machine. @@ -326,7 +350,7 @@ def reset_stream_on_error(self, previous_state): error._events = [event] raise error - def recv_on_closed_stream(self, previous_state): + def recv_on_closed_stream(self, previous_state: StreamState) -> None: """ Called when an unexpected frame is received on an already-closed stream. @@ -338,7 +362,7 @@ def recv_on_closed_stream(self, previous_state): """ raise StreamClosedError(self.stream_id) - def send_on_closed_stream(self, previous_state): + def send_on_closed_stream(self, previous_state: StreamState) -> None: """ Called when an attempt is made to send data on an already-closed stream. @@ -350,7 +374,7 @@ def send_on_closed_stream(self, previous_state): """ raise StreamClosedError(self.stream_id) - def recv_push_on_closed_stream(self, previous_state): + def recv_push_on_closed_stream(self, previous_state: StreamState) -> None: """ Called when a PUSH_PROMISE frame is received on a full stop stream. @@ -366,10 +390,10 @@ def recv_push_on_closed_stream(self, previous_state): if self.stream_closed_by == StreamClosedBy.SEND_RST_STREAM: raise StreamClosedError(self.stream_id) - else: - raise ProtocolError("Attempted to push on closed stream.") + msg = "Attempted to push on closed stream." + raise ProtocolError(msg) - def send_push_on_closed_stream(self, previous_state): + def send_push_on_closed_stream(self, previous_state: StreamState) -> None: """ Called when an attempt is made to push on an already-closed stream. @@ -379,9 +403,10 @@ def send_push_on_closed_stream(self, previous_state): allowed to be there. The only recourse is to tear the whole connection down. """ - raise ProtocolError("Attempted to push on closed stream.") + msg = "Attempted to push on closed stream." + raise ProtocolError(msg) - def send_informational_response(self, previous_state): + def send_informational_response(self, previous_state: StreamState) -> list[Event]: """ Called when an informational header block is sent (that is, a block where the :status header has a 1XX value). @@ -389,24 +414,26 @@ def send_informational_response(self, previous_state): Only enforces that these are sent *before* final headers are sent. """ if self.headers_sent: - raise ProtocolError("Information response after final response") + msg = "Information response after final response" + raise ProtocolError(msg) event = _ResponseSent() return [event] - def recv_informational_response(self, previous_state): + def recv_informational_response(self, previous_state: StreamState) -> list[Event]: """ Called when an informational header block is received (that is, a block where the :status header has a 1XX value). """ if self.headers_received: - raise ProtocolError("Informational response after final response") + msg = "Informational response after final response" + raise ProtocolError(msg) event = InformationalResponseReceived() event.stream_id = self.stream_id return [event] - def recv_alt_svc(self, previous_state): + def recv_alt_svc(self, previous_state: StreamState) -> list[Event]: """ Called when receiving an ALTSVC frame. @@ -446,7 +473,7 @@ def recv_alt_svc(self, previous_state): # the event and let it get populated. return [AlternativeServiceAvailable()] - def send_alt_svc(self, previous_state): + def send_alt_svc(self, previous_state: StreamState) -> None: """ Called when sending an ALTSVC frame on this stream. @@ -460,11 +487,9 @@ def send_alt_svc(self, previous_state): # We should not send ALTSVC after we've sent response headers, as the # client may have disposed of its state. if self.headers_sent: - raise ProtocolError( - "Cannot send ALTSVC after sending response headers." - ) + msg = "Cannot send ALTSVC after sending response headers." + raise ProtocolError(msg) - return # STATE MACHINE @@ -747,15 +772,16 @@ class H2Stream: Attempts to create frames that cannot be sent will raise a ``ProtocolError``. """ + def __init__(self, - stream_id, - config, - inbound_window_size, - outbound_window_size): + stream_id: int, + config: H2Configuration, + inbound_window_size: int, + outbound_window_size: int) -> None: self.state_machine = H2StreamStateMachine(stream_id) self.stream_id = stream_id - self.max_outbound_frame_size = None - self.request_method = None + self.max_outbound_frame_size: int | None = None + self.request_method: bytes | None = None # The current value of the outbound stream flow control window self.outbound_flow_control_window = outbound_window_size @@ -764,26 +790,22 @@ def __init__(self, self._inbound_window_manager = WindowManager(inbound_window_size) # The expected content length, if any. - self._expected_content_length = None + self._expected_content_length: int | None = None # The actual received content length. Always tracked. self._actual_content_length = 0 # The authority we believe this stream belongs to. - self._authority = None + self._authority: bytes | None = None # The configuration for this stream. self.config = config - def __repr__(self): - return "<%s id:%d state:%r>" % ( - type(self).__name__, - self.stream_id, - self.state_machine.state - ) + def __repr__(self) -> str: + return f"<{type(self).__name__} id:{self.stream_id} state:{self.state_machine.state!r}>" @property - def inbound_flow_control_window(self): + def inbound_flow_control_window(self) -> int: """ The size of the inbound flow control window for the stream. This is rarely publicly useful: instead, use :meth:`remote_flow_control_window @@ -793,7 +815,7 @@ def inbound_flow_control_window(self): return self._inbound_window_manager.current_window_size @property - def open(self): + def open(self) -> bool: """ Whether the stream is 'open' in any sense: that is, whether it counts against the number of concurrent streams. @@ -806,20 +828,20 @@ def open(self): return STREAM_OPEN[self.state_machine.state] @property - def closed(self): + def closed(self) -> bool: """ Whether the stream is closed. """ return self.state_machine.state == StreamState.CLOSED @property - def closed_by(self): + def closed_by(self) -> StreamClosedBy | None: """ Returns how the stream was closed, as one of StreamClosedBy. """ return self.state_machine.stream_closed_by - def upgrade(self, client_side): + def upgrade(self, client_side: bool) -> None: """ Called by the connection to indicate that this stream is the initial request/response of an upgraded connection. Places the stream into an @@ -835,9 +857,11 @@ def upgrade(self, client_side): # This may return events, we deliberately don't want them. self.state_machine.process_input(input_) - return - def send_headers(self, headers, encoder, end_stream=False): + def send_headers(self, + headers: Iterable[HeaderWeaklyTyped], + encoder: Encoder, + end_stream: bool = False) -> list[HeadersFrame | ContinuationFrame | PushPromiseFrame]: """ Returns a list of HEADERS/CONTINUATION frames to emit as either headers or trailers. @@ -853,14 +877,13 @@ def send_headers(self, headers, encoder, end_stream=False): # response. input_ = StreamInputs.SEND_HEADERS - headers = utf8_encode_headers(headers) + bytes_headers = utf8_encode_headers(headers) if ((not self.state_machine.client) and - is_informational_response(headers)): + is_informational_response(bytes_headers)): if end_stream: - raise ProtocolError( - "Cannot set END_STREAM on informational responses." - ) + msg = "Cannot set END_STREAM on informational responses." + raise ProtocolError(msg) input_ = StreamInputs.SEND_INFORMATIONAL_HEADERS @@ -869,27 +892,31 @@ def send_headers(self, headers, encoder, end_stream=False): hf = HeadersFrame(self.stream_id) hdr_validation_flags = self._build_hdr_validation_flags(events) frames = self._build_headers_frames( - headers, encoder, hf, hdr_validation_flags + bytes_headers, encoder, hf, hdr_validation_flags, ) if end_stream: # Not a bug: the END_STREAM flag is valid on the initial HEADERS # frame, not the CONTINUATION frames that follow. self.state_machine.process_input(StreamInputs.SEND_END_STREAM) - frames[0].flags.add('END_STREAM') + frames[0].flags.add("END_STREAM") if self.state_machine.trailers_sent and not end_stream: - raise ProtocolError("Trailers must have END_STREAM set.") + msg = "Trailers must have END_STREAM set." + raise ProtocolError(msg) if self.state_machine.client and self._authority is None: - self._authority = authority_from_headers(headers) + self._authority = authority_from_headers(bytes_headers) # store request method for _initialize_content_length - self.request_method = extract_method_header(headers) + self.request_method = extract_method_header(bytes_headers) return frames - def push_stream_in_band(self, related_stream_id, headers, encoder): + def push_stream_in_band(self, + related_stream_id: int, + headers: Iterable[HeaderWeaklyTyped], + encoder: Encoder) -> list[HeadersFrame | ContinuationFrame | PushPromiseFrame]: """ Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed stream header. Called on the stream that has the PUSH_PROMISE frame @@ -901,19 +928,21 @@ def push_stream_in_band(self, related_stream_id, headers, encoder): # compression context, we make the state transition *first*. events = self.state_machine.process_input( - StreamInputs.SEND_PUSH_PROMISE + StreamInputs.SEND_PUSH_PROMISE, ) ppf = PushPromiseFrame(self.stream_id) ppf.promised_stream_id = related_stream_id hdr_validation_flags = self._build_hdr_validation_flags(events) - frames = self._build_headers_frames( - headers, encoder, ppf, hdr_validation_flags + + bytes_headers = utf8_encode_headers(headers) + + return self._build_headers_frames( + bytes_headers, encoder, ppf, hdr_validation_flags, ) - return frames - def locally_pushed(self): + def locally_pushed(self) -> list[Frame]: """ Mark this stream as one that was pushed by this peer. Must be called immediately after initialization. Sends no frames, simply updates the @@ -921,19 +950,22 @@ def locally_pushed(self): """ # This does not trigger any events. events = self.state_machine.process_input( - StreamInputs.SEND_PUSH_PROMISE + StreamInputs.SEND_PUSH_PROMISE, ) assert not events return [] - def send_data(self, data, end_stream=False, pad_length=None): + def send_data(self, + data: bytes | memoryview, + end_stream: bool = False, + pad_length: int | None = None) -> list[Frame]: """ Prepare some data frames. Optionally end the stream. .. warning:: Does not perform flow control checks. """ self.config.logger.debug( - "Send data on %r with end stream set to %s", self, end_stream + "Send data on %r with end stream set to %s", self, end_stream, ) self.state_machine.process_input(StreamInputs.SEND_DATA) @@ -942,9 +974,9 @@ def send_data(self, data, end_stream=False, pad_length=None): df.data = data if end_stream: self.state_machine.process_input(StreamInputs.SEND_END_STREAM) - df.flags.add('END_STREAM') + df.flags.add("END_STREAM") if pad_length is not None: - df.flags.add('PADDED') + df.flags.add("PADDED") df.pad_length = pad_length # Subtract flow_controlled_length to account for possible padding @@ -953,7 +985,7 @@ def send_data(self, data, end_stream=False, pad_length=None): return [df] - def end_stream(self): + def end_stream(self) -> list[Frame]: """ End a stream without sending data. """ @@ -961,29 +993,29 @@ def end_stream(self): self.state_machine.process_input(StreamInputs.SEND_END_STREAM) df = DataFrame(self.stream_id) - df.flags.add('END_STREAM') + df.flags.add("END_STREAM") return [df] - def advertise_alternative_service(self, field_value): + def advertise_alternative_service(self, field_value: bytes) -> list[Frame]: """ Advertise an RFC 7838 alternative service. The semantics of this are better documented in the ``H2Connection`` class. """ self.config.logger.debug( - "Advertise alternative service of %r for %r", field_value, self + "Advertise alternative service of %r for %r", field_value, self, ) self.state_machine.process_input(StreamInputs.SEND_ALTERNATIVE_SERVICE) asf = AltSvcFrame(self.stream_id) asf.field = field_value return [asf] - def increase_flow_control_window(self, increment): + def increase_flow_control_window(self, increment: int) -> list[Frame]: """ Increase the size of the flow control window for the remote side. """ self.config.logger.debug( "Increase flow control window for %r by %d", - self, increment + self, increment, ) self.state_machine.process_input(StreamInputs.SEND_WINDOW_UPDATE) self._inbound_window_manager.window_opened(increment) @@ -993,9 +1025,9 @@ def increase_flow_control_window(self, increment): return [wuf] def receive_push_promise_in_band(self, - promised_stream_id, - headers, - header_encoding): + promised_stream_id: int, + headers: Iterable[Header], + header_encoding: bool | str | None) -> tuple[list[Frame], list[Event]]: """ Receives a push promise frame sent on this stream, pushing a remote stream. This is called on the stream that has the PUSH_PROMISE sent @@ -1003,20 +1035,20 @@ def receive_push_promise_in_band(self, """ self.config.logger.debug( "Receive Push Promise on %r for remote stream %d", - self, promised_stream_id + self, promised_stream_id, ) events = self.state_machine.process_input( - StreamInputs.RECV_PUSH_PROMISE + StreamInputs.RECV_PUSH_PROMISE, ) events[0].pushed_stream_id = promised_stream_id hdr_validation_flags = self._build_hdr_validation_flags(events) events[0].headers = self._process_received_headers( - headers, hdr_validation_flags, header_encoding + headers, hdr_validation_flags, header_encoding, ) return [], events - def remotely_pushed(self, pushed_headers): + def remotely_pushed(self, pushed_headers: Iterable[Header]) -> tuple[list[Frame], list[Event]]: """ Mark this stream as one that was pushed by the remote peer. Must be called immediately after initialization. Sends no frames, simply @@ -1024,20 +1056,22 @@ def remotely_pushed(self, pushed_headers): """ self.config.logger.debug("%r pushed by remote peer", self) events = self.state_machine.process_input( - StreamInputs.RECV_PUSH_PROMISE + StreamInputs.RECV_PUSH_PROMISE, ) self._authority = authority_from_headers(pushed_headers) return [], events - def receive_headers(self, headers, end_stream, header_encoding): + def receive_headers(self, + headers: Iterable[Header], + end_stream: bool, + header_encoding: bool | str | None) -> tuple[list[Frame], list[Event]]: """ Receive a set of headers (or trailers). """ if is_informational_response(headers): if end_stream: - raise ProtocolError( - "Cannot set END_STREAM on informational responses" - ) + msg = "Cannot set END_STREAM on informational responses" + raise ProtocolError(msg) input_ = StreamInputs.RECV_INFORMATIONAL_HEADERS else: input_ = StreamInputs.RECV_HEADERS @@ -1046,30 +1080,30 @@ def receive_headers(self, headers, end_stream, header_encoding): if end_stream: es_events = self.state_machine.process_input( - StreamInputs.RECV_END_STREAM + StreamInputs.RECV_END_STREAM, ) events[0].stream_ended = es_events[0] events += es_events self._initialize_content_length(headers) - if isinstance(events[0], TrailersReceived): - if not end_stream: - raise ProtocolError("Trailers must have END_STREAM set") + if isinstance(events[0], TrailersReceived) and not end_stream: + msg = "Trailers must have END_STREAM set" + raise ProtocolError(msg) hdr_validation_flags = self._build_hdr_validation_flags(events) events[0].headers = self._process_received_headers( - headers, hdr_validation_flags, header_encoding + headers, hdr_validation_flags, header_encoding, ) return [], events - def receive_data(self, data, end_stream, flow_control_len): + def receive_data(self, data: bytes, end_stream: bool, flow_control_len: int) -> tuple[list[Frame], list[Event]]: """ Receive some data. """ self.config.logger.debug( "Receive data on %r with end stream %s and flow control length " - "set to %d", self, end_stream, flow_control_len + "set to %d", self, end_stream, flow_control_len, ) events = self.state_machine.process_input(StreamInputs.RECV_DATA) self._inbound_window_manager.window_consumed(flow_control_len) @@ -1077,7 +1111,7 @@ def receive_data(self, data, end_stream, flow_control_len): if end_stream: es_events = self.state_machine.process_input( - StreamInputs.RECV_END_STREAM + StreamInputs.RECV_END_STREAM, ) events[0].stream_ended = es_events[0] events.extend(es_events) @@ -1086,16 +1120,16 @@ def receive_data(self, data, end_stream, flow_control_len): events[0].flow_controlled_length = flow_control_len return [], events - def receive_window_update(self, increment): + def receive_window_update(self, increment: int) -> tuple[list[Frame], list[Event]]: """ Handle a WINDOW_UPDATE increment. """ self.config.logger.debug( "Receive Window Update on %r for increment of %d", - self, increment + self, increment, ) events = self.state_machine.process_input( - StreamInputs.RECV_WINDOW_UPDATE + StreamInputs.RECV_WINDOW_UPDATE, ) frames = [] @@ -1107,7 +1141,7 @@ def receive_window_update(self, increment): try: self.outbound_flow_control_window = guard_increment_window( self.outbound_flow_control_window, - increment + increment, ) except FlowControlError: # Ok, this is bad. We're going to need to perform a local @@ -1122,7 +1156,7 @@ def receive_window_update(self, increment): return frames, events - def receive_continuation(self): + def receive_continuation(self) -> None: """ A naked CONTINUATION frame has been received. This is always an error, but the type of error it is depends on the state of the stream and must @@ -1130,17 +1164,18 @@ def receive_continuation(self): """ self.config.logger.debug("Receive Continuation frame on %r", self) self.state_machine.process_input( - StreamInputs.RECV_CONTINUATION + StreamInputs.RECV_CONTINUATION, ) - assert False, "Should not be reachable" + msg = "Should not be reachable" # pragma: no cover + raise AssertionError(msg) # pragma: no cover - def receive_alt_svc(self, frame): + def receive_alt_svc(self, frame: AltSvcFrame) -> tuple[list[Frame], list[Event]]: """ An Alternative Service frame was received on the stream. This frame inherits the origin associated with this stream. """ self.config.logger.debug( - "Receive Alternative Service frame on stream %r", self + "Receive Alternative Service frame on stream %r", self, ) # If the origin is present, RFC 7838 says we have to ignore it. @@ -1148,7 +1183,7 @@ def receive_alt_svc(self, frame): return [], [] events = self.state_machine.process_input( - StreamInputs.RECV_ALTERNATIVE_SERVICE + StreamInputs.RECV_ALTERNATIVE_SERVICE, ) # There are lots of situations where we want to ignore the ALTSVC @@ -1161,12 +1196,12 @@ def receive_alt_svc(self, frame): return [], events - def reset_stream(self, error_code=0): + def reset_stream(self, error_code: ErrorCodes | int = 0) -> list[Frame]: """ Close the stream locally. Reset the stream with an error code. """ self.config.logger.debug( - "Local reset %r with error code: %d", self, error_code + "Local reset %r with error code: %d", self, error_code, ) self.state_machine.process_input(StreamInputs.SEND_RST_STREAM) @@ -1174,12 +1209,12 @@ def reset_stream(self, error_code=0): rsf.error_code = error_code return [rsf] - def stream_reset(self, frame): + def stream_reset(self, frame: RstStreamFrame) -> tuple[list[Frame], list[Event]]: """ Handle a stream being reset remotely. """ self.config.logger.debug( - "Remote reset %r with error code: %d", self, frame.error_code + "Remote reset %r with error code: %d", self, frame.error_code, ) events = self.state_machine.process_input(StreamInputs.RECV_RST_STREAM) @@ -1189,7 +1224,7 @@ def stream_reset(self, frame): return [], events - def acknowledge_received_data(self, acknowledged_size): + def acknowledge_received_data(self, acknowledged_size: int) -> list[Frame]: """ The user has informed us that they've processed some amount of data that was received on this stream. Pass that to the window manager and @@ -1197,10 +1232,10 @@ def acknowledge_received_data(self, acknowledged_size): """ self.config.logger.debug( "Acknowledge received data with size %d on %r", - acknowledged_size, self + acknowledged_size, self, ) increment = self._inbound_window_manager.process_bytes( - acknowledged_size + acknowledged_size, ) if increment: f = WindowUpdateFrame(self.stream_id) @@ -1209,44 +1244,42 @@ def acknowledge_received_data(self, acknowledged_size): return [] - def _build_hdr_validation_flags(self, events): + def _build_hdr_validation_flags(self, events: Any) -> HeaderValidationFlags: """ Constructs a set of header validation flags for use when normalizing and validating header blocks. """ is_trailer = isinstance( - events[0], (_TrailersSent, TrailersReceived) + events[0], (_TrailersSent, TrailersReceived), ) is_response_header = isinstance( events[0], ( _ResponseSent, ResponseReceived, - InformationalResponseReceived - ) + InformationalResponseReceived, + ), ) is_push_promise = isinstance( - events[0], (PushedStreamReceived, _PushedRequestSent) + events[0], (PushedStreamReceived, _PushedRequestSent), ) return HeaderValidationFlags( - is_client=self.state_machine.client, + is_client=self.state_machine.client or False, is_trailer=is_trailer, is_response_header=is_response_header, is_push_promise=is_push_promise, ) def _build_headers_frames(self, - headers, - encoder, - first_frame, - hdr_validation_flags): + headers: Iterable[Header], + encoder: Encoder, + first_frame: HeadersFrame | PushPromiseFrame, + hdr_validation_flags: HeaderValidationFlags) \ + -> list[HeadersFrame | ContinuationFrame | PushPromiseFrame]: """ Helper method to build headers or push promise frames. """ - - headers = utf8_encode_headers(headers) - # We need to lowercase the header names, and to ensure that secure # header fields are kept out of compression contexts. if self.config.normalize_outbound_headers: @@ -1255,11 +1288,11 @@ def _build_headers_frames(self, should_split_outbound_cookies = self.config.split_outbound_cookies headers = normalize_outbound_headers( - headers, hdr_validation_flags, should_split_outbound_cookies + headers, hdr_validation_flags, should_split_outbound_cookies, ) if self.config.validate_outbound_headers: headers = validate_outbound_headers( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) encoded_headers = encoder.encode(headers) @@ -1268,13 +1301,13 @@ def _build_headers_frames(self, # it only works right because we never send padded frames or priority # information on the frames. Revisit this if we do. header_blocks = [ - encoded_headers[i:i+self.max_outbound_frame_size] + encoded_headers[i:i+(self.max_outbound_frame_size or 0)] for i in range( - 0, len(encoded_headers), self.max_outbound_frame_size + 0, len(encoded_headers), (self.max_outbound_frame_size or 0), ) ] - frames = [] + frames: list[HeadersFrame | ContinuationFrame | PushPromiseFrame] = [] first_frame.data = header_blocks[0] frames.append(first_frame) @@ -1283,13 +1316,13 @@ def _build_headers_frames(self, cf.data = block frames.append(cf) - frames[-1].flags.add('END_HEADERS') + frames[-1].flags.add("END_HEADERS") return frames def _process_received_headers(self, - headers, - header_validation_flags, - header_encoding): + headers: Iterable[Header], + header_validation_flags: HeaderValidationFlags, + header_encoding: bool | str | None) -> Iterable[Header]: """ When headers have been received from the remote peer, run a processing pipeline on them to transform them into the appropriate form for @@ -1297,41 +1330,40 @@ def _process_received_headers(self, """ if self.config.normalize_inbound_headers: headers = normalize_inbound_headers( - headers, header_validation_flags + headers, header_validation_flags, ) if self.config.validate_inbound_headers: headers = validate_headers(headers, header_validation_flags) - if header_encoding: + if isinstance(header_encoding, str): headers = _decode_headers(headers, header_encoding) # The above steps are all generators, so we need to concretize the # headers now. return list(headers) - def _initialize_content_length(self, headers): + def _initialize_content_length(self, headers: Iterable[Header]) -> None: """ Checks the headers for a content-length header and initializes the _expected_content_length field from it. It's not an error for no Content-Length header to be present. """ - if self.request_method == b'HEAD': + if self.request_method == b"HEAD": self._expected_content_length = 0 return for n, v in headers: - if n == b'content-length': + if n == b"content-length": try: self._expected_content_length = int(v, 10) - except ValueError: - raise ProtocolError( - f"Invalid content-length header: {repr(v)}" - ) + except ValueError as err: + msg = f"Invalid content-length header: {v!r}" + raise ProtocolError(msg) from err return - def _track_content_length(self, length, end_stream): + def _track_content_length(self, length: int, end_stream: bool) -> None: """ Update the expected content length in response to data being received. Validates that the appropriate amount of data is sent. Always updates @@ -1352,7 +1384,7 @@ def _track_content_length(self, length, end_stream): if end_stream and expected != actual: raise InvalidBodyLengthError(expected, actual) - def _inbound_flow_control_change_from_settings(self, delta): + def _inbound_flow_control_change_from_settings(self, delta: int) -> None: """ We changed SETTINGS_INITIAL_WINDOW_SIZE, which means we need to update the target window size for flow control. For our flow control @@ -1365,7 +1397,7 @@ def _inbound_flow_control_change_from_settings(self, delta): self._inbound_window_manager.max_window_size = new_max_size -def _decode_headers(headers, encoding): +def _decode_headers(headers: Iterable[HeaderWeaklyTyped], encoding: str) -> Generator[HeaderTuple, None, None]: """ Given an iterable of header two-tuples and an encoding, decodes those headers using that encoding while preserving the type of the header tuple. @@ -1377,6 +1409,9 @@ def _decode_headers(headers, encoding): assert isinstance(header, HeaderTuple) name, value = header - name = name.decode(encoding) - value = value.decode(encoding) - yield header.__class__(name, value) + assert isinstance(name, bytes) + assert isinstance(value, bytes) + + n = name.decode(encoding) + v = value.decode(encoding) + yield header.__class__(n, v) diff --git a/src/h2/utilities.py b/src/h2/utilities.py index 1d60f1bd..8cafdbd5 100644 --- a/src/h2/utilities.py +++ b/src/h2/utilities.py @@ -1,73 +1,80 @@ -# -*- coding: utf-8 -*- """ h2/utilities ~~~~~~~~~~~~ Utility functions that do not belong in a separate module. """ +from __future__ import annotations + import collections import re from string import whitespace +from typing import TYPE_CHECKING, Any, NamedTuple + +from hpack.struct import HeaderTuple, NeverIndexedHeaderTuple -from hpack import HeaderTuple, NeverIndexedHeaderTuple +from .exceptions import FlowControlError, ProtocolError -from .exceptions import ProtocolError, FlowControlError +if TYPE_CHECKING: # pragma: no cover + from collections.abc import Generator, Iterable + from hpack.struct import Header, HeaderWeaklyTyped UPPER_RE = re.compile(b"[A-Z]") -SIGIL = ord(b':') -INFORMATIONAL_START = ord(b'1') +SIGIL = ord(b":") +INFORMATIONAL_START = ord(b"1") # A set of headers that are hop-by-hop or connection-specific and thus # forbidden in HTTP/2. This list comes from RFC 7540 ยง 8.1.2.2. CONNECTION_HEADERS = frozenset([ - b'connection', - b'proxy-connection', - b'keep-alive', - b'transfer-encoding', - b'upgrade', + b"connection", + b"proxy-connection", + b"keep-alive", + b"transfer-encoding", + b"upgrade", ]) _ALLOWED_PSEUDO_HEADER_FIELDS = frozenset([ - b':method', - b':scheme', - b':authority', - b':path', - b':status', - b':protocol', + b":method", + b":scheme", + b":authority", + b":path", + b":status", + b":protocol", ]) _SECURE_HEADERS = frozenset([ # May have basic credentials which are vulnerable to dictionary attacks. - b'authorization', - b'proxy-authorization', + b"authorization", + b"proxy-authorization", ]) _REQUEST_ONLY_HEADERS = frozenset([ - b':scheme', - b':path', - b':authority', - b':method', - b':protocol', + b":scheme", + b":path", + b":authority", + b":method", + b":protocol", ]) -_RESPONSE_ONLY_HEADERS = frozenset([b':status']) +_RESPONSE_ONLY_HEADERS = frozenset([b":status"]) # A Set of pseudo headers that are only valid if the method is # CONNECT, see RFC 8441 ยง 5 -_CONNECT_REQUEST_ONLY_HEADERS = frozenset([b':protocol']) +_CONNECT_REQUEST_ONLY_HEADERS = frozenset([b":protocol"]) _WHITESPACE = frozenset(map(ord, whitespace)) -def _secure_headers(headers, hdr_validation_flags): +def _secure_headers(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Certain headers are at risk of being attacked during the header compression phase, and so need to be kept out of header compression contexts. This @@ -86,24 +93,26 @@ def _secure_headers(headers, hdr_validation_flags): and nghttp2. """ for header in headers: - if header[0] in _SECURE_HEADERS: - yield NeverIndexedHeaderTuple(*header) - elif header[0] == b'cookie' and len(header[1]) < 20: - yield NeverIndexedHeaderTuple(*header) + assert isinstance(header[0], bytes) + if header[0] in _SECURE_HEADERS or (header[0] in b"cookie" and len(header[1]) < 20): + yield NeverIndexedHeaderTuple(header[0], header[1]) else: yield header -def extract_method_header(headers): +def extract_method_header(headers: Iterable[Header]) -> bytes | None: """ Extracts the request method from the headers list. """ for k, v in headers: - if k == b':method': + if isinstance(v, bytes) and k == b":method": return v + if isinstance(v, str) and k == ":method": + return v.encode("utf-8") # pragma: no cover + return None -def is_informational_response(headers): +def is_informational_response(headers: Iterable[Header]) -> bool: """ Searches headers list for a :status header to confirm that a given collection of headers are an informational response. Assumes the header @@ -115,19 +124,17 @@ def is_informational_response(headers): :returns: A boolean indicating if this is an informational response. """ for n, v in headers: - if not isinstance(n, bytes) or not isinstance(v, bytes): - raise ProtocolError(f"header not bytes: {n=:r}, {v=:r}") # pragma: no cover - - if not n.startswith(b':'): + if not n.startswith(b":"): return False - if n != b':status': + if n != b":status": # If we find a non-special header, we're done here: stop looping. continue # If the first digit is a 1, we've got informational headers. - return v.startswith(b'1') + return v.startswith(b"1") + return False -def guard_increment_window(current, increment): +def guard_increment_window(current: int, increment: int) -> int: """ Increments a flow control window, guarding against that window becoming too large. @@ -138,20 +145,18 @@ def guard_increment_window(current, increment): :raises: ``FlowControlError`` """ # The largest value the flow control window may take. - LARGEST_FLOW_CONTROL_WINDOW = 2**31 - 1 + LARGEST_FLOW_CONTROL_WINDOW = 2**31 - 1 # noqa: N806 new_size = current + increment if new_size > LARGEST_FLOW_CONTROL_WINDOW: - raise FlowControlError( - "May not increment flow control window past %d" % - LARGEST_FLOW_CONTROL_WINDOW - ) + msg = f"May not increment flow control window past {LARGEST_FLOW_CONTROL_WINDOW}" + raise FlowControlError(msg) return new_size -def authority_from_headers(headers): +def authority_from_headers(headers: Iterable[Header]) -> bytes | None: """ Given a header set, searches for the authority header and returns the value. @@ -165,7 +170,7 @@ def authority_from_headers(headers): :rtype: ``bytes`` or ``None``. """ for n, v in headers: - if n == b':authority': + if n == b":authority": return v return None @@ -173,13 +178,14 @@ def authority_from_headers(headers): # Flags used by the validate_headers pipeline to determine which checks # should be applied to a given set of headers. -HeaderValidationFlags = collections.namedtuple( - 'HeaderValidationFlags', - ['is_client', 'is_trailer', 'is_response_header', 'is_push_promise'] -) +class HeaderValidationFlags(NamedTuple): + is_client: bool + is_trailer: bool + is_response_header: bool + is_push_promise: bool -def validate_headers(headers, hdr_validation_flags): +def validate_headers(headers: Iterable[Header], hdr_validation_flags: HeaderValidationFlags) -> Iterable[Header]: """ Validates a header sequence against a set of constraints from RFC 7540. @@ -196,32 +202,32 @@ def validate_headers(headers, hdr_validation_flags): # fixed cost that we don't want to spend, instead indexing into the header # tuples. headers = _reject_empty_header_names( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_uppercase_header_fields( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_surrounding_whitespace( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_te( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_connection_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_pseudo_header_fields( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _check_host_authority_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) - headers = _check_path_header(headers, hdr_validation_flags) + return _check_path_header(headers, hdr_validation_flags) - return headers -def _reject_empty_header_names(headers, hdr_validation_flags): +def _reject_empty_header_names(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises a ProtocolError if any header names are empty (length 0). While hpack decodes such headers without errors, they are semantically @@ -230,24 +236,26 @@ def _reject_empty_header_names(headers, hdr_validation_flags): """ for header in headers: if len(header[0]) == 0: - raise ProtocolError("Received header name with zero length.") + msg = "Received header name with zero length." + raise ProtocolError(msg) yield header -def _reject_uppercase_header_fields(headers, hdr_validation_flags): +def _reject_uppercase_header_fields(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises a ProtocolError if any uppercase character is found in a header block. """ for header in headers: if UPPER_RE.search(header[0]): - raise ProtocolError( - f"Received uppercase header name {repr(header[0])}." - ) + msg = f"Received uppercase header name {header[0]!r}." + raise ProtocolError(msg) yield header -def _reject_surrounding_whitespace(headers, hdr_validation_flags): +def _reject_surrounding_whitespace(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises a ProtocolError if any header name or value is surrounded by whitespace characters. @@ -259,58 +267,55 @@ def _reject_surrounding_whitespace(headers, hdr_validation_flags): # doesn't. for header in headers: if header[0][0] in _WHITESPACE or header[0][-1] in _WHITESPACE: - raise ProtocolError( - "Received header name surrounded by whitespace %r" % header[0]) + msg = f"Received header name surrounded by whitespace {header[0]!r}" + raise ProtocolError(msg) if header[1] and ((header[1][0] in _WHITESPACE) or (header[1][-1] in _WHITESPACE)): - raise ProtocolError( - "Received header value surrounded by whitespace %r" % header[1] - ) + msg = f"Received header value surrounded by whitespace {header[1]!r}" + raise ProtocolError(msg) yield header -def _reject_te(headers, hdr_validation_flags): +def _reject_te(headers: Iterable[Header], hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises a ProtocolError if the TE header is present in a header block and its value is anything other than "trailers". """ for header in headers: - if header[0] == b'te': - if header[1].lower() != b'trailers': - raise ProtocolError( - f"Invalid value for TE header: {repr(header[1])}" - ) + if header[0] == b"te" and header[1].lower() != b"trailers": + msg = f"Invalid value for TE header: {header[1]!r}" + raise ProtocolError(msg) yield header -def _reject_connection_header(headers, hdr_validation_flags): +def _reject_connection_header(headers: Iterable[Header], hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises a ProtocolError if the Connection header is present in a header block. """ for header in headers: if header[0] in CONNECTION_HEADERS: - raise ProtocolError( - f"Connection-specific header field present: {repr(header[0])}." - ) + msg = f"Connection-specific header field present: {header[0]!r}." + raise ProtocolError(msg) yield header -def _assert_header_in_set(bytes_header, header_set): +def _assert_header_in_set(bytes_header: bytes, + header_set: set[bytes | str] | set[bytes] | set[str]) -> None: """ Given a set of header names, checks whether the string or byte version of the header name is present. Raises a Protocol error with the appropriate error if it's missing. """ if bytes_header not in header_set: - raise ProtocolError( - f"Header block missing mandatory {repr(bytes_header)} header" - ) + msg = f"Header block missing mandatory {bytes_header!r} header" + raise ProtocolError(msg) -def _reject_pseudo_header_fields(headers, hdr_validation_flags): +def _reject_pseudo_header_fields(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises a ProtocolError if duplicate pseudo-header fields are found in a header block or if a pseudo-header field appears in a block after an @@ -325,23 +330,20 @@ def _reject_pseudo_header_fields(headers, hdr_validation_flags): for header in headers: if header[0][0] == SIGIL: if header[0] in seen_pseudo_header_fields: - raise ProtocolError( - f"Received duplicate pseudo-header field {repr(header[0])}" - ) + msg = f"Received duplicate pseudo-header field {header[0]!r}" + raise ProtocolError(msg) seen_pseudo_header_fields.add(header[0]) if seen_regular_header: - raise ProtocolError( - f"Received pseudo-header field out of sequence: {repr(header[0])}" - ) + msg = f"Received pseudo-header field out of sequence: {header[0]!r}" + raise ProtocolError(msg) if header[0] not in _ALLOWED_PSEUDO_HEADER_FIELDS: - raise ProtocolError( - f"Received custom pseudo-header field {repr(header[0])}" - ) + msg = f"Received custom pseudo-header field {header[0]!r}" + raise ProtocolError(msg) - if header[0] in b':method': + if header[0] in b":method": method = header[1] else: @@ -351,22 +353,21 @@ def _reject_pseudo_header_fields(headers, hdr_validation_flags): # Check the pseudo-headers we got to confirm they're acceptable. _check_pseudo_header_field_acceptability( - seen_pseudo_header_fields, method, hdr_validation_flags + seen_pseudo_header_fields, method, hdr_validation_flags, ) -def _check_pseudo_header_field_acceptability(pseudo_headers, - method, - hdr_validation_flags): +def _check_pseudo_header_field_acceptability(pseudo_headers: set[bytes | str] | set[bytes] | set[str], + method: bytes | None, + hdr_validation_flags: HeaderValidationFlags) -> None: """ Given the set of pseudo-headers present in a header block and the validation flags, confirms that RFC 7540 allows them. """ # Pseudo-header fields MUST NOT appear in trailers - RFC 7540 ยง 8.1.2.1 if hdr_validation_flags.is_trailer and pseudo_headers: - raise ProtocolError( - "Received pseudo-header in trailer %s" % pseudo_headers - ) + msg = f"Received pseudo-header in trailer {pseudo_headers}" + raise ProtocolError(msg) # If ':status' pseudo-header is not there in a response header, reject it. # Similarly, if ':path', ':method', or ':scheme' are not there in a request @@ -375,35 +376,30 @@ def _check_pseudo_header_field_acceptability(pseudo_headers, # Relevant RFC section: RFC 7540 ยง 8.1.2.4 # https://tools.ietf.org/html/rfc7540#section-8.1.2.4 if hdr_validation_flags.is_response_header: - _assert_header_in_set(b':status', pseudo_headers) + _assert_header_in_set(b":status", pseudo_headers) invalid_response_headers = pseudo_headers & _REQUEST_ONLY_HEADERS if invalid_response_headers: - raise ProtocolError( - "Encountered request-only headers %s" % - invalid_response_headers - ) + msg = f"Encountered request-only headers {invalid_response_headers}" + raise ProtocolError(msg) elif (not hdr_validation_flags.is_response_header and not hdr_validation_flags.is_trailer): # This is a request, so we need to have seen :path, :method, and # :scheme. - _assert_header_in_set(b':path', pseudo_headers) - _assert_header_in_set(b':method', pseudo_headers) - _assert_header_in_set(b':scheme', pseudo_headers) + _assert_header_in_set(b":path", pseudo_headers) + _assert_header_in_set(b":method", pseudo_headers) + _assert_header_in_set(b":scheme", pseudo_headers) invalid_request_headers = pseudo_headers & _RESPONSE_ONLY_HEADERS if invalid_request_headers: - raise ProtocolError( - "Encountered response-only headers %s" % - invalid_request_headers - ) - if method != b'CONNECT': + msg = f"Encountered response-only headers {invalid_request_headers}" + raise ProtocolError(msg) + if method != b"CONNECT": invalid_headers = pseudo_headers & _CONNECT_REQUEST_ONLY_HEADERS if invalid_headers: - raise ProtocolError( - f"Encountered connect-request-only headers {repr(invalid_headers)}" - ) + msg = f"Encountered connect-request-only headers {invalid_headers!r}" + raise ProtocolError(msg) -def _validate_host_authority_header(headers): +def _validate_host_authority_header(headers: Iterable[Header]) -> Generator[Header, None, None]: """ Given the :authority and Host headers from a request block that isn't a trailer, check that: @@ -424,9 +420,9 @@ def _validate_host_authority_header(headers): host_header_val = None for header in headers: - if header[0] == b':authority': + if header[0] == b":authority": authority_header_val = header[1] - elif header[0] == b'host': + elif header[0] == b"host": host_header_val = header[1] yield header @@ -439,21 +435,20 @@ def _validate_host_authority_header(headers): # It is an error for a request header block to contain neither # an :authority header nor a Host header. if not authority_present and not host_present: - raise ProtocolError( - "Request header block does not have an :authority or Host header." - ) + msg = "Request header block does not have an :authority or Host header." + raise ProtocolError(msg) # If we receive both headers, they should definitely match. - if authority_present and host_present: - if authority_header_val != host_header_val: - raise ProtocolError( - "Request header block has mismatched :authority and " - "Host headers: %r / %r" - % (authority_header_val, host_header_val) - ) + if authority_present and host_present and authority_header_val != host_header_val: + msg = ( + "Request header block has mismatched :authority and " + f"Host headers: {authority_header_val!r} / {host_header_val!r}" + ) + raise ProtocolError(msg) -def _check_host_authority_header(headers, hdr_validation_flags): +def _check_host_authority_header(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises a ProtocolError if a header block arrives that does not contain an :authority or a Host header, or if a header block contains both fields, @@ -467,21 +462,22 @@ def _check_host_authority_header(headers, hdr_validation_flags): hdr_validation_flags.is_trailer ) if skip_validation: - return headers + return (h for h in headers) return _validate_host_authority_header(headers) -def _check_path_header(headers, hdr_validation_flags): +def _check_path_header(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raise a ProtocolError if a header block arrives or is sent that contains an empty :path header. """ - def inner(): + def inner() -> Generator[Header, None, None]: for header in headers: - if header[0] == b':path': - if not header[1]: - raise ProtocolError("An empty :path header is forbidden") + if header[0] == b":path" and not header[1]: + msg = "An empty :path header is forbidden" + raise ProtocolError(msg) yield header @@ -493,28 +489,27 @@ def inner(): hdr_validation_flags.is_trailer ) if skip_validation: - return headers - else: - return inner() + return (h for h in headers) + return inner() -def _to_bytes(v): +def _to_bytes(v: bytes | str) -> bytes: """ Given an assumed `str` (or anything that supports `.encode()`), encodes it using utf-8 into bytes. Returns the unmodified object if it is already a `bytes` object. """ - return v if isinstance(v, bytes) else v.encode('utf-8') + return v if isinstance(v, bytes) else v.encode("utf-8") -def utf8_encode_headers(headers): +def utf8_encode_headers(headers: Iterable[HeaderWeaklyTyped]) -> list[Header]: """ Given an iterable of header two-tuples, rebuilds that as a list with the header names and values encoded as utf-8 bytes. This function produces tuples that preserve the original type of the header tuple for tuple and any ``HeaderTuple``. """ - encoded_headers = [] + encoded_headers: list[Header] = [] for header in headers: h = (_to_bytes(header[0]), _to_bytes(header[1])) if isinstance(header, HeaderTuple): @@ -524,7 +519,8 @@ def utf8_encode_headers(headers): return encoded_headers -def _lowercase_header_names(headers, hdr_validation_flags): +def _lowercase_header_names(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Given an iterable of header two-tuples, rebuilds that iterable with the header names lowercased. This generator produces tuples that preserve the @@ -537,7 +533,8 @@ def _lowercase_header_names(headers, hdr_validation_flags): yield (header[0].lower(), header[1]) -def _strip_surrounding_whitespace(headers, hdr_validation_flags): +def _strip_surrounding_whitespace(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Given an iterable of header two-tuples, strip both leading and trailing whitespace from both header names and header values. This generator @@ -551,7 +548,8 @@ def _strip_surrounding_whitespace(headers, hdr_validation_flags): yield (header[0].strip(), header[1].strip()) -def _strip_connection_headers(headers, hdr_validation_flags): +def _strip_connection_headers(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ Strip any connection headers as per RFC7540 ยง 8.1.2.2. """ @@ -560,7 +558,8 @@ def _strip_connection_headers(headers, hdr_validation_flags): yield header -def _check_sent_host_authority_header(headers, hdr_validation_flags): +def _check_sent_host_authority_header(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Raises an InvalidHeaderBlockError if we try to send a header block that does not contain an :authority or a Host header, or if @@ -574,12 +573,12 @@ def _check_sent_host_authority_header(headers, hdr_validation_flags): hdr_validation_flags.is_trailer ) if skip_validation: - return headers + return (h for h in headers) return _validate_host_authority_header(headers) -def _combine_cookie_fields(headers, hdr_validation_flags): +def _combine_cookie_fields(headers: Iterable[Header], hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ RFC 7540 ยง 8.1.2.5 allows HTTP/2 clients to split the Cookie header field, which must normally appear only once, into multiple fields for better @@ -591,18 +590,19 @@ def _combine_cookie_fields(headers, hdr_validation_flags): # possible that all these cookies are sent with different header indexing # values. At this point it shouldn't matter too much, so we apply our own # logic and make them never-indexed. - cookies = [] + cookies: list[bytes] = [] for header in headers: - if header[0] == b'cookie': + if header[0] == b"cookie": cookies.append(header[1]) else: yield header if cookies: - cookie_val = b'; '.join(cookies) - yield NeverIndexedHeaderTuple(b'cookie', cookie_val) + cookie_val = b"; ".join(cookies) + yield NeverIndexedHeaderTuple(b"cookie", cookie_val) -def _split_outbound_cookie_fields(headers, hdr_validation_flags): +def _split_outbound_cookie_fields(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags | None) -> Generator[Header, None, None]: """ RFC 7540 ยง 8.1.2.5 allows for better compression efficiency, to split the Cookie header field into separate header fields @@ -611,8 +611,10 @@ def _split_outbound_cookie_fields(headers, hdr_validation_flags): inbound. """ for header in headers: - if header[0] == b'cookie': - for cookie_val in header[1].split(b'; '): + assert isinstance(header[0], bytes) + assert isinstance(header[1], bytes) + if header[0] == b"cookie": + for cookie_val in header[1].split(b"; "): if isinstance(header, HeaderTuple): yield header.__class__(header[0], cookie_val) else: @@ -621,7 +623,9 @@ def _split_outbound_cookie_fields(headers, hdr_validation_flags): yield header -def normalize_outbound_headers(headers, hdr_validation_flags, should_split_outbound_cookies): +def normalize_outbound_headers(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags | None, + should_split_outbound_cookies: bool=False) -> Generator[Header, None, None]: """ Normalizes a header sequence that we are about to send. @@ -634,23 +638,23 @@ def normalize_outbound_headers(headers, hdr_validation_flags, should_split_outbo headers = _split_outbound_cookie_fields(headers, hdr_validation_flags) headers = _strip_surrounding_whitespace(headers, hdr_validation_flags) headers = _strip_connection_headers(headers, hdr_validation_flags) - headers = _secure_headers(headers, hdr_validation_flags) + return _secure_headers(headers, hdr_validation_flags) - return headers -def normalize_inbound_headers(headers, hdr_validation_flags): +def normalize_inbound_headers(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Normalizes a header sequence that we have received. :param headers: The HTTP header set. :param hdr_validation_flags: An instance of HeaderValidationFlags """ - headers = _combine_cookie_fields(headers, hdr_validation_flags) - return headers + return _combine_cookie_fields(headers, hdr_validation_flags) -def validate_outbound_headers(headers, hdr_validation_flags): +def validate_outbound_headers(headers: Iterable[Header], + hdr_validation_flags: HeaderValidationFlags) -> Generator[Header, None, None]: """ Validates and normalizes a header sequence that we are about to send. @@ -658,36 +662,35 @@ def validate_outbound_headers(headers, hdr_validation_flags): :param hdr_validation_flags: An instance of HeaderValidationFlags. """ headers = _reject_te( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_connection_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _reject_pseudo_header_fields( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) headers = _check_sent_host_authority_header( - headers, hdr_validation_flags + headers, hdr_validation_flags, ) - headers = _check_path_header(headers, hdr_validation_flags) + return _check_path_header(headers, hdr_validation_flags) - return headers -class SizeLimitDict(collections.OrderedDict): +class SizeLimitDict(collections.OrderedDict[int, Any]): - def __init__(self, *args, **kwargs): + def __init__(self, *args: dict[int, int], **kwargs: Any) -> None: self._size_limit = kwargs.pop("size_limit", None) - super(SizeLimitDict, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._check_size_limit() - def __setitem__(self, key, value): - super(SizeLimitDict, self).__setitem__(key, value) + def __setitem__(self, key: int, value: Any | int) -> None: + super().__setitem__(key, value) self._check_size_limit() - def _check_size_limit(self): + def _check_size_limit(self) -> None: if self._size_limit is not None: while len(self) > self._size_limit: self.popitem(last=False) diff --git a/src/h2/windows.py b/src/h2/windows.py index be4eb438..0efdd9fe 100644 --- a/src/h2/windows.py +++ b/src/h2/windows.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ h2/windows ~~~~~~~~~~ @@ -12,11 +11,10 @@ to manage the flow control window without user input, trying to ensure that it does not emit too many WINDOW_UPDATE frames. """ -from __future__ import division +from __future__ import annotations from .exceptions import FlowControlError - # The largest acceptable value for a HTTP/2 flow control window. LARGEST_FLOW_CONTROL_WINDOW = 2**31 - 1 @@ -28,13 +26,14 @@ class WindowManager: :param max_window_size: The maximum size of the flow control window. :type max_window_size: ``int`` """ - def __init__(self, max_window_size): + + def __init__(self, max_window_size: int) -> None: assert max_window_size <= LARGEST_FLOW_CONTROL_WINDOW self.max_window_size = max_window_size self.current_window_size = max_window_size self._bytes_processed = 0 - def window_consumed(self, size): + def window_consumed(self, size: int) -> None: """ We have received a certain number of bytes from the remote peer. This necessarily shrinks the flow control window! @@ -47,9 +46,10 @@ def window_consumed(self, size): """ self.current_window_size -= size if self.current_window_size < 0: - raise FlowControlError("Flow control window shrunk below 0") + msg = "Flow control window shrunk below 0" + raise FlowControlError(msg) - def window_opened(self, size): + def window_opened(self, size: int) -> None: """ The flow control window has been incremented, either because of manual flow control management or because of the user changing the flow @@ -67,15 +67,12 @@ def window_opened(self, size): self.current_window_size += size if self.current_window_size > LARGEST_FLOW_CONTROL_WINDOW: - raise FlowControlError( - "Flow control window mustn't exceed %d" % - LARGEST_FLOW_CONTROL_WINDOW - ) + msg = f"Flow control window mustn't exceed {LARGEST_FLOW_CONTROL_WINDOW}" + raise FlowControlError(msg) - if self.current_window_size > self.max_window_size: - self.max_window_size = self.current_window_size + self.max_window_size = max(self.current_window_size, self.max_window_size) - def process_bytes(self, size): + def process_bytes(self, size: int) -> int | None: """ The application has informed us that it has processed a certain number of bytes. This may cause us to want to emit a window update frame. If @@ -92,7 +89,7 @@ def process_bytes(self, size): self._bytes_processed += size return self._maybe_update_window() - def _maybe_update_window(self): + def _maybe_update_window(self) -> int | None: """ Run the algorithm. @@ -127,11 +124,8 @@ def _maybe_update_window(self): # Note that, even though we may increment less than _bytes_processed, # we still want to set it to zero whenever we emit an increment. This # is because we'll always increment up to the maximum we can. - if (self.current_window_size == 0) and ( - self._bytes_processed > min(1024, self.max_window_size // 4)): - increment = min(self._bytes_processed, max_increment) - self._bytes_processed = 0 - elif self._bytes_processed >= (self.max_window_size // 2): + if ((self.current_window_size == 0) and ( + self._bytes_processed > min(1024, self.max_window_size // 4))) or self._bytes_processed >= (self.max_window_size // 2): increment = min(self._bytes_processed, max_increment) self._bytes_processed = 0 diff --git a/tests/conftest.py b/tests/conftest.py index b29c28ce..acb395c4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ -# -*- coding: utf-8 -*- +from __future__ import annotations + import pytest from . import helpers diff --git a/tests/coroutine_tests.py b/tests/coroutine_tests.py index 0f48c02d..42a10d73 100644 --- a/tests/coroutine_tests.py +++ b/tests/coroutine_tests.py @@ -1,8 +1,4 @@ -# -*- coding: utf-8 -*- """ -coroutine_tests -~~~~~~~~~~~~~~~ - This file gives access to a coroutine-based test class. This allows each test case to be defined as a pair of interacting coroutines, sending data to each other by yielding the flow of control. @@ -13,13 +9,15 @@ makes them behave identically on all platforms, as well as ensuring they both succeed and fail quickly. """ -import itertools +from __future__ import annotations + import functools +import itertools import pytest -class CoroutineTestCase(object): +class CoroutineTestCase: """ A base class for tests that use interacting coroutines. @@ -29,7 +27,8 @@ class CoroutineTestCase(object): its first action is to receive data), the calling code should prime it by using the 'server' decorator on this class. """ - def run_until_complete(self, *coroutines): + + def run_until_complete(self, *coroutines) -> None: """ Executes a set of coroutines that communicate between each other. Each one is, in order, passed the output of the previous coroutine until @@ -56,7 +55,7 @@ def run_until_complete(self, *coroutines): except StopIteration: continue else: - pytest.fail("Coroutine %s not exhausted" % coro) + pytest.fail(f"Coroutine {coro} not exhausted") def server(self, func): """ diff --git a/tests/h2spectest.sh b/tests/h2spectest.sh index 02e38d0c..37e7dfa9 100755 --- a/tests/h2spectest.sh +++ b/tests/h2spectest.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash + # A test script that runs the example Python Twisted server and then runs # h2spec against it. Prints the output of h2spec. This script does not expect # to be run directly, but instead via `tox -e h2spec`. diff --git a/tests/helpers.py b/tests/helpers.py index 2a4e9093..a23a79ec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,17 +1,22 @@ -# -*- coding: utf-8 -*- """ -helpers -~~~~~~~ - -This module contains helpers for the h2 tests. +Helper module for the h2 tests. """ +from __future__ import annotations + +from hpack.hpack import Encoder from hyperframe.frame import ( - HeadersFrame, DataFrame, SettingsFrame, WindowUpdateFrame, PingFrame, - GoAwayFrame, RstStreamFrame, PushPromiseFrame, PriorityFrame, - ContinuationFrame, AltSvcFrame + AltSvcFrame, + ContinuationFrame, + DataFrame, + GoAwayFrame, + HeadersFrame, + PingFrame, + PriorityFrame, + PushPromiseFrame, + RstStreamFrame, + SettingsFrame, + WindowUpdateFrame, ) -from hpack.hpack import Encoder - SAMPLE_SETTINGS = { SettingsFrame.HEADER_TABLE_SIZE: 4096, @@ -20,32 +25,35 @@ } -class FrameFactory(object): +class FrameFactory: """ A class containing lots of helper methods and state to build frames. This allows test cases to easily build correct HTTP/2 frames to feed to hyper-h2. """ - def __init__(self): + + def __init__(self) -> None: self.encoder = Encoder() - def refresh_encoder(self): + def refresh_encoder(self) -> None: self.encoder = Encoder() - def preamble(self): - return b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + def preamble(self) -> bytes: + return b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" def build_headers_frame(self, headers, - flags=[], + flags=None, stream_id=1, **priority_kwargs): """ Builds a single valid headers frame out of the contained headers. """ + if flags is None: + flags = [] f = HeadersFrame(stream_id) f.data = self.encoder.encode(headers) - f.flags.add('END_HEADERS') + f.flags.add("END_HEADERS") for flag in flags: f.flags.add(flag) @@ -54,10 +62,12 @@ def build_headers_frame(self, return f - def build_continuation_frame(self, header_block, flags=[], stream_id=1): + def build_continuation_frame(self, header_block, flags=None, stream_id=1): """ Builds a single continuation frame out of the binary header block. """ + if flags is None: + flags = [] f = ContinuationFrame(stream_id) f.data = header_block f.flags = set(flags) @@ -74,7 +84,7 @@ def build_data_frame(self, data, flags=None, stream_id=1, padding_len=0): f.flags = flags if padding_len: - flags.add('PADDED') + flags.add("PADDED") f.pad_length = padding_len return f @@ -85,7 +95,7 @@ def build_settings_frame(self, settings, ack=False): """ f = SettingsFrame(0) if ack: - f.flags.add('ACK') + f.flags.add("ACK") f.settings = settings return f @@ -112,7 +122,7 @@ def build_ping_frame(self, ping_data, flags=None): def build_goaway_frame(self, last_stream_id, error_code=0, - additional_data=b''): + additional_data=b""): """ Builds a single GOAWAY frame. """ @@ -134,15 +144,17 @@ def build_push_promise_frame(self, stream_id, promised_stream_id, headers, - flags=[]): + flags=None): """ Builds a single PUSH_PROMISE frame. """ + if flags is None: + flags = [] f = PushPromiseFrame(stream_id) f.promised_stream_id = promised_stream_id f.data = self.encoder.encode(headers) f.flags = set(flags) - f.flags.add('END_HEADERS') + f.flags.add("END_HEADERS") return f def build_priority_frame(self, @@ -168,7 +180,7 @@ def build_alt_svc_frame(self, stream_id, origin, field): f.field = field return f - def change_table_size(self, new_size): + def change_table_size(self, new_size) -> None: """ Causes the encoder to send a dynamic size update in the next header block it sends. diff --git a/tests/test_basic_logic.py b/tests/test_basic_logic.py index 61f57a24..0ea7e47b 100644 --- a/tests/test_basic_logic.py +++ b/tests/test_basic_logic.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- """ -test_basic_logic -~~~~~~~~~~~~~~~~ - Test the basic logic of the h2 state machines. """ +from __future__ import annotations + import random import hyperframe import pytest from hpack import HeaderTuple +from hypothesis import HealthCheck, given, settings +from hypothesis.strategies import integers import h2.config import h2.connection @@ -22,52 +22,50 @@ from . import helpers -from hypothesis import given, settings, HealthCheck -from hypothesis.strategies import integers - -class TestBasicClient(object): +class TestBasicClient: """ Basic client-side tests. """ + example_request_headers = [ - (u':authority', u'example.com'), - (u':path', u'/'), - (u':scheme', u'https'), - (u':method', u'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] bytes_example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (u':status', u'200'), - (u'server', u'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] bytes_example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), ] - def test_begin_connection(self, frame_factory): + def test_begin_connection(self, frame_factory) -> None: """ Client connections emit the HTTP/2 preamble. """ c = h2.connection.H2Connection() expected_settings = frame_factory.build_settings_frame( - c.local_settings + c.local_settings, ) expected_data = ( - b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + expected_settings.serialize() + b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + expected_settings.serialize() ) events = c.initiate_connection() assert not events assert c.data_to_send() == expected_data - def test_sending_headers(self): + def test_sending_headers(self) -> None: """ Single headers frames are correctly encoded. """ @@ -79,11 +77,11 @@ def test_sending_headers(self): events = c.send_headers(1, self.example_request_headers) assert not events assert c.data_to_send() == ( - b'\x00\x00\r\x01\x04\x00\x00\x00\x01' - b'A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x87\x82' + b"\x00\x00\r\x01\x04\x00\x00\x00\x01" + b"A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x87\x82" ) - def test_sending_data(self): + def test_sending_data(self) -> None: """ Single data frames are encoded correctly. """ @@ -93,25 +91,25 @@ def test_sending_data(self): # Clear the data, then send some data. c.clear_outbound_data_buffer() - events = c.send_data(1, b'some data') + events = c.send_data(1, b"some data") assert not events data_to_send = c.data_to_send() assert ( - data_to_send == b'\x00\x00\t\x00\x00\x00\x00\x00\x01some data' + data_to_send == b"\x00\x00\t\x00\x00\x00\x00\x00\x01some data" ) buffer = h2.frame_buffer.FrameBuffer(server=False) buffer.max_frame_size = 65535 buffer.add_data(data_to_send) - data_frame = list(buffer)[0] + data_frame = next(iter(buffer)) sanity_check_data_frame( data_frame=data_frame, - expected_flow_controlled_length=len(b'some data'), + expected_flow_controlled_length=len(b"some data"), expect_padded_flag=False, - expected_data_frame_pad_length=0 + expected_data_frame_pad_length=0, ) - def test_sending_data_in_memoryview(self): + def test_sending_data_in_memoryview(self) -> None: """ Support memoryview for sending data. """ @@ -121,14 +119,14 @@ def test_sending_data_in_memoryview(self): # Clear the data, then send some data. c.clear_outbound_data_buffer() - events = c.send_data(1, memoryview(b'some data')) + events = c.send_data(1, memoryview(b"some data")) assert not events data_to_send = c.data_to_send() assert ( - data_to_send == b'\x00\x00\t\x00\x00\x00\x00\x00\x01some data' + data_to_send == b"\x00\x00\t\x00\x00\x00\x00\x00\x01some data" ) - def test_sending_data_with_padding(self): + def test_sending_data_with_padding(self) -> None: """ Single data frames with padding are encoded correctly. """ @@ -138,26 +136,26 @@ def test_sending_data_with_padding(self): # Clear the data, then send some data. c.clear_outbound_data_buffer() - events = c.send_data(1, b'some data', pad_length=5) + events = c.send_data(1, b"some data", pad_length=5) assert not events data_to_send = c.data_to_send() assert data_to_send == ( - b'\x00\x00\x0f\x00\x08\x00\x00\x00\x01' - b'\x05some data\x00\x00\x00\x00\x00' + b"\x00\x00\x0f\x00\x08\x00\x00\x00\x01" + b"\x05some data\x00\x00\x00\x00\x00" ) buffer = h2.frame_buffer.FrameBuffer(server=False) buffer.max_frame_size = 65535 buffer.add_data(data_to_send) - data_frame = list(buffer)[0] + data_frame = next(iter(buffer)) sanity_check_data_frame( data_frame=data_frame, - expected_flow_controlled_length=len(b'some data') + 1 + 5, + expected_flow_controlled_length=len(b"some data") + 1 + 5, expect_padded_flag=True, - expected_data_frame_pad_length=5 + expected_data_frame_pad_length=5, ) - def test_sending_data_with_zero_length_padding(self): + def test_sending_data_with_zero_length_padding(self) -> None: """ Single data frames with zero-length padding are encoded correctly. @@ -168,36 +166,36 @@ def test_sending_data_with_zero_length_padding(self): # Clear the data, then send some data. c.clear_outbound_data_buffer() - events = c.send_data(1, b'some data', pad_length=0) + events = c.send_data(1, b"some data", pad_length=0) assert not events data_to_send = c.data_to_send() assert data_to_send == ( - b'\x00\x00\x0a\x00\x08\x00\x00\x00\x01' - b'\x00some data' + b"\x00\x00\x0a\x00\x08\x00\x00\x00\x01" + b"\x00some data" ) buffer = h2.frame_buffer.FrameBuffer(server=False) buffer.max_frame_size = 65535 buffer.add_data(data_to_send) - data_frame = list(buffer)[0] + data_frame = next(iter(buffer)) sanity_check_data_frame( data_frame=data_frame, - expected_flow_controlled_length=len(b'some data') + 1, + expected_flow_controlled_length=len(b"some data") + 1, expect_padded_flag=True, - expected_data_frame_pad_length=0 + expected_data_frame_pad_length=0, ) - @pytest.mark.parametrize("expected_error,pad_length", [ + @pytest.mark.parametrize(("expected_error", "pad_length"), [ (None, 0), (None, 255), (None, None), (ValueError, -1), (ValueError, 256), - (TypeError, 'invalid'), - (TypeError, ''), - (TypeError, '10'), + (TypeError, "invalid"), + (TypeError, ""), + (TypeError, "10"), (TypeError, {}), - (TypeError, ['1', '2', '3']), + (TypeError, ["1", "2", "3"]), (TypeError, []), (TypeError, 1.5), (TypeError, 1.0), @@ -205,7 +203,7 @@ def test_sending_data_with_zero_length_padding(self): ]) def test_sending_data_with_invalid_padding_length(self, expected_error, - pad_length): + pad_length) -> None: """ ``send_data`` with a ``pad_length`` parameter that is an integer outside the range of [0, 255] throws a ``ValueError``, and a @@ -219,11 +217,11 @@ def test_sending_data_with_invalid_padding_length(self, c.clear_outbound_data_buffer() if expected_error is not None: with pytest.raises(expected_error): - c.send_data(1, b'some data', pad_length=pad_length) + c.send_data(1, b"some data", pad_length=pad_length) else: - c.send_data(1, b'some data', pad_length=pad_length) + c.send_data(1, b"some data", pad_length=pad_length) - def test_closing_stream_sending_data(self, frame_factory): + def test_closing_stream_sending_data(self, frame_factory) -> None: """ We can close a stream with a data frame. """ @@ -232,28 +230,28 @@ def test_closing_stream_sending_data(self, frame_factory): c.send_headers(1, self.example_request_headers) f = frame_factory.build_data_frame( - data=b'some data', - flags=['END_STREAM'], + data=b"some data", + flags=["END_STREAM"], ) # Clear the data, then send some data. c.clear_outbound_data_buffer() - events = c.send_data(1, b'some data', end_stream=True) + events = c.send_data(1, b"some data", end_stream=True) assert not events assert c.data_to_send() == f.serialize() - def test_receiving_a_response(self, frame_factory): + def test_receiving_a_response(self, frame_factory) -> None: """ When receiving a response, the ResponseReceived event fires. """ - config = h2.config.H2Configuration(header_encoding='utf-8') + config = h2.config.H2Configuration(header_encoding="utf-8") c = h2.connection.H2Connection(config=config) c.initiate_connection() c.send_headers(1, self.example_request_headers, end_stream=True) # Clear the data f = frame_factory.build_headers_frame( - self.example_response_headers + self.example_response_headers, ) events = c.receive_data(f.serialize()) @@ -264,7 +262,7 @@ def test_receiving_a_response(self, frame_factory): assert event.stream_id == 1 assert event.headers == self.example_response_headers - def test_receiving_a_response_bytes(self, frame_factory): + def test_receiving_a_response_bytes(self, frame_factory) -> None: """ When receiving a response, the ResponseReceived event fires with bytes headers if the encoding is set appropriately. @@ -276,7 +274,7 @@ def test_receiving_a_response_bytes(self, frame_factory): # Clear the data f = frame_factory.build_headers_frame( - self.example_response_headers + self.example_response_headers, ) events = c.receive_data(f.serialize()) @@ -287,7 +285,7 @@ def test_receiving_a_response_bytes(self, frame_factory): assert event.stream_id == 1 assert event.headers == self.bytes_example_response_headers - def test_receiving_a_response_change_encoding(self, frame_factory): + def test_receiving_a_response_change_encoding(self, frame_factory) -> None: """ When receiving a response, the ResponseReceived event fires with bytes headers if the encoding is set appropriately, but if this changes then @@ -299,7 +297,7 @@ def test_receiving_a_response_change_encoding(self, frame_factory): c.send_headers(1, self.example_request_headers, end_stream=True) f = frame_factory.build_headers_frame( - self.example_response_headers + self.example_response_headers, ) events = c.receive_data(f.serialize()) @@ -311,7 +309,7 @@ def test_receiving_a_response_change_encoding(self, frame_factory): assert event.headers == self.bytes_example_response_headers c.send_headers(3, self.example_request_headers, end_stream=True) - c.config.header_encoding = 'utf-8' + c.config.header_encoding = "utf-8" f = frame_factory.build_headers_frame( self.example_response_headers, stream_id=3, @@ -325,7 +323,7 @@ def test_receiving_a_response_change_encoding(self, frame_factory): assert event.stream_id == 3 assert event.headers == self.example_response_headers - def test_end_stream_without_data(self, frame_factory): + def test_end_stream_without_data(self, frame_factory) -> None: """ Ending a stream without data emits a zero-length DATA frame with END_STREAM set. @@ -336,13 +334,13 @@ def test_end_stream_without_data(self, frame_factory): # Clear the data c.clear_outbound_data_buffer() - f = frame_factory.build_data_frame(b'', flags=['END_STREAM']) + f = frame_factory.build_data_frame(b"", flags=["END_STREAM"]) events = c.end_stream(1) assert not events assert c.data_to_send() == f.serialize() - def test_cannot_send_headers_on_lower_stream_id(self): + def test_cannot_send_headers_on_lower_stream_id(self) -> None: """ Once stream ID x has been used, cannot use stream ID y where y < x. """ @@ -356,30 +354,30 @@ def test_cannot_send_headers_on_lower_stream_id(self): assert e.value.stream_id == 1 assert e.value.max_stream_id == 3 - def test_receiving_pushed_stream(self, frame_factory): + def test_receiving_pushed_stream(self, frame_factory) -> None: """ Pushed streams fire a PushedStreamReceived event, followed by ResponseReceived when the response headers are received. """ - config = h2.config.H2Configuration(header_encoding='utf-8') + config = h2.config.H2Configuration(header_encoding="utf-8") c = h2.connection.H2Connection(config=config) c.initiate_connection() c.send_headers(1, self.example_request_headers, end_stream=False) f1 = frame_factory.build_headers_frame( - self.example_response_headers + self.example_response_headers, ) f2 = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, headers=self.example_request_headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) f3 = frame_factory.build_headers_frame( self.example_response_headers, stream_id=2, ) - data = b''.join(x.serialize() for x in [f1, f2, f3]) + data = b"".join(x.serialize() for x in [f1, f2, f3]) events = c.receive_data(data) @@ -397,7 +395,7 @@ def test_receiving_pushed_stream(self, frame_factory): assert response_event.stream_id == 2 assert response_event.headers == self.example_response_headers - def test_receiving_pushed_stream_bytes(self, frame_factory): + def test_receiving_pushed_stream_bytes(self, frame_factory) -> None: """ Pushed headers are not decoded if the header encoding is set to False. """ @@ -407,19 +405,19 @@ def test_receiving_pushed_stream_bytes(self, frame_factory): c.send_headers(1, self.example_request_headers, end_stream=False) f1 = frame_factory.build_headers_frame( - self.example_response_headers + self.example_response_headers, ) f2 = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, headers=self.example_request_headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) f3 = frame_factory.build_headers_frame( self.example_response_headers, stream_id=2, ) - data = b''.join(x.serialize() for x in [f1, f2, f3]) + data = b"".join(x.serialize() for x in [f1, f2, f3]) events = c.receive_data(data) @@ -438,7 +436,7 @@ def test_receiving_pushed_stream_bytes(self, frame_factory): assert response_event.headers == self.bytes_example_response_headers def test_cannot_receive_pushed_stream_when_enable_push_is_0(self, - frame_factory): + frame_factory) -> None: """ If we have set SETTINGS_ENABLE_PUSH to 0, receiving PUSH_PROMISE frames triggers the connection to be closed. @@ -450,13 +448,13 @@ def test_cannot_receive_pushed_stream_when_enable_push_is_0(self, f1 = frame_factory.build_settings_frame({}, ack=True) f2 = frame_factory.build_headers_frame( - self.example_response_headers + self.example_response_headers, ) f3 = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, headers=self.example_request_headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) c.receive_data(f1.serialize()) c.receive_data(f2.serialize()) @@ -466,11 +464,11 @@ def test_cannot_receive_pushed_stream_when_enable_push_is_0(self, c.receive_data(f3.serialize()) expected_frame = frame_factory.build_goaway_frame( - 0, h2.errors.ErrorCodes.PROTOCOL_ERROR + 0, h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() - def test_receiving_response_no_body(self, frame_factory): + def test_receiving_response_no_body(self, frame_factory) -> None: """ Receiving a response without a body fires two events, ResponseReceived and StreamEnded. @@ -481,7 +479,7 @@ def test_receiving_response_no_body(self, frame_factory): f = frame_factory.build_headers_frame( self.example_response_headers, - flags=['END_STREAM'] + flags=["END_STREAM"], ) events = c.receive_data(f.serialize()) @@ -492,22 +490,22 @@ def test_receiving_response_no_body(self, frame_factory): assert isinstance(response_event, h2.events.ResponseReceived) assert isinstance(end_stream, h2.events.StreamEnded) - def test_oversize_headers(self): + def test_oversize_headers(self) -> None: """ Sending headers that are oversized generates a stream of CONTINUATION frames. """ - all_bytes = [chr(x).encode('latin1') for x in range(0, 256)] + all_bytes = [chr(x).encode("latin1") for x in range(256)] - large_binary_string = b''.join( - random.choice(all_bytes) for _ in range(0, 256) + large_binary_string = b"".join( + random.choice(all_bytes) for _ in range(256) ) test_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':method', 'GET'), - (':scheme', 'https'), - ('key', large_binary_string) + (":authority", "example.com"), + (":path", "/"), + (":method", "GET"), + (":scheme", "https"), + ("key", large_binary_string), ] c = h2.connection.H2Connection() @@ -538,21 +536,19 @@ def test_oversize_headers(self): assert isinstance(headers_frame, hyperframe.frame.HeadersFrame) assert all( - map( - lambda f: isinstance(f, hyperframe.frame.ContinuationFrame), - continuation_frames) + (isinstance(f, hyperframe.frame.ContinuationFrame) for f in continuation_frames), ) assert all( - map(lambda f: len(f.data) <= c.max_outbound_frame_size, frames) + (len(f.data) <= c.max_outbound_frame_size for f in frames), ) - assert frames[0].flags == {'END_STREAM'} + assert frames[0].flags == {"END_STREAM"} buffer.add_data(data[-1:]) - headers = list(buffer)[0] + headers = next(iter(buffer)) assert isinstance(headers, hyperframe.frame.HeadersFrame) - def test_handle_stream_reset(self, frame_factory): + def test_handle_stream_reset(self, frame_factory) -> None: """ Streams being remotely reset fires a StreamReset event. """ @@ -576,7 +572,7 @@ def test_handle_stream_reset(self, frame_factory): assert isinstance(event.error_code, h2.errors.ErrorCodes) assert event.remote_reset - def test_handle_stream_reset_with_unknown_erorr_code(self, frame_factory): + def test_handle_stream_reset_with_unknown_erorr_code(self, frame_factory) -> None: """ Streams being remotely reset with unknown error codes behave exactly as they do with known error codes, but the error code on the event is an @@ -599,7 +595,7 @@ def test_handle_stream_reset_with_unknown_erorr_code(self, frame_factory): assert not isinstance(event.error_code, h2.errors.ErrorCodes) assert event.remote_reset - def test_can_consume_partial_data_from_connection(self): + def test_can_consume_partial_data_from_connection(self) -> None: """ We can do partial reads from the connection. """ @@ -612,7 +608,7 @@ def test_can_consume_partial_data_from_connection(self): assert len(c.data_to_send(10)) == 0 assert len(c.data_to_send()) == 0 - def test_we_can_update_settings(self, frame_factory): + def test_we_can_update_settings(self, frame_factory) -> None: """ Updating the settings emits a SETTINGS frame. """ @@ -630,7 +626,7 @@ def test_we_can_update_settings(self, frame_factory): f = frame_factory.build_settings_frame(new_settings) assert c.data_to_send() == f.serialize() - def test_settings_get_acked_correctly(self, frame_factory): + def test_settings_get_acked_correctly(self, frame_factory) -> None: """ When settings changes are ACKed, they contain the changed settings. """ @@ -654,7 +650,7 @@ def test_settings_get_acked_correctly(self, frame_factory): for setting, value in new_settings.items(): assert event.changed_settings[setting].new_value == value - def test_cannot_create_new_outbound_stream_over_limit(self, frame_factory): + def test_cannot_create_new_outbound_stream_over_limit(self, frame_factory) -> None: """ When the number of outbound streams exceeds the remote peer's MAX_CONCURRENT_STREAMS setting, attempting to open new streams fails. @@ -663,7 +659,7 @@ def test_cannot_create_new_outbound_stream_over_limit(self, frame_factory): c.initiate_connection() f = frame_factory.build_settings_frame( - {h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 1} + {h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 1}, ) c.receive_data(f.serialize())[0] @@ -672,12 +668,12 @@ def test_cannot_create_new_outbound_stream_over_limit(self, frame_factory): with pytest.raises(h2.exceptions.TooManyStreamsError): c.send_headers(3, self.example_request_headers) - def test_can_receive_trailers(self, frame_factory): + def test_can_receive_trailers(self, frame_factory) -> None: """ When two HEADERS blocks are received in the same stream from a server, the second set are trailers. """ - config = h2.config.H2Configuration(header_encoding='utf-8') + config = h2.config.H2Configuration(header_encoding="utf-8") c = h2.connection.H2Connection(config=config) c.initiate_connection() c.send_headers(1, self.example_request_headers) @@ -685,10 +681,10 @@ def test_can_receive_trailers(self, frame_factory): c.receive_data(f.serialize()) # Send in trailers. - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] f = frame_factory.build_headers_frame( trailers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) events = c.receive_data(f.serialize()) assert len(events) == 2 @@ -698,7 +694,7 @@ def test_can_receive_trailers(self, frame_factory): assert event.headers == trailers assert event.stream_id == 1 - def test_reject_trailers_not_ending_stream(self, frame_factory): + def test_reject_trailers_not_ending_stream(self, frame_factory) -> None: """ When trailers are received without the END_STREAM flag being present, this is a ProtocolError. @@ -711,7 +707,7 @@ def test_reject_trailers_not_ending_stream(self, frame_factory): # Send in trailers. c.clear_outbound_data_buffer() - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] f = frame_factory.build_headers_frame( trailers, flags=[], @@ -725,7 +721,7 @@ def test_reject_trailers_not_ending_stream(self, frame_factory): ) assert c.data_to_send() == expected_frame.serialize() - def test_can_send_trailers(self, frame_factory): + def test_can_send_trailers(self, frame_factory) -> None: """ When a second set of headers are sent, they are properly trailers. """ @@ -735,7 +731,7 @@ def test_can_send_trailers(self, frame_factory): c.send_headers(1, self.example_request_headers) # Now send trailers. - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] c.send_headers(1, trailers, end_stream=True) frame_factory.refresh_encoder() @@ -744,11 +740,11 @@ def test_can_send_trailers(self, frame_factory): ) f2 = frame_factory.build_headers_frame( trailers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) assert c.data_to_send() == f1.serialize() + f2.serialize() - def test_trailers_must_have_end_stream(self, frame_factory): + def test_trailers_must_have_end_stream(self, frame_factory) -> None: """ A set of trailers must carry the END_STREAM flag. """ @@ -759,25 +755,17 @@ def test_trailers_must_have_end_stream(self, frame_factory): c.send_headers(1, self.example_request_headers) # Now send trailers. - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] with pytest.raises(h2.exceptions.ProtocolError): c.send_headers(1, trailers) - def test_headers_are_lowercase(self, frame_factory): + def test_headers_are_lowercase(self, frame_factory) -> None: """ When headers are sent, they are forced to lower-case. """ - weird_headers = self.example_request_headers + [ - ('ChAnGiNg-CaSe', 'AlsoHere'), - ('alllowercase', 'alllowercase'), - ('ALLCAPS', 'ALLCAPS'), - ] - expected_headers = self.example_request_headers + [ - ('changing-case', 'AlsoHere'), - ('alllowercase', 'alllowercase'), - ('allcaps', 'ALLCAPS'), - ] + weird_headers = [*self.example_request_headers, ("ChAnGiNg-CaSe", "AlsoHere"), ("alllowercase", "alllowercase"), ("ALLCAPS", "ALLCAPS")] + expected_headers = [*self.example_request_headers, ("changing-case", "AlsoHere"), ("alllowercase", "alllowercase"), ("allcaps", "ALLCAPS")] c = h2.connection.H2Connection() c.initiate_connection() @@ -785,40 +773,40 @@ def test_headers_are_lowercase(self, frame_factory): c.send_headers(1, weird_headers) expected_frame = frame_factory.build_headers_frame( - headers=expected_headers + headers=expected_headers, ) assert c.data_to_send() == expected_frame.serialize() - def test_outbound_cookie_headers_are_split(self): + def test_outbound_cookie_headers_are_split(self) -> None: """ We should split outbound cookie headers according to RFC 7540 - 8.1.2.5 """ cookie_headers = [ - HeaderTuple('cookie', - 'username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC'), - ('cookie', 'path=1'), - ('cookie', 'test1=val1; test2=val2') + HeaderTuple("cookie", + "username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC"), + ("cookie", "path=1"), + ("cookie", "test1=val1; test2=val2"), ] expected_cookie_headers = [ - HeaderTuple('cookie', 'username=John Doe'), - HeaderTuple('cookie', 'expires=Thu, 18 Dec 2013 12:00:00 UTC'), - ('cookie', 'path=1'), - ('cookie', 'test1=val1'), - ('cookie', 'test2=val2'), + HeaderTuple("cookie", "username=John Doe"), + HeaderTuple("cookie", "expires=Thu, 18 Dec 2013 12:00:00 UTC"), + ("cookie", "path=1"), + ("cookie", "test1=val1"), + ("cookie", "test2=val2"), ] client_config = h2.config.H2Configuration( client_side=True, - header_encoding='utf-8', - split_outbound_cookies=True + header_encoding="utf-8", + split_outbound_cookies=True, ) server_config = h2.config.H2Configuration( client_side=False, normalize_inbound_headers=False, - header_encoding='utf-8' + header_encoding="utf-8", ) client = h2.connection.H2Connection(config=client_config) server = h2.connection.H2Connection(config=server_config) @@ -828,13 +816,13 @@ def test_outbound_cookie_headers_are_split(self): e = server.receive_data(client.data_to_send()) - cookie_fields = [(n, v) for n, v in e[1].headers if n == 'cookie'] + cookie_fields = [(n, v) for n, v in e[1].headers if n == "cookie"] assert cookie_fields == expected_cookie_headers @given(frame_size=integers(min_value=2**14, max_value=(2**24 - 1))) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_changing_max_frame_size(self, frame_factory, frame_size): + def test_changing_max_frame_size(self, frame_factory, frame_size) -> None: """ When the user changes the max frame size and the change is ACKed, the remote peer is now bound by the new frame size. @@ -857,7 +845,7 @@ def test_changing_max_frame_size(self, frame_factory, frame_size): # Change the max frame size. c.update_settings( - {h2.settings.SettingCodes.MAX_FRAME_SIZE: frame_size} + {h2.settings.SettingCodes.MAX_FRAME_SIZE: frame_size}, ) settings_ack = frame_factory.build_settings_frame({}, ack=True) c.receive_data(settings_ack.serialize()) @@ -866,13 +854,13 @@ def test_changing_max_frame_size(self, frame_factory, frame_size): # flow control today. c.increment_flow_control_window(increment=(2 * frame_size) + 1) c.increment_flow_control_window( - increment=(2 * frame_size) + 1, stream_id=1 + increment=(2 * frame_size) + 1, stream_id=1, ) # Send one DATA frame that is exactly the max frame size: confirm it's # fine. data = frame_factory.build_data_frame( - data=(b'\x00' * frame_size), + data=(b"\x00" * frame_size), ) events = c.receive_data(data.serialize()) assert len(events) == 1 @@ -881,11 +869,11 @@ def test_changing_max_frame_size(self, frame_factory, frame_size): # Send one that is one byte too large: confirm a protocol error is # raised. - data.data += b'\x00' + data.data += b"\x00" with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(data.serialize()) - def test_cookies_are_joined_on_push(self, frame_factory): + def test_cookies_are_joined_on_push(self, frame_factory) -> None: """ RFC 7540 Section 8.1.2.5 requires that we join multiple Cookie headers in a header block together when they're received on a push. @@ -893,17 +881,17 @@ def test_cookies_are_joined_on_push(self, frame_factory): # This is a moderately varied set of cookie headers: some combined, # some split. cookie_headers = [ - ('cookie', - 'username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC'), - ('cookie', 'path=1'), - ('cookie', 'test1=val1; test2=val2') + ("cookie", + "username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC"), + ("cookie", "path=1"), + ("cookie", "test1=val1; test2=val2"), ] expected = ( - 'username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC; ' - 'path=1; test1=val1; test2=val2' + "username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC; " + "path=1; test1=val1; test2=val2" ) - config = h2.config.H2Configuration(header_encoding='utf-8') + config = h2.config.H2Configuration(header_encoding="utf-8") c = h2.connection.H2Connection(config=config) c.initiate_connection() c.send_headers(1, self.example_request_headers, end_stream=True) @@ -911,20 +899,20 @@ def test_cookies_are_joined_on_push(self, frame_factory): f = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, - headers=self.example_request_headers + cookie_headers + headers=self.example_request_headers + cookie_headers, ) events = c.receive_data(f.serialize()) assert len(events) == 1 e = events[0] - cookie_fields = [(n, v) for n, v in e.headers if n == 'cookie'] + cookie_fields = [(n, v) for n, v in e.headers if n == "cookie"] assert len(cookie_fields) == 1 _, v = cookie_fields[0] assert v == expected - def test_cookies_arent_joined_without_normalization(self, frame_factory): + def test_cookies_arent_joined_without_normalization(self, frame_factory) -> None: """ If inbound header normalization is disabled, cookie headers aren't joined. @@ -932,16 +920,16 @@ def test_cookies_arent_joined_without_normalization(self, frame_factory): # This is a moderately varied set of cookie headers: some combined, # some split. cookie_headers = [ - ('cookie', - 'username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC'), - ('cookie', 'path=1'), - ('cookie', 'test1=val1; test2=val2') + ("cookie", + "username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC"), + ("cookie", "path=1"), + ("cookie", "test1=val1; test2=val2"), ] config = h2.config.H2Configuration( client_side=True, normalize_inbound_headers=False, - header_encoding='utf-8' + header_encoding="utf-8", ) c = h2.connection.H2Connection(config=config) c.initiate_connection() @@ -950,60 +938,61 @@ def test_cookies_arent_joined_without_normalization(self, frame_factory): f = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, - headers=self.example_request_headers + cookie_headers + headers=self.example_request_headers + cookie_headers, ) events = c.receive_data(f.serialize()) assert len(events) == 1 e = events[0] - received_cookies = [(n, v) for n, v in e.headers if n == 'cookie'] + received_cookies = [(n, v) for n, v in e.headers if n == "cookie"] assert len(received_cookies) == 3 assert cookie_headers == received_cookies -class TestBasicServer(object): +class TestBasicServer: """ Basic server-side tests. """ + example_request_headers = [ - (u':authority', u'example.com'), - (u':path', u'/'), - (u':scheme', u'https'), - (u':method', u'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] bytes_example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'hyper-h2/0.1.0') + (":status", "200"), + ("server", "hyper-h2/0.1.0"), ] server_config = h2.config.H2Configuration( - client_side=False, header_encoding='utf-8' + client_side=False, header_encoding="utf-8", ) - def test_ignores_preamble(self): + def test_ignores_preamble(self) -> None: """ The preamble does not cause any events or frames to be written. """ c = h2.connection.H2Connection(config=self.server_config) - preamble = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + preamble = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" events = c.receive_data(preamble) assert not events assert not c.data_to_send() @pytest.mark.parametrize("chunk_size", range(1, 24)) - def test_drip_feed_preamble(self, chunk_size): + def test_drip_feed_preamble(self, chunk_size) -> None: """ The preamble can be sent in in less than a single buffer. """ c = h2.connection.H2Connection(config=self.server_config) - preamble = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' + preamble = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" events = [] for i in range(0, len(preamble), chunk_size): @@ -1012,14 +1001,14 @@ def test_drip_feed_preamble(self, chunk_size): assert not events assert not c.data_to_send() - def test_initiate_connection_sends_server_preamble(self, frame_factory): + def test_initiate_connection_sends_server_preamble(self, frame_factory) -> None: """ For server-side connections, initiate_connection sends a server preamble. """ c = h2.connection.H2Connection(config=self.server_config) expected_settings = frame_factory.build_settings_frame( - c.local_settings + c.local_settings, ) expected_data = expected_settings.serialize() @@ -1027,7 +1016,7 @@ def test_initiate_connection_sends_server_preamble(self, frame_factory): assert not events assert c.data_to_send() == expected_data - def test_headers_event(self, frame_factory): + def test_headers_event(self, frame_factory) -> None: """ When a headers frame is received a RequestReceived event fires. """ @@ -1045,13 +1034,13 @@ def test_headers_event(self, frame_factory): assert event.stream_id == 1 assert event.headers == self.example_request_headers - def test_headers_event_bytes(self, frame_factory): + def test_headers_event_bytes(self, frame_factory) -> None: """ When a headers frame is received a RequestReceived event fires with bytes headers if the encoding is set appropriately. """ config = h2.config.H2Configuration( - client_side=False, header_encoding=False + client_side=False, header_encoding=False, ) c = h2.connection.H2Connection(config=config) c.receive_data(frame_factory.preamble()) @@ -1067,7 +1056,7 @@ def test_headers_event_bytes(self, frame_factory): assert event.stream_id == 1 assert event.headers == self.bytes_example_request_headers - def test_data_event(self, frame_factory): + def test_data_event(self, frame_factory) -> None: """ Test that data received on a stream fires a DataReceived event. """ @@ -1075,13 +1064,13 @@ def test_data_event(self, frame_factory): c.receive_data(frame_factory.preamble()) f1 = frame_factory.build_headers_frame( - self.example_request_headers, stream_id=3 + self.example_request_headers, stream_id=3, ) f2 = frame_factory.build_data_frame( - b'some request data', + b"some request data", stream_id=3, ) - data = b''.join(map(lambda f: f.serialize(), [f1, f2])) + data = b"".join(f.serialize() for f in [f1, f2]) events = c.receive_data(data) assert len(events) == 2 @@ -1089,10 +1078,10 @@ def test_data_event(self, frame_factory): assert isinstance(event, h2.events.DataReceived) assert event.stream_id == 3 - assert event.data == b'some request data' + assert event.data == b"some request data" assert event.flow_controlled_length == 17 - def test_data_event_with_padding(self, frame_factory): + def test_data_event_with_padding(self, frame_factory) -> None: """ Test that data received on a stream fires a DataReceived event that accounts for padding. @@ -1101,14 +1090,14 @@ def test_data_event_with_padding(self, frame_factory): c.receive_data(frame_factory.preamble()) f1 = frame_factory.build_headers_frame( - self.example_request_headers, stream_id=3 + self.example_request_headers, stream_id=3, ) f2 = frame_factory.build_data_frame( - b'some request data', + b"some request data", stream_id=3, - padding_len=20 + padding_len=20, ) - data = b''.join(map(lambda f: f.serialize(), [f1, f2])) + data = b"".join(f.serialize() for f in [f1, f2]) events = c.receive_data(data) assert len(events) == 2 @@ -1116,20 +1105,20 @@ def test_data_event_with_padding(self, frame_factory): assert isinstance(event, h2.events.DataReceived) assert event.stream_id == 3 - assert event.data == b'some request data' + assert event.data == b"some request data" assert event.flow_controlled_length == 17 + 20 + 1 - def test_receiving_ping_frame(self, frame_factory): + def test_receiving_ping_frame(self, frame_factory) -> None: """ Ping frames should be immediately ACKed. """ c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) - ping_data = b'\x01' * 8 + ping_data = b"\x01" * 8 sent_frame = frame_factory.build_ping_frame(ping_data) expected_frame = frame_factory.build_ping_frame( - ping_data, flags=["ACK"] + ping_data, flags=["ACK"], ) expected_data = expected_frame.serialize() @@ -1144,7 +1133,7 @@ def test_receiving_ping_frame(self, frame_factory): assert c.data_to_send() == expected_data - def test_receiving_settings_frame_event(self, frame_factory): + def test_receiving_settings_frame_event(self, frame_factory) -> None: """ Settings frames should cause a RemoteSettingsChanged event to fire. """ @@ -1152,7 +1141,7 @@ def test_receiving_settings_frame_event(self, frame_factory): c.receive_data(frame_factory.preamble()) f = frame_factory.build_settings_frame( - settings=helpers.SAMPLE_SETTINGS + settings=helpers.SAMPLE_SETTINGS, ) events = c.receive_data(f.serialize()) @@ -1162,7 +1151,7 @@ def test_receiving_settings_frame_event(self, frame_factory): assert isinstance(event, h2.events.RemoteSettingsChanged) assert len(event.changed_settings) == len(helpers.SAMPLE_SETTINGS) - def test_acknowledging_settings(self, frame_factory): + def test_acknowledging_settings(self, frame_factory) -> None: """ Acknowledging settings causes appropriate Settings frame to be emitted. """ @@ -1170,10 +1159,10 @@ def test_acknowledging_settings(self, frame_factory): c.receive_data(frame_factory.preamble()) received_frame = frame_factory.build_settings_frame( - settings=helpers.SAMPLE_SETTINGS + settings=helpers.SAMPLE_SETTINGS, ) expected_frame = frame_factory.build_settings_frame( - settings={}, ack=True + settings={}, ack=True, ) expected_data = expected_frame.serialize() @@ -1183,7 +1172,7 @@ def test_acknowledging_settings(self, frame_factory): assert len(events) == 1 assert c.data_to_send() == expected_data - def test_close_connection(self, frame_factory): + def test_close_connection(self, frame_factory) -> None: """ Closing the connection with no error code emits a GOAWAY frame with error code 0. @@ -1200,7 +1189,7 @@ def test_close_connection(self, frame_factory): assert c.data_to_send() == expected_data @pytest.mark.parametrize("error_code", h2.errors.ErrorCodes) - def test_close_connection_with_error_code(self, frame_factory, error_code): + def test_close_connection_with_error_code(self, frame_factory, error_code) -> None: """ Closing the connection with an error code emits a GOAWAY frame with that error code. @@ -1208,7 +1197,7 @@ def test_close_connection_with_error_code(self, frame_factory, error_code): c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) f = frame_factory.build_goaway_frame( - error_code=error_code, last_stream_id=0 + error_code=error_code, last_stream_id=0, ) expected_data = f.serialize() @@ -1218,13 +1207,13 @@ def test_close_connection_with_error_code(self, frame_factory, error_code): assert not events assert c.data_to_send() == expected_data - @pytest.mark.parametrize("last_stream_id,output", [ + @pytest.mark.parametrize(("last_stream_id", "output"), [ (None, 23), (0, 0), - (42, 42) + (42, 42), ]) def test_close_connection_with_last_stream_id(self, frame_factory, - last_stream_id, output): + last_stream_id, output) -> None: """ Closing the connection with last_stream_id set emits a GOAWAY frame with that value. @@ -1233,16 +1222,16 @@ def test_close_connection_with_last_stream_id(self, frame_factory, c.receive_data(frame_factory.preamble()) headers_frame = frame_factory.build_headers_frame( [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ], stream_id=23) c.receive_data(headers_frame.serialize()) f = frame_factory.build_goaway_frame( - last_stream_id=output + last_stream_id=output, ) expected_data = f.serialize() @@ -1252,13 +1241,13 @@ def test_close_connection_with_last_stream_id(self, frame_factory, assert not events assert c.data_to_send() == expected_data - @pytest.mark.parametrize("additional_data,output", [ - (None, b''), - (b'', b''), - (b'foobar', b'foobar') + @pytest.mark.parametrize(("additional_data", "output"), [ + (None, b""), + (b"", b""), + (b"foobar", b"foobar"), ]) def test_close_connection_with_additional_data(self, frame_factory, - additional_data, output): + additional_data, output) -> None: """ Closing the connection with additional debug data emits a GOAWAY frame with that data attached. @@ -1266,7 +1255,7 @@ def test_close_connection_with_additional_data(self, frame_factory, c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) f = frame_factory.build_goaway_frame( - last_stream_id=0, additional_data=output + last_stream_id=0, additional_data=output, ) expected_data = f.serialize() @@ -1276,7 +1265,7 @@ def test_close_connection_with_additional_data(self, frame_factory, assert not events assert c.data_to_send() == expected_data - def test_reset_stream(self, frame_factory): + def test_reset_stream(self, frame_factory) -> None: """ Resetting a stream with no error code emits a RST_STREAM frame with error code 0. @@ -1296,7 +1285,7 @@ def test_reset_stream(self, frame_factory): assert c.data_to_send() == expected_data @pytest.mark.parametrize("error_code", h2.errors.ErrorCodes) - def test_reset_stream_with_error_code(self, frame_factory, error_code): + def test_reset_stream_with_error_code(self, frame_factory, error_code) -> None: """ Resetting a stream with an error code emits a RST_STREAM frame with that error code. @@ -1305,13 +1294,13 @@ def test_reset_stream_with_error_code(self, frame_factory, error_code): c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( self.example_request_headers, - stream_id=3 + stream_id=3, ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() expected_frame = frame_factory.build_rst_stream_frame( - stream_id=3, error_code=error_code + stream_id=3, error_code=error_code, ) expected_data = expected_frame.serialize() @@ -1320,7 +1309,7 @@ def test_reset_stream_with_error_code(self, frame_factory, error_code): assert not events assert c.data_to_send() == expected_data - def test_cannot_reset_nonexistent_stream(self, frame_factory): + def test_cannot_reset_nonexistent_stream(self, frame_factory) -> None: """ Resetting nonexistent streams raises NoSuchStreamError. """ @@ -1328,7 +1317,7 @@ def test_cannot_reset_nonexistent_stream(self, frame_factory): c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( self.example_request_headers, - stream_id=3 + stream_id=3, ) c.receive_data(f.serialize()) @@ -1342,7 +1331,7 @@ def test_cannot_reset_nonexistent_stream(self, frame_factory): assert e.value.stream_id == 5 - def test_basic_sending_ping_frame_logic(self, frame_factory): + def test_basic_sending_ping_frame_logic(self, frame_factory) -> None: """ Sending ping frames serializes a ping frame on stream 0 with appropriate opaque data. @@ -1351,7 +1340,7 @@ def test_basic_sending_ping_frame_logic(self, frame_factory): c.receive_data(frame_factory.preamble()) c.clear_outbound_data_buffer() - ping_data = b'\x01\x02\x03\x04\x05\x06\x07\x08' + ping_data = b"\x01\x02\x03\x04\x05\x06\x07\x08" expected_frame = frame_factory.build_ping_frame(ping_data) expected_data = expected_frame.serialize() @@ -1362,17 +1351,17 @@ def test_basic_sending_ping_frame_logic(self, frame_factory): assert c.data_to_send() == expected_data @pytest.mark.parametrize( - 'opaque_data', + "opaque_data", [ - b'', - b'\x01\x02\x03\x04\x05\x06\x07', - u'abcdefgh', - b'too many bytes', - ] + b"", + b"\x01\x02\x03\x04\x05\x06\x07", + "abcdefgh", + b"too many bytes", + ], ) def test_ping_frame_opaque_data_must_be_length_8_bytestring(self, frame_factory, - opaque_data): + opaque_data) -> None: """ Sending a ping frame only works with 8-byte bytestrings. """ @@ -1382,17 +1371,17 @@ def test_ping_frame_opaque_data_must_be_length_8_bytestring(self, with pytest.raises(ValueError): c.ping(opaque_data) - def test_receiving_ping_acknowledgement(self, frame_factory): + def test_receiving_ping_acknowledgement(self, frame_factory) -> None: """ Receiving a PING acknowledgement fires a PingAckReceived event. """ c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) - ping_data = b'\x01\x02\x03\x04\x05\x06\x07\x08' + ping_data = b"\x01\x02\x03\x04\x05\x06\x07\x08" f = frame_factory.build_ping_frame( - ping_data, flags=['ACK'] + ping_data, flags=["ACK"], ) events = c.receive_data(f.serialize()) @@ -1402,7 +1391,7 @@ def test_receiving_ping_acknowledgement(self, frame_factory): assert isinstance(event, h2.events.PingAckReceived) assert event.ping_data == ping_data - def test_stream_ended_remotely(self, frame_factory): + def test_stream_ended_remotely(self, frame_factory) -> None: """ When the remote stream ends with a non-empty data frame a DataReceived event and a StreamEnded event are fired. @@ -1411,14 +1400,14 @@ def test_stream_ended_remotely(self, frame_factory): c.receive_data(frame_factory.preamble()) f1 = frame_factory.build_headers_frame( - self.example_request_headers, stream_id=3 + self.example_request_headers, stream_id=3, ) f2 = frame_factory.build_data_frame( - b'some request data', - flags=['END_STREAM'], + b"some request data", + flags=["END_STREAM"], stream_id=3, ) - data = b''.join(map(lambda f: f.serialize(), [f1, f2])) + data = b"".join(f.serialize() for f in [f1, f2]) events = c.receive_data(data) assert len(events) == 3 @@ -1429,14 +1418,14 @@ def test_stream_ended_remotely(self, frame_factory): assert isinstance(stream_ended_event, h2.events.StreamEnded) stream_ended_event.stream_id == 3 - def test_can_push_stream(self, frame_factory): + def test_can_push_stream(self, frame_factory) -> None: """ Pushing a stream causes a PUSH_PROMISE frame to be emitted. """ c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - self.example_request_headers + self.example_request_headers, ) c.receive_data(f.serialize()) @@ -1445,31 +1434,31 @@ def test_can_push_stream(self, frame_factory): stream_id=1, promised_stream_id=2, headers=self.example_request_headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) c.clear_outbound_data_buffer() c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=self.example_request_headers + request_headers=self.example_request_headers, ) assert c.data_to_send() == expected_frame.serialize() - def test_cannot_push_streams_when_disabled(self, frame_factory): + def test_cannot_push_streams_when_disabled(self, frame_factory) -> None: """ When the remote peer has disabled stream pushing, we should fail. """ c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) f = frame_factory.build_settings_frame( - {h2.settings.SettingCodes.ENABLE_PUSH: 0} + {h2.settings.SettingCodes.ENABLE_PUSH: 0}, ) c.receive_data(f.serialize()) f = frame_factory.build_headers_frame( - self.example_request_headers + self.example_request_headers, ) c.receive_data(f.serialize()) @@ -1477,10 +1466,10 @@ def test_cannot_push_streams_when_disabled(self, frame_factory): c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=self.example_request_headers + request_headers=self.example_request_headers, ) - def test_settings_remote_change_header_table_size(self, frame_factory): + def test_settings_remote_change_header_table_size(self, frame_factory) -> None: """ Acknowledging a remote HEADER_TABLE_SIZE settings change causes us to change the header table size of our encoder. @@ -1491,13 +1480,13 @@ def test_settings_remote_change_header_table_size(self, frame_factory): assert c.encoder.header_table_size == 4096 received_frame = frame_factory.build_settings_frame( - {h2.settings.SettingCodes.HEADER_TABLE_SIZE: 80} + {h2.settings.SettingCodes.HEADER_TABLE_SIZE: 80}, ) c.receive_data(received_frame.serialize())[0] assert c.encoder.header_table_size == 80 - def test_settings_local_change_header_table_size(self, frame_factory): + def test_settings_local_change_header_table_size(self, frame_factory) -> None: """ The remote peer acknowledging a local HEADER_TABLE_SIZE settings change does not cause us to change the header table size of our decoder. @@ -1511,14 +1500,14 @@ def test_settings_local_change_header_table_size(self, frame_factory): expected_frame = frame_factory.build_settings_frame({}, ack=True) c.update_settings( - {h2.settings.SettingCodes.HEADER_TABLE_SIZE: 80} + {h2.settings.SettingCodes.HEADER_TABLE_SIZE: 80}, ) c.receive_data(expected_frame.serialize()) c.clear_outbound_data_buffer() assert c.decoder.header_table_size == 4096 - def test_restricting_outbound_frame_size_by_settings(self, frame_factory): + def test_restricting_outbound_frame_size_by_settings(self, frame_factory) -> None: """ The remote peer can shrink the maximum outbound frame size using settings. @@ -1532,17 +1521,17 @@ def test_restricting_outbound_frame_size_by_settings(self, frame_factory): c.clear_outbound_data_buffer() with pytest.raises(h2.exceptions.FrameTooLargeError): - c.send_data(1, b'\x01' * 17000) + c.send_data(1, b"\x01" * 17000) received_frame = frame_factory.build_settings_frame( - {h2.settings.SettingCodes.MAX_FRAME_SIZE: 17001} + {h2.settings.SettingCodes.MAX_FRAME_SIZE: 17001}, ) c.receive_data(received_frame.serialize()) - c.send_data(1, b'\x01' * 17000) + c.send_data(1, b"\x01" * 17000) assert c.data_to_send() - def test_restricting_inbound_frame_size_by_settings(self, frame_factory): + def test_restricting_inbound_frame_size_by_settings(self, frame_factory) -> None: """ We throw ProtocolErrors and tear down connections if oversize frames are received. @@ -1553,17 +1542,17 @@ def test_restricting_inbound_frame_size_by_settings(self, frame_factory): c.receive_data(h.serialize()) c.clear_outbound_data_buffer() - data_frame = frame_factory.build_data_frame(b'\x01' * 17000) + data_frame = frame_factory.build_data_frame(b"\x01" * 17000) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(data_frame.serialize()) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=1, error_code=h2.errors.ErrorCodes.FRAME_SIZE_ERROR + last_stream_id=1, error_code=h2.errors.ErrorCodes.FRAME_SIZE_ERROR, ) assert c.data_to_send() == expected_frame.serialize() - def test_cannot_receive_new_streams_over_limit(self, frame_factory): + def test_cannot_receive_new_streams_over_limit(self, frame_factory) -> None: """ When the number of inbound streams exceeds our MAX_CONCURRENT_STREAMS setting, their attempt to open new streams fails. @@ -1571,7 +1560,7 @@ def test_cannot_receive_new_streams_over_limit(self, frame_factory): c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) c.update_settings( - {h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 1} + {h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 1}, ) f = frame_factory.build_settings_frame({}, ack=True) c.receive_data(f.serialize()) @@ -1595,7 +1584,7 @@ def test_cannot_receive_new_streams_over_limit(self, frame_factory): ) assert c.data_to_send() == expected_frame.serialize() - def test_can_receive_trailers(self, frame_factory): + def test_can_receive_trailers(self, frame_factory) -> None: """ When two HEADERS blocks are received in the same stream from a client, the second set are trailers. @@ -1606,10 +1595,10 @@ def test_can_receive_trailers(self, frame_factory): c.receive_data(f.serialize()) # Send in trailers. - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] f = frame_factory.build_headers_frame( trailers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) events = c.receive_data(f.serialize()) assert len(events) == 2 @@ -1619,7 +1608,7 @@ def test_can_receive_trailers(self, frame_factory): assert event.headers == trailers assert event.stream_id == 1 - def test_reject_trailers_not_ending_stream(self, frame_factory): + def test_reject_trailers_not_ending_stream(self, frame_factory) -> None: """ When trailers are received without the END_STREAM flag being present, this is a ProtocolError. @@ -1631,7 +1620,7 @@ def test_reject_trailers_not_ending_stream(self, frame_factory): # Send in trailers. c.clear_outbound_data_buffer() - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] f = frame_factory.build_headers_frame( trailers, flags=[], @@ -1645,7 +1634,7 @@ def test_reject_trailers_not_ending_stream(self, frame_factory): ) assert c.data_to_send() == expected_frame.serialize() - def test_can_send_trailers(self, frame_factory): + def test_can_send_trailers(self, frame_factory) -> None: """ When a second set of headers are sent, they are properly trailers. """ @@ -1659,7 +1648,7 @@ def test_can_send_trailers(self, frame_factory): c.send_headers(1, self.example_response_headers) # Now send trailers. - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] c.send_headers(1, trailers, end_stream=True) frame_factory.refresh_encoder() @@ -1668,11 +1657,11 @@ def test_can_send_trailers(self, frame_factory): ) f2 = frame_factory.build_headers_frame( trailers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) assert c.data_to_send() == f1.serialize() + f2.serialize() - def test_trailers_must_have_end_stream(self, frame_factory): + def test_trailers_must_have_end_stream(self, frame_factory) -> None: """ A set of trailers must carry the END_STREAM flag. """ @@ -1685,18 +1674,18 @@ def test_trailers_must_have_end_stream(self, frame_factory): c.send_headers(1, self.example_response_headers) # Now send trailers. - trailers = [('content-length', '0')] + trailers = [("content-length", "0")] with pytest.raises(h2.exceptions.ProtocolError): c.send_headers(1, trailers) @pytest.mark.parametrize("frame_id", range(12, 256)) - def test_unknown_frames_are_ignored(self, frame_factory, frame_id): + def test_unknown_frames_are_ignored(self, frame_factory, frame_id) -> None: c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) c.clear_outbound_data_buffer() - f = frame_factory.build_data_frame(data=b'abcdefghtdst') + f = frame_factory.build_data_frame(data=b"abcdefghtdst") f.type = frame_id events = c.receive_data(f.serialize()) @@ -1705,7 +1694,7 @@ def test_unknown_frames_are_ignored(self, frame_factory, frame_id): assert isinstance(events[0], h2.events.UnknownFrameReceived) assert isinstance(events[0].frame, hyperframe.frame.ExtensionFrame) - def test_can_send_goaway_repeatedly(self, frame_factory): + def test_can_send_goaway_repeatedly(self, frame_factory) -> None: """ We can send a GOAWAY frame as many times as we like. """ @@ -1721,7 +1710,7 @@ def test_can_send_goaway_repeatedly(self, frame_factory): assert c.data_to_send() == (f.serialize() * 3) - def test_receiving_goaway_frame(self, frame_factory): + def test_receiving_goaway_frame(self, frame_factory) -> None: """ Receiving a GOAWAY frame causes a ConnectionTerminated event to be fired and transitions the connection to the CLOSED state, and clears @@ -1732,7 +1721,7 @@ def test_receiving_goaway_frame(self, frame_factory): c.receive_data(frame_factory.preamble()) f = frame_factory.build_goaway_frame( - last_stream_id=5, error_code=h2.errors.ErrorCodes.SETTINGS_TIMEOUT + last_stream_id=5, error_code=h2.errors.ErrorCodes.SETTINGS_TIMEOUT, ) events = c.receive_data(f.serialize()) @@ -1748,7 +1737,7 @@ def test_receiving_goaway_frame(self, frame_factory): assert not c.data_to_send() - def test_receiving_multiple_goaway_frames(self, frame_factory): + def test_receiving_multiple_goaway_frames(self, frame_factory) -> None: """ Multiple GOAWAY frames can be received at once, and are allowed. Each one fires a ConnectionTerminated event. @@ -1766,7 +1755,7 @@ def test_receiving_multiple_goaway_frames(self, frame_factory): for event in events ) - def test_receiving_goaway_frame_with_additional_data(self, frame_factory): + def test_receiving_goaway_frame_with_additional_data(self, frame_factory) -> None: """ GOAWAY frame can contain additional data, it should be available via ConnectionTerminated event. @@ -1775,7 +1764,7 @@ def test_receiving_goaway_frame_with_additional_data(self, frame_factory): c.initiate_connection() c.receive_data(frame_factory.preamble()) - additional_data = b'debug data' + additional_data = b"debug data" f = frame_factory.build_goaway_frame(last_stream_id=0, additional_data=additional_data) events = c.receive_data(f.serialize()) @@ -1786,7 +1775,7 @@ def test_receiving_goaway_frame_with_additional_data(self, frame_factory): assert isinstance(event, h2.events.ConnectionTerminated) assert event.additional_data == additional_data - def test_receiving_goaway_frame_with_unknown_error(self, frame_factory): + def test_receiving_goaway_frame_with_unknown_error(self, frame_factory) -> None: """ Receiving a GOAWAY frame with an unknown error code behaves exactly the same as receiving one we know about, but the code is reported as an @@ -1797,7 +1786,7 @@ def test_receiving_goaway_frame_with_unknown_error(self, frame_factory): c.receive_data(frame_factory.preamble()) f = frame_factory.build_goaway_frame( - last_stream_id=5, error_code=0xFA + last_stream_id=5, error_code=0xFA, ) events = c.receive_data(f.serialize()) @@ -1813,7 +1802,7 @@ def test_receiving_goaway_frame_with_unknown_error(self, frame_factory): assert not c.data_to_send() - def test_cookies_are_joined(self, frame_factory): + def test_cookies_are_joined(self, frame_factory) -> None: """ RFC 7540 Section 8.1.2.5 requires that we join multiple Cookie headers in a header block together. @@ -1821,14 +1810,14 @@ def test_cookies_are_joined(self, frame_factory): # This is a moderately varied set of cookie headers: some combined, # some split. cookie_headers = [ - ('cookie', - 'username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC'), - ('cookie', 'path=1'), - ('cookie', 'test1=val1; test2=val2') + ("cookie", + "username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC"), + ("cookie", "path=1"), + ("cookie", "test1=val1; test2=val2"), ] expected = ( - 'username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC; ' - 'path=1; test1=val1; test2=val2' + "username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC; " + "path=1; test1=val1; test2=val2" ) c = h2.connection.H2Connection(config=self.server_config) @@ -1836,20 +1825,20 @@ def test_cookies_are_joined(self, frame_factory): c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - self.example_request_headers + cookie_headers + self.example_request_headers + cookie_headers, ) events = c.receive_data(f.serialize()) assert len(events) == 1 e = events[0] - cookie_fields = [(n, v) for n, v in e.headers if n == 'cookie'] + cookie_fields = [(n, v) for n, v in e.headers if n == "cookie"] assert len(cookie_fields) == 1 _, v = cookie_fields[0] assert v == expected - def test_cookies_arent_joined_without_normalization(self, frame_factory): + def test_cookies_arent_joined_without_normalization(self, frame_factory) -> None: """ If inbound header normalization is disabled, cookie headers aren't joined. @@ -1857,34 +1846,34 @@ def test_cookies_arent_joined_without_normalization(self, frame_factory): # This is a moderately varied set of cookie headers: some combined, # some split. cookie_headers = [ - ('cookie', - 'username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC'), - ('cookie', 'path=1'), - ('cookie', 'test1=val1; test2=val2') + ("cookie", + "username=John Doe; expires=Thu, 18 Dec 2013 12:00:00 UTC"), + ("cookie", "path=1"), + ("cookie", "test1=val1; test2=val2"), ] config = h2.config.H2Configuration( client_side=False, normalize_inbound_headers=False, - header_encoding='utf-8' + header_encoding="utf-8", ) c = h2.connection.H2Connection(config=config) c.initiate_connection() c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - self.example_request_headers + cookie_headers + self.example_request_headers + cookie_headers, ) events = c.receive_data(f.serialize()) assert len(events) == 1 e = events[0] - received_cookies = [(n, v) for n, v in e.headers if n == 'cookie'] + received_cookies = [(n, v) for n, v in e.headers if n == "cookie"] assert len(received_cookies) == 3 assert cookie_headers == received_cookies - def test_stream_repr(self): + def test_stream_repr(self) -> None: """ Ensure stream string representation is appropriate. """ @@ -1895,20 +1884,19 @@ def test_stream_repr(self): def sanity_check_data_frame(data_frame, expected_flow_controlled_length, expect_padded_flag, - expected_data_frame_pad_length): + expected_data_frame_pad_length) -> None: """ ``data_frame`` is a frame of type ``hyperframe.frame.DataFrame``, and the ``flags`` and ``flow_controlled_length`` of ``data_frame`` match expectations. """ - assert isinstance(data_frame, hyperframe.frame.DataFrame) assert data_frame.flow_controlled_length == expected_flow_controlled_length if expect_padded_flag: - assert 'PADDED' in data_frame.flags + assert "PADDED" in data_frame.flags else: - assert 'PADDED' not in data_frame.flags + assert "PADDED" not in data_frame.flags assert data_frame.pad_length == expected_data_frame_pad_length diff --git a/tests/test_closed_streams.py b/tests/test_closed_streams.py index ef88d8e4..aaf1f5d3 100644 --- a/tests/test_closed_streams.py +++ b/tests/test_closed_streams.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- """ -test_closed_streams -~~~~~~~~~~~~~~~~~~~ - Tests that we handle closed streams correctly. """ +from __future__ import annotations + import pytest import h2.config @@ -14,20 +12,20 @@ import h2.exceptions -class TestClosedStreams(object): +class TestClosedStreams: example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) - def test_can_receive_multiple_rst_stream_frames(self, frame_factory): + def test_can_receive_multiple_rst_stream_frames(self, frame_factory) -> None: """ Multiple RST_STREAM frames can be received, either at once or well after one another. Only the first fires an event. @@ -50,7 +48,7 @@ def test_can_receive_multiple_rst_stream_frames(self, frame_factory): assert isinstance(event, h2.events.StreamReset) - def test_receiving_low_stream_id_causes_goaway(self, frame_factory): + def test_receiving_low_stream_id_causes_goaway(self, frame_factory) -> None: """ The remote peer creating a stream with a lower ID than one we've seen causes a GOAWAY frame. @@ -83,7 +81,7 @@ def test_receiving_low_stream_id_causes_goaway(self, frame_factory): ) assert c.data_to_send() == f.serialize() - def test_closed_stream_not_present_in_streams_dict(self, frame_factory): + def test_closed_stream_not_present_in_streams_dict(self, frame_factory) -> None: """ When streams have been closed, they get removed from the streams dictionary the next time we count the open streams. @@ -107,7 +105,7 @@ def test_closed_stream_not_present_in_streams_dict(self, frame_factory): # The streams dictionary should be empty. assert not c.streams - def test_receive_rst_stream_on_closed_stream(self, frame_factory): + def test_receive_rst_stream_on_closed_stream(self, frame_factory) -> None: """ RST_STREAM frame should be ignored if stream is in a closed state. See RFC 7540 Section 5.1 (closed state) @@ -120,15 +118,15 @@ def test_receive_rst_stream_on_closed_stream(self, frame_factory): # Some time passes and client sends DATA frame and closes stream, # so it is in a half-closed state - c.send_data(1, b'some data', end_stream=True) + c.send_data(1, b"some data", end_stream=True) # Server received HEADERS frame but DATA frame is still on the way. # Stream is in open state on the server-side. In this state server is # allowed to end stream and reset it - this trick helps immediately # close stream on the server-side. headers_frame = frame_factory.build_headers_frame( - [(':status', '200')], - flags=['END_STREAM'], + [(":status", "200")], + flags=["END_STREAM"], stream_id=1, ) events = c.receive_data(headers_frame.serialize()) @@ -141,7 +139,7 @@ def test_receive_rst_stream_on_closed_stream(self, frame_factory): events = c.receive_data(rst_stream_frame.serialize()) assert not events - def test_receive_window_update_on_closed_stream(self, frame_factory): + def test_receive_window_update_on_closed_stream(self, frame_factory) -> None: """ WINDOW_UPDATE frame should be ignored if stream is in a closed state. See RFC 7540 Section 5.1 (closed state) @@ -154,15 +152,15 @@ def test_receive_window_update_on_closed_stream(self, frame_factory): # Some time passes and client sends DATA frame and closes stream, # so it is in a half-closed state - c.send_data(1, b'some data', end_stream=True) + c.send_data(1, b"some data", end_stream=True) # Server received HEADERS frame but DATA frame is still on the way. # Stream is in open state on the server-side. In this state server is # allowed to end stream and after that acknowledge received data by # sending WINDOW_UPDATE frames. headers_frame = frame_factory.build_headers_frame( - [(':status', '200')], - flags=['END_STREAM'], + [(":status", "200")], + flags=["END_STREAM"], stream_id=1, ) events = c.receive_data(headers_frame.serialize()) @@ -187,16 +185,16 @@ def test_receive_window_update_on_closed_stream(self, frame_factory): assert not events -class TestStreamsClosedByEndStream(object): +class TestStreamsClosedByEndStream: example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) @@ -204,16 +202,16 @@ class TestStreamsClosedByEndStream(object): "frame", [ lambda self, ff: ff.build_headers_frame( - self.example_request_headers, flags=['END_STREAM']), + self.example_request_headers, flags=["END_STREAM"]), lambda self, ff: ff.build_headers_frame( self.example_request_headers), - ] + ], ) @pytest.mark.parametrize("clear_streams", [True, False]) def test_frames_after_recv_end_will_error(self, frame_factory, frame, - clear_streams): + clear_streams) -> None: """ A stream that is closed by receiving END_STREAM raises ProtocolError when it receives an unexpected frame. @@ -223,13 +221,13 @@ def test_frames_after_recv_end_will_error(self, c.initiate_connection() f = frame_factory.build_headers_frame( - self.example_request_headers, flags=['END_STREAM'] + self.example_request_headers, flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.send_headers( stream_id=1, headers=self.example_response_headers, - end_stream=True + end_stream=True, ) if clear_streams: @@ -253,16 +251,16 @@ def test_frames_after_recv_end_will_error(self, "frame", [ lambda self, ff: ff.build_headers_frame( - self.example_response_headers, flags=['END_STREAM']), + self.example_response_headers, flags=["END_STREAM"]), lambda self, ff: ff.build_headers_frame( self.example_response_headers), - ] + ], ) @pytest.mark.parametrize("clear_streams", [True, False]) def test_frames_after_send_end_will_error(self, frame_factory, frame, - clear_streams): + clear_streams) -> None: """ A stream that is closed by sending END_STREAM raises ProtocolError when it receives an unexpected frame. @@ -273,7 +271,7 @@ def test_frames_after_send_end_will_error(self, end_stream=True) f = frame_factory.build_headers_frame( - self.example_response_headers, flags=['END_STREAM'] + self.example_response_headers, flags=["END_STREAM"], ) c.receive_data(f.serialize()) @@ -298,12 +296,12 @@ def test_frames_after_send_end_will_error(self, "frame", [ lambda self, ff: ff.build_window_update_frame(1, 1), - lambda self, ff: ff.build_rst_stream_frame(1) - ] + lambda self, ff: ff.build_rst_stream_frame(1), + ], ) def test_frames_after_send_end_will_be_ignored(self, frame_factory, - frame): + frame) -> None: """ A stream that is closed by sending END_STREAM will raise ProtocolError when received unexpected frame. @@ -313,13 +311,13 @@ def test_frames_after_send_end_will_be_ignored(self, c.initiate_connection() f = frame_factory.build_headers_frame( - self.example_request_headers, flags=['END_STREAM'] + self.example_request_headers, flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.send_headers( stream_id=1, headers=self.example_response_headers, - end_stream=True + end_stream=True, ) c.clear_outbound_data_buffer() @@ -330,16 +328,16 @@ def test_frames_after_send_end_will_be_ignored(self, assert not events -class TestStreamsClosedByRstStream(object): +class TestStreamsClosedByRstStream: example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) @@ -349,12 +347,12 @@ class TestStreamsClosedByRstStream(object): lambda self, ff: ff.build_headers_frame( self.example_request_headers), lambda self, ff: ff.build_headers_frame( - self.example_request_headers, flags=['END_STREAM']), - ] + self.example_request_headers, flags=["END_STREAM"]), + ], ) def test_resets_further_frames_after_recv_reset(self, frame_factory, - frame): + frame) -> None: """ A stream that is closed by receive RST_STREAM can receive further frames: it simply sends RST_STREAM for it, and additionally @@ -365,18 +363,18 @@ def test_resets_further_frames_after_recv_reset(self, c.initiate_connection() header_frame = frame_factory.build_headers_frame( - self.example_request_headers, flags=['END_STREAM'] + self.example_request_headers, flags=["END_STREAM"], ) c.receive_data(header_frame.serialize()) c.send_headers( stream_id=1, headers=self.example_response_headers, - end_stream=False + end_stream=False, ) rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED + 1, h2.errors.ErrorCodes.STREAM_CLOSED, ) c.receive_data(rst_frame.serialize()) c.clear_outbound_data_buffer() @@ -385,7 +383,7 @@ def test_resets_further_frames_after_recv_reset(self, events = c.receive_data(f.serialize()) rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED + 1, h2.errors.ErrorCodes.STREAM_CLOSED, ) assert not events assert c.data_to_send() == rst_frame.serialize() @@ -403,7 +401,7 @@ def test_resets_further_frames_after_recv_reset(self, assert c.data_to_send() == rst_frame.serialize() * 3 def test_resets_further_data_frames_after_recv_reset(self, - frame_factory): + frame_factory) -> None: """ A stream that is closed by receive RST_STREAM can receive further DATA frames: it simply sends WINDOW_UPDATE for the connection flow @@ -414,24 +412,24 @@ def test_resets_further_data_frames_after_recv_reset(self, c.initiate_connection() header_frame = frame_factory.build_headers_frame( - self.example_request_headers, flags=['END_STREAM'] + self.example_request_headers, flags=["END_STREAM"], ) c.receive_data(header_frame.serialize()) c.send_headers( stream_id=1, headers=self.example_response_headers, - end_stream=False + end_stream=False, ) rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED + 1, h2.errors.ErrorCodes.STREAM_CLOSED, ) c.receive_data(rst_frame.serialize()) c.clear_outbound_data_buffer() f = frame_factory.build_data_frame( - data=b'some data' + data=b"some data", ) events = c.receive_data(f.serialize()) @@ -461,12 +459,12 @@ def test_resets_further_data_frames_after_recv_reset(self, lambda self, ff: ff.build_headers_frame( self.example_request_headers), lambda self, ff: ff.build_headers_frame( - self.example_request_headers, flags=['END_STREAM']), - ] + self.example_request_headers, flags=["END_STREAM"]), + ], ) def test_resets_further_frames_after_send_reset(self, frame_factory, - frame): + frame) -> None: """ A stream that is closed by sent RST_STREAM can receive further frames: it simply sends RST_STREAM for it. @@ -476,20 +474,20 @@ def test_resets_further_frames_after_send_reset(self, c.initiate_connection() header_frame = frame_factory.build_headers_frame( - self.example_request_headers, flags=['END_STREAM'] + self.example_request_headers, flags=["END_STREAM"], ) c.receive_data(header_frame.serialize()) c.send_headers( stream_id=1, headers=self.example_response_headers, - end_stream=False + end_stream=False, ) c.reset_stream(1, h2.errors.ErrorCodes.INTERNAL_ERROR) rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED + 1, h2.errors.ErrorCodes.STREAM_CLOSED, ) c.clear_outbound_data_buffer() @@ -497,7 +495,7 @@ def test_resets_further_frames_after_send_reset(self, events = c.receive_data(f.serialize()) rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED + 1, h2.errors.ErrorCodes.STREAM_CLOSED, ) assert not events assert c.data_to_send() == rst_frame.serialize() @@ -515,7 +513,7 @@ def test_resets_further_frames_after_send_reset(self, assert c.data_to_send() == rst_frame.serialize() * 3 def test_resets_further_data_frames_after_send_reset(self, - frame_factory): + frame_factory) -> None: """ A stream that is closed by sent RST_STREAM can receive further data frames: it simply sends WINDOW_UPDATE and RST_STREAM for it. @@ -525,14 +523,14 @@ def test_resets_further_data_frames_after_send_reset(self, c.initiate_connection() header_frame = frame_factory.build_headers_frame( - self.example_request_headers, flags=['END_STREAM'] + self.example_request_headers, flags=["END_STREAM"], ) c.receive_data(header_frame.serialize()) c.send_headers( stream_id=1, headers=self.example_response_headers, - end_stream=False + end_stream=False, ) c.reset_stream(1, h2.errors.ErrorCodes.INTERNAL_ERROR) @@ -540,7 +538,7 @@ def test_resets_further_data_frames_after_send_reset(self, c.clear_outbound_data_buffer() f = frame_factory.build_data_frame( - data=b'some data' + data=b"some data", ) events = c.receive_data(f.serialize()) assert not events diff --git a/tests/test_complex_logic.py b/tests/test_complex_logic.py index ff90bb8b..93774821 100644 --- a/tests/test_complex_logic.py +++ b/tests/test_complex_logic.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- """ -test_complex_logic -~~~~~~~~~~~~~~~~ - More complex tests that try to do more. Certain tests don't really eliminate incorrect behaviour unless they do quite a bit. These tests should live here, to keep the pain in once place rather than hide it in the other parts of the test suite. """ +from __future__ import annotations + import pytest import h2 @@ -16,22 +14,23 @@ import h2.connection -class TestComplexClient(object): +class TestComplexClient: """ Complex tests for client-side stacks. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] - def test_correctly_count_server_streams(self, frame_factory): + def test_correctly_count_server_streams(self, frame_factory) -> None: """ We correctly count the number of server streams, both inbound and outbound. @@ -83,7 +82,7 @@ def test_correctly_count_server_streams(self, frame_factory): f = frame_factory.build_headers_frame( stream_id=stream_id, headers=self.example_response_headers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) c.receive_data(f.serialize()) expected_outbound_streams -= 1 @@ -93,8 +92,8 @@ def test_correctly_count_server_streams(self, frame_factory): # Pushed streams can only be closed remotely. f = frame_factory.build_data_frame( stream_id=stream_id+1, - data=b'the content', - flags=['END_STREAM'], + data=b"the content", + flags=["END_STREAM"], ) c.receive_data(f.serialize()) expected_inbound_streams -= 1 @@ -105,23 +104,24 @@ def test_correctly_count_server_streams(self, frame_factory): assert c.open_outbound_streams == 0 -class TestComplexServer(object): +class TestComplexServer: """ Complex tests for server-side stacks. """ + example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) - def test_correctly_count_server_streams(self, frame_factory): + def test_correctly_count_server_streams(self, frame_factory) -> None: """ We correctly count the number of server streams, both inbound and outbound. @@ -160,8 +160,8 @@ def test_correctly_count_server_streams(self, frame_factory): for stream_id in range(13, 0, -2): # Close an inbound stream. f = frame_factory.build_data_frame( - data=b'', - flags=['END_STREAM'], + data=b"", + flags=["END_STREAM"], stream_id=stream_id, ) c.receive_data(f.serialize()) @@ -170,7 +170,7 @@ def test_correctly_count_server_streams(self, frame_factory): assert c.open_inbound_streams == expected_inbound_streams assert c.open_outbound_streams == expected_outbound_streams - c.send_data(stream_id, b'', end_stream=True) + c.send_data(stream_id, b"", end_stream=True) expected_inbound_streams -= 1 assert c.open_inbound_streams == expected_inbound_streams assert c.open_outbound_streams == expected_outbound_streams @@ -178,7 +178,7 @@ def test_correctly_count_server_streams(self, frame_factory): # Pushed streams, however, we can close ourselves. c.send_data( stream_id=stream_id+1, - data=b'', + data=b"", end_stream=True, ) expected_outbound_streams -= 1 @@ -189,15 +189,16 @@ def test_correctly_count_server_streams(self, frame_factory): assert c.open_outbound_streams == 0 -class TestContinuationFrames(object): +class TestContinuationFrames: """ Tests for the relatively complex CONTINUATION frame logic. """ + example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] server_config = h2.config.H2Configuration(client_side=False) @@ -212,12 +213,12 @@ def _build_continuation_sequence(self, headers, block_size, frame_factory): frames = [ frame_factory.build_continuation_frame(c) for c in chunks ] - f.flags = {'END_STREAM'} - frames[-1].flags.add('END_HEADERS') + f.flags = {"END_STREAM"} + frames[-1].flags.add("END_HEADERS") frames.insert(0, f) return frames - def test_continuation_frame_basic(self, frame_factory): + def test_continuation_frame_basic(self, frame_factory) -> None: """ Test that we correctly decode a header block split across continuation frames. @@ -231,7 +232,7 @@ def test_continuation_frame_basic(self, frame_factory): block_size=5, frame_factory=frame_factory, ) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) events = c.receive_data(data) assert len(events) == 2 @@ -244,10 +245,10 @@ def test_continuation_frame_basic(self, frame_factory): assert isinstance(second_event, h2.events.StreamEnded) assert second_event.stream_id == 1 - @pytest.mark.parametrize('stream_id', [3, 1]) + @pytest.mark.parametrize("stream_id", [3, 1]) def test_continuation_cannot_interleave_headers(self, frame_factory, - stream_id): + stream_id) -> None: """ We cannot interleave a new headers block with a CONTINUATION sequence. """ @@ -266,17 +267,17 @@ def test_continuation_cannot_interleave_headers(self, bogus_frame = frame_factory.build_headers_frame( headers=self.example_request_headers, stream_id=stream_id, - flags=['END_STREAM'], + flags=["END_STREAM"], ) frames.insert(len(frames) - 2, bogus_frame) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) with pytest.raises(h2.exceptions.ProtocolError) as e: c.receive_data(data) assert "invalid frame" in str(e.value).lower() - def test_continuation_cannot_interleave_data(self, frame_factory): + def test_continuation_cannot_interleave_data(self, frame_factory) -> None: """ We cannot interleave a data frame with a CONTINUATION sequence. """ @@ -293,18 +294,18 @@ def test_continuation_cannot_interleave_data(self, frame_factory): assert len(frames) > 2 # This is mostly defensive. bogus_frame = frame_factory.build_data_frame( - data=b'hello', + data=b"hello", stream_id=1, ) frames.insert(len(frames) - 2, bogus_frame) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) with pytest.raises(h2.exceptions.ProtocolError) as e: c.receive_data(data) assert "invalid frame" in str(e.value).lower() - def test_continuation_cannot_interleave_unknown_frame(self, frame_factory): + def test_continuation_cannot_interleave_unknown_frame(self, frame_factory) -> None: """ We cannot interleave an unknown frame with a CONTINUATION sequence. """ @@ -321,19 +322,19 @@ def test_continuation_cannot_interleave_unknown_frame(self, frame_factory): assert len(frames) > 2 # This is mostly defensive. bogus_frame = frame_factory.build_data_frame( - data=b'hello', + data=b"hello", stream_id=1, ) bogus_frame.type = 88 frames.insert(len(frames) - 2, bogus_frame) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) with pytest.raises(h2.exceptions.ProtocolError) as e: c.receive_data(data) assert "invalid frame" in str(e.value).lower() - def test_continuation_frame_multiple_blocks(self, frame_factory): + def test_continuation_frame_multiple_blocks(self, frame_factory) -> None: """ Test that we correctly decode several header blocks split across continuation frames. @@ -351,7 +352,7 @@ def test_continuation_frame_multiple_blocks(self, frame_factory): for frame in frames: frame.stream_id = stream_id - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) events = c.receive_data(data) assert len(events) == 2 @@ -365,25 +366,26 @@ def test_continuation_frame_multiple_blocks(self, frame_factory): assert second_event.stream_id == stream_id -class TestContinuationFramesPushPromise(object): +class TestContinuationFramesPushPromise: """ Tests for the relatively complex CONTINUATION frame logic working with PUSH_PROMISE frames. """ + example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), ] def _build_continuation_sequence(self, headers, block_size, frame_factory): f = frame_factory.build_push_promise_frame( - stream_id=1, promised_stream_id=2, headers=headers + stream_id=1, promised_stream_id=2, headers=headers, ) header_data = f.data chunks = [ @@ -394,12 +396,12 @@ def _build_continuation_sequence(self, headers, block_size, frame_factory): frames = [ frame_factory.build_continuation_frame(c) for c in chunks ] - f.flags = {'END_STREAM'} - frames[-1].flags.add('END_HEADERS') + f.flags = {"END_STREAM"} + frames[-1].flags.add("END_HEADERS") frames.insert(0, f) return frames - def test_continuation_frame_basic_push_promise(self, frame_factory): + def test_continuation_frame_basic_push_promise(self, frame_factory) -> None: """ Test that we correctly decode a header block split across continuation frames when that header block is initiated with a PUSH_PROMISE. @@ -413,7 +415,7 @@ def test_continuation_frame_basic_push_promise(self, frame_factory): block_size=5, frame_factory=frame_factory, ) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) events = c.receive_data(data) assert len(events) == 1 @@ -424,10 +426,10 @@ def test_continuation_frame_basic_push_promise(self, frame_factory): assert event.parent_stream_id == 1 assert event.pushed_stream_id == 2 - @pytest.mark.parametrize('stream_id', [3, 1, 2]) + @pytest.mark.parametrize("stream_id", [3, 1, 2]) def test_continuation_cannot_interleave_headers_pp(self, frame_factory, - stream_id): + stream_id) -> None: """ We cannot interleave a new headers block with a CONTINUATION sequence when the headers block is based on a PUSH_PROMISE frame. @@ -446,17 +448,17 @@ def test_continuation_cannot_interleave_headers_pp(self, bogus_frame = frame_factory.build_headers_frame( headers=self.example_response_headers, stream_id=stream_id, - flags=['END_STREAM'], + flags=["END_STREAM"], ) frames.insert(len(frames) - 2, bogus_frame) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) with pytest.raises(h2.exceptions.ProtocolError) as e: c.receive_data(data) assert "invalid frame" in str(e.value).lower() - def test_continuation_cannot_interleave_data(self, frame_factory): + def test_continuation_cannot_interleave_data(self, frame_factory) -> None: """ We cannot interleave a data frame with a CONTINUATION sequence when that sequence began with a PUSH_PROMISE frame. @@ -473,18 +475,18 @@ def test_continuation_cannot_interleave_data(self, frame_factory): assert len(frames) > 2 # This is mostly defensive. bogus_frame = frame_factory.build_data_frame( - data=b'hello', + data=b"hello", stream_id=1, ) frames.insert(len(frames) - 2, bogus_frame) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) with pytest.raises(h2.exceptions.ProtocolError) as e: c.receive_data(data) assert "invalid frame" in str(e.value).lower() - def test_continuation_cannot_interleave_unknown_frame(self, frame_factory): + def test_continuation_cannot_interleave_unknown_frame(self, frame_factory) -> None: """ We cannot interleave an unknown frame with a CONTINUATION sequence when that sequence began with a PUSH_PROMISE frame. @@ -501,22 +503,22 @@ def test_continuation_cannot_interleave_unknown_frame(self, frame_factory): assert len(frames) > 2 # This is mostly defensive. bogus_frame = frame_factory.build_data_frame( - data=b'hello', + data=b"hello", stream_id=1, ) bogus_frame.type = 88 frames.insert(len(frames) - 2, bogus_frame) - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) with pytest.raises(h2.exceptions.ProtocolError) as e: c.receive_data(data) assert "invalid frame" in str(e.value).lower() - @pytest.mark.parametrize('evict', [True, False]) + @pytest.mark.parametrize("evict", [True, False]) def test_stream_remotely_closed_disallows_push_promise(self, evict, - frame_factory): + frame_factory) -> None: """ Streams closed normally by the remote peer disallow PUSH_PROMISE frames, and cause a GOAWAY. @@ -526,13 +528,13 @@ def test_stream_remotely_closed_disallows_push_promise(self, c.send_headers( stream_id=1, headers=self.example_request_headers, - end_stream=True + end_stream=True, ) f = frame_factory.build_headers_frame( stream_id=1, headers=self.example_response_headers, - flags=['END_STREAM'] + flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() @@ -557,7 +559,7 @@ def test_stream_remotely_closed_disallows_push_promise(self, ) assert c.data_to_send() == f.serialize() - def test_continuation_frame_multiple_push_promise(self, frame_factory): + def test_continuation_frame_multiple_push_promise(self, frame_factory) -> None: """ Test that we correctly decode header blocks split across continuation frames when those header block is initiated with a PUSH_PROMISE, for @@ -574,7 +576,7 @@ def test_continuation_frame_multiple_push_promise(self, frame_factory): frame_factory=frame_factory, ) frames[0].promised_stream_id = promised_stream_id - data = b''.join(f.serialize() for f in frames) + data = b"".join(f.serialize() for f in frames) events = c.receive_data(data) assert len(events) == 1 diff --git a/tests/test_config.py b/tests/test_config.py index e3274c37..5651f7ba 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,22 +1,21 @@ -# -*- coding: utf-8 -*- """ -test_config -~~~~~~~~~~~ - Test the configuration object. """ +from __future__ import annotations + import logging + import pytest import h2.config -class TestH2Config(object): +class TestH2Config: """ Tests of the H2 config object. """ - def test_defaults(self): + def test_defaults(self) -> None: """ The default values of the HTTP/2 config object are sensible. """ @@ -26,18 +25,18 @@ def test_defaults(self): assert isinstance(config.logger, h2.config.DummyLogger) boolean_config_options = [ - 'client_side', - 'validate_outbound_headers', - 'normalize_outbound_headers', - 'validate_inbound_headers', - 'normalize_inbound_headers', + "client_side", + "validate_outbound_headers", + "normalize_outbound_headers", + "validate_inbound_headers", + "normalize_inbound_headers", ] - @pytest.mark.parametrize('option_name', boolean_config_options) - @pytest.mark.parametrize('value', [None, 'False', 1]) + @pytest.mark.parametrize("option_name", boolean_config_options) + @pytest.mark.parametrize("value", [None, "False", 1]) def test_boolean_config_options_reject_non_bools_init( - self, option_name, value - ): + self, option_name, value, + ) -> None: """ The boolean config options raise an error if you try to set a value that isn't a boolean via the initializer. @@ -45,11 +44,11 @@ def test_boolean_config_options_reject_non_bools_init( with pytest.raises(ValueError): h2.config.H2Configuration(**{option_name: value}) - @pytest.mark.parametrize('option_name', boolean_config_options) - @pytest.mark.parametrize('value', [None, 'False', 1]) + @pytest.mark.parametrize("option_name", boolean_config_options) + @pytest.mark.parametrize("value", [None, "False", 1]) def test_boolean_config_options_reject_non_bools_attr( - self, option_name, value - ): + self, option_name, value, + ) -> None: """ The boolean config options raise an error if you try to set a value that isn't a boolean via attribute setter. @@ -58,9 +57,9 @@ def test_boolean_config_options_reject_non_bools_attr( with pytest.raises(ValueError): setattr(config, option_name, value) - @pytest.mark.parametrize('option_name', boolean_config_options) - @pytest.mark.parametrize('value', [True, False]) - def test_boolean_config_option_is_reflected_init(self, option_name, value): + @pytest.mark.parametrize("option_name", boolean_config_options) + @pytest.mark.parametrize("value", [True, False]) + def test_boolean_config_option_is_reflected_init(self, option_name, value) -> None: """ The value of the boolean config options, when set, is reflected in the value via the initializer. @@ -68,9 +67,9 @@ def test_boolean_config_option_is_reflected_init(self, option_name, value): config = h2.config.H2Configuration(**{option_name: value}) assert getattr(config, option_name) == value - @pytest.mark.parametrize('option_name', boolean_config_options) - @pytest.mark.parametrize('value', [True, False]) - def test_boolean_config_option_is_reflected_attr(self, option_name, value): + @pytest.mark.parametrize("option_name", boolean_config_options) + @pytest.mark.parametrize("value", [True, False]) + def test_boolean_config_option_is_reflected_attr(self, option_name, value) -> None: """ The value of the boolean config options, when set, is reflected in the value via attribute setter. @@ -79,10 +78,10 @@ def test_boolean_config_option_is_reflected_attr(self, option_name, value): setattr(config, option_name, value) assert getattr(config, option_name) == value - @pytest.mark.parametrize('header_encoding', [True, 1, object()]) + @pytest.mark.parametrize("header_encoding", [True, 1, object()]) def test_header_encoding_must_be_false_str_none_init( - self, header_encoding - ): + self, header_encoding, + ) -> None: """ The value of the ``header_encoding`` setting must be False, a string, or None via the initializer. @@ -90,10 +89,10 @@ def test_header_encoding_must_be_false_str_none_init( with pytest.raises(ValueError): h2.config.H2Configuration(header_encoding=header_encoding) - @pytest.mark.parametrize('header_encoding', [True, 1, object()]) + @pytest.mark.parametrize("header_encoding", [True, 1, object()]) def test_header_encoding_must_be_false_str_none_attr( - self, header_encoding - ): + self, header_encoding, + ) -> None: """ The value of the ``header_encoding`` setting must be False, a string, or None via attribute setter. @@ -102,8 +101,8 @@ def test_header_encoding_must_be_false_str_none_attr( with pytest.raises(ValueError): config.header_encoding = header_encoding - @pytest.mark.parametrize('header_encoding', [False, 'ascii', None]) - def test_header_encoding_is_reflected_init(self, header_encoding): + @pytest.mark.parametrize("header_encoding", [False, "ascii", None]) + def test_header_encoding_is_reflected_init(self, header_encoding) -> None: """ The value of ``header_encoding``, when set, is reflected in the value via the initializer. @@ -111,8 +110,8 @@ def test_header_encoding_is_reflected_init(self, header_encoding): config = h2.config.H2Configuration(header_encoding=header_encoding) assert config.header_encoding == header_encoding - @pytest.mark.parametrize('header_encoding', [False, 'ascii', None]) - def test_header_encoding_is_reflected_attr(self, header_encoding): + @pytest.mark.parametrize("header_encoding", [False, "ascii", None]) + def test_header_encoding_is_reflected_attr(self, header_encoding) -> None: """ The value of ``header_encoding``, when set, is reflected in the value via the attribute setter. @@ -121,17 +120,17 @@ def test_header_encoding_is_reflected_attr(self, header_encoding): config.header_encoding = header_encoding assert config.header_encoding == header_encoding - def test_logger_instance_is_reflected(self): + def test_logger_instance_is_reflected(self) -> None: """ The value of ``logger``, when set, is reflected in the value. """ - logger = logging.Logger('hyper-h2.test') + logger = logging.getLogger("hyper-h2.test") config = h2.config.H2Configuration() config.logger = logger assert config.logger is logger @pytest.mark.parametrize("trace_level", [False, True]) - def test_output_logger(self, capsys, trace_level): + def test_output_logger(self, capsys, trace_level) -> None: logger = h2.config.OutputLogger(trace_level=trace_level) logger.debug("This is a debug message %d.", 123) diff --git a/tests/test_events.py b/tests/test_events.py index c790fbaa..aac91358 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1,25 +1,20 @@ -# -*- coding: utf-8 -*- """ -test_events.py -~~~~~~~~~~~~~~ - Specific tests for any function that is logically self-contained as part of events.py. """ +from __future__ import annotations + import inspect import sys -from hypothesis import given -from hypothesis.strategies import ( - integers, lists, tuples -) import pytest +from hypothesis import given +from hypothesis.strategies import integers, lists, tuples import h2.errors import h2.events import h2.settings - # We define a fairly complex Hypothesis strategy here. We want to build a list # of two tuples of (Setting, value). For Setting we want to make sure we can # handle settings that the rest of hyper knows nothing about, so we want to @@ -29,16 +24,17 @@ tuples( integers(min_value=0, max_value=2**16-1), integers(min_value=0, max_value=2**32-1), - ) + ), ) -class TestRemoteSettingsChanged(object): +class TestRemoteSettingsChanged: """ Validate the function of the RemoteSettingsChanged event. """ + @given(SETTINGS_STRATEGY) - def test_building_settings_from_scratch(self, settings_list): + def test_building_settings_from_scratch(self, settings_list) -> None: """ Missing old settings are defaulted to None. """ @@ -56,7 +52,7 @@ def test_building_settings_from_scratch(self, settings_list): @given(SETTINGS_STRATEGY, SETTINGS_STRATEGY) def test_only_reports_changed_settings(self, old_settings_list, - new_settings_list): + new_settings_list) -> None: """ Settings that were not changed are not reported. """ @@ -69,14 +65,14 @@ def test_only_reports_changed_settings(self, assert len(e.changed_settings) == len(new_settings_dict) assert ( - sorted(list(e.changed_settings.keys())) == - sorted(list(new_settings_dict.keys())) + sorted(e.changed_settings.keys()) == + sorted(new_settings_dict.keys()) ) @given(SETTINGS_STRATEGY, SETTINGS_STRATEGY) def test_correctly_reports_changed_settings(self, old_settings_list, - new_settings_list): + new_settings_list) -> None: """ Settings that are changed are correctly reported. """ @@ -94,26 +90,27 @@ def test_correctly_reports_changed_settings(self, assert e.changed_settings[setting].new_value == new_value -class TestEventReprs(object): +class TestEventReprs: """ Events have useful representations. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_informational_headers = [ - (':status', '100'), - ('server', 'fake-serv/0.1.0') + (":status", "100"), + ("server", "fake-serv/0.1.0"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] - def test_requestreceived_repr(self): + def test_requestreceived_repr(self) -> None: """ RequestReceived has a useful debug representation. """ @@ -129,7 +126,7 @@ def test_requestreceived_repr(self): "(':method', 'GET')]>" ) - def test_responsereceived_repr(self): + def test_responsereceived_repr(self) -> None: """ ResponseReceived has a useful debug representation. """ @@ -143,7 +140,7 @@ def test_responsereceived_repr(self): "('server', 'fake-serv/0.1.0')]>" ) - def test_trailersreceived_repr(self): + def test_trailersreceived_repr(self) -> None: """ TrailersReceived has a useful debug representation. """ @@ -157,7 +154,7 @@ def test_trailersreceived_repr(self): "('server', 'fake-serv/0.1.0')]>" ) - def test_informationalresponsereceived_repr(self): + def test_informationalresponsereceived_repr(self) -> None: """ InformationalResponseReceived has a useful debug representation. """ @@ -171,7 +168,7 @@ def test_informationalresponsereceived_repr(self): "('server', 'fake-serv/0.1.0')]>" ) - def test_datareceived_repr(self): + def test_datareceived_repr(self) -> None: """ DataReceived has a useful debug representation. """ @@ -185,7 +182,7 @@ def test_datareceived_repr(self): "data:6162636465666768696a6b6c6d6e6f7071727374>" ) - def test_windowupdated_repr(self): + def test_windowupdated_repr(self) -> None: """ WindowUpdated has a useful debug representation. """ @@ -195,7 +192,7 @@ def test_windowupdated_repr(self): assert repr(e) == "" - def test_remotesettingschanged_repr(self): + def test_remotesettingschanged_repr(self) -> None: """ RemoteSettingsChanged has a useful debug representation. """ @@ -203,7 +200,7 @@ def test_remotesettingschanged_repr(self): e.changed_settings = { h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: h2.settings.ChangedSetting( - h2.settings.SettingCodes.INITIAL_WINDOW_SIZE, 2**16, 2**15 + h2.settings.SettingCodes.INITIAL_WINDOW_SIZE, 2**16, 2**15, ), } @@ -220,25 +217,25 @@ def test_remotesettingschanged_repr(self): "new_value=32768)}>" ) - def test_pingreceived_repr(self): + def test_pingreceived_repr(self) -> None: """ PingReceived has a useful debug representation. """ e = h2.events.PingReceived() - e.ping_data = b'abcdefgh' + e.ping_data = b"abcdefgh" assert repr(e) == "" - def test_pingackreceived_repr(self): + def test_pingackreceived_repr(self) -> None: """ PingAckReceived has a useful debug representation. """ e = h2.events.PingAckReceived() - e.ping_data = b'abcdefgh' + e.ping_data = b"abcdefgh" assert repr(e) == "" - def test_streamended_repr(self): + def test_streamended_repr(self) -> None: """ StreamEnded has a useful debug representation. """ @@ -247,7 +244,7 @@ def test_streamended_repr(self): assert repr(e) == "" - def test_streamreset_repr(self): + def test_streamreset_repr(self) -> None: """ StreamEnded has a useful debug representation. """ @@ -267,7 +264,7 @@ def test_streamreset_repr(self): "error_code:ErrorCodes.ENHANCE_YOUR_CALM, remote_reset:False>" ) - def test_pushedstreamreceived_repr(self): + def test_pushedstreamreceived_repr(self) -> None: """ PushedStreamReceived has a useful debug representation. """ @@ -285,7 +282,7 @@ def test_pushedstreamreceived_repr(self): "(':method', 'GET')]>" ) - def test_settingsacknowledged_repr(self): + def test_settingsacknowledged_repr(self) -> None: """ SettingsAcknowledged has a useful debug representation. """ @@ -293,7 +290,7 @@ def test_settingsacknowledged_repr(self): e.changed_settings = { h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: h2.settings.ChangedSetting( - h2.settings.SettingCodes.INITIAL_WINDOW_SIZE, 2**16, 2**15 + h2.settings.SettingCodes.INITIAL_WINDOW_SIZE, 2**16, 2**15, ), } @@ -310,7 +307,7 @@ def test_settingsacknowledged_repr(self): "new_value=32768)}>" ) - def test_priorityupdated_repr(self): + def test_priorityupdated_repr(self) -> None: """ PriorityUpdated has a useful debug representation. """ @@ -325,11 +322,11 @@ def test_priorityupdated_repr(self): "exclusive:True>" ) - @pytest.mark.parametrize("additional_data,data_repr", [ + @pytest.mark.parametrize(("additional_data", "data_repr"), [ (None, "None"), - (b'some data', "736f6d652064617461") + (b"some data", "736f6d652064617461"), ]) - def test_connectionterminated_repr(self, additional_data, data_repr): + def test_connectionterminated_repr(self, additional_data, data_repr) -> None: """ ConnectionTerminated has a useful debug representation. """ @@ -341,15 +338,15 @@ def test_connectionterminated_repr(self, additional_data, data_repr): if sys.version_info >= (3, 11): assert repr(e) == ( "" % data_repr + f"last_stream_id:33, additional_data:{data_repr}>" ) else: assert repr(e) == ( "" % data_repr + f"last_stream_id:33, additional_data:{data_repr}>" ) - def test_alternativeserviceavailable_repr(self): + def test_alternativeserviceavailable_repr(self) -> None: """ AlternativeServiceAvailable has a useful debug representation. """ @@ -362,31 +359,31 @@ def test_alternativeserviceavailable_repr(self): 'field_value:h2=":8000"; ma=60>' ) - def test_unknownframereceived_repr(self): + def test_unknownframereceived_repr(self) -> None: """ UnknownFrameReceived has a useful debug representation. """ e = h2.events.UnknownFrameReceived() - assert repr(e) == '' + assert repr(e) == "" def all_events(): """ Generates all the classes (i.e., events) defined in h2.events. """ - for _, obj in inspect.getmembers(sys.modules['h2.events']): + for _, obj in inspect.getmembers(sys.modules["h2.events"]): # We are only interested in objects that are defined in h2.events; # objects that are imported from other modules are not of interest. - if hasattr(obj, '__module__') and (obj.__module__ != 'h2.events'): + if hasattr(obj, "__module__") and (obj.__module__ != "h2.events"): continue if inspect.isclass(obj): yield obj -@pytest.mark.parametrize('event', all_events()) -def test_all_events_subclass_from_event(event): +@pytest.mark.parametrize("event", all_events()) +def test_all_events_subclass_from_event(event) -> None: """ Every event defined in h2.events subclasses from h2.events.Event. """ diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 18904599..fa4e379b 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,15 +1,13 @@ -# -*- coding: utf-8 -*- """ -test_exceptions -~~~~~~~~~~~~~~~ - Tests that verify logic local to exceptions. """ +from __future__ import annotations + import h2.exceptions -class TestExceptions(object): - def test_stream_id_too_low_prints_properly(self): +class TestExceptions: + def test_stream_id_too_low_prints_properly(self) -> None: x = h2.exceptions.StreamIDTooLowError(5, 10) - assert "StreamIDTooLowError: 5 is lower than 10" == str(x) + assert str(x) == "StreamIDTooLowError: 5 is lower than 10" diff --git a/tests/test_flow_control_window.py b/tests/test_flow_control_window.py index 223cf39f..21cc7b8f 100644 --- a/tests/test_flow_control_window.py +++ b/tests/test_flow_control_window.py @@ -1,13 +1,13 @@ -# -*- coding: utf-8 -*- """ test_flow_control ~~~~~~~~~~~~~~~~~ Tests of the flow control management in h2 """ -import pytest +from __future__ import annotations -from hypothesis import given, settings, HealthCheck +import pytest +from hypothesis import HealthCheck, given, settings from hypothesis.strategies import integers import h2.config @@ -18,21 +18,22 @@ import h2.settings -class TestFlowControl(object): +class TestFlowControl: """ Tests of the flow control management in the connection objects. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] server_config = h2.config.H2Configuration(client_side=False) DEFAULT_FLOW_WINDOW = 65535 - def test_flow_control_initializes_properly(self): + def test_flow_control_initializes_properly(self) -> None: """ The flow control window for a stream should initially be the default flow control value. @@ -43,20 +44,20 @@ def test_flow_control_initializes_properly(self): assert c.local_flow_control_window(1) == self.DEFAULT_FLOW_WINDOW assert c.remote_flow_control_window(1) == self.DEFAULT_FLOW_WINDOW - def test_flow_control_decreases_with_sent_data(self): + def test_flow_control_decreases_with_sent_data(self) -> None: """ When data is sent on a stream, the flow control window should drop. """ c = h2.connection.H2Connection() c.send_headers(1, self.example_request_headers) - c.send_data(1, b'some data') + c.send_data(1, b"some data") - remaining_length = self.DEFAULT_FLOW_WINDOW - len(b'some data') + remaining_length = self.DEFAULT_FLOW_WINDOW - len(b"some data") assert (c.local_flow_control_window(1) == remaining_length) @pytest.mark.parametrize("pad_length", [5, 0]) def test_flow_control_decreases_with_sent_data_with_padding(self, - pad_length): + pad_length) -> None: """ When padded data is sent on a stream, the flow control window drops by the length of the padding plus 1 for the 1-byte padding length @@ -65,13 +66,13 @@ def test_flow_control_decreases_with_sent_data_with_padding(self, c = h2.connection.H2Connection() c.send_headers(1, self.example_request_headers) - c.send_data(1, b'some data', pad_length=pad_length) + c.send_data(1, b"some data", pad_length=pad_length) remaining_length = ( - self.DEFAULT_FLOW_WINDOW - len(b'some data') - pad_length - 1 + self.DEFAULT_FLOW_WINDOW - len(b"some data") - pad_length - 1 ) assert c.local_flow_control_window(1) == remaining_length - def test_flow_control_decreases_with_received_data(self, frame_factory): + def test_flow_control_decreases_with_received_data(self, frame_factory) -> None: """ When data is received on a stream, the remote flow control window should drop. @@ -79,14 +80,14 @@ def test_flow_control_decreases_with_received_data(self, frame_factory): c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) f1 = frame_factory.build_headers_frame(self.example_request_headers) - f2 = frame_factory.build_data_frame(b'some data') + f2 = frame_factory.build_data_frame(b"some data") c.receive_data(f1.serialize() + f2.serialize()) - remaining_length = self.DEFAULT_FLOW_WINDOW - len(b'some data') + remaining_length = self.DEFAULT_FLOW_WINDOW - len(b"some data") assert (c.remote_flow_control_window(1) == remaining_length) - def test_flow_control_decreases_with_padded_data(self, frame_factory): + def test_flow_control_decreases_with_padded_data(self, frame_factory) -> None: """ When padded data is received on a stream, the remote flow control window drops by an amount that includes the padding. @@ -94,29 +95,29 @@ def test_flow_control_decreases_with_padded_data(self, frame_factory): c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) f1 = frame_factory.build_headers_frame(self.example_request_headers) - f2 = frame_factory.build_data_frame(b'some data', padding_len=10) + f2 = frame_factory.build_data_frame(b"some data", padding_len=10) c.receive_data(f1.serialize() + f2.serialize()) remaining_length = ( - self.DEFAULT_FLOW_WINDOW - len(b'some data') - 10 - 1 + self.DEFAULT_FLOW_WINDOW - len(b"some data") - 10 - 1 ) assert (c.remote_flow_control_window(1) == remaining_length) - def test_flow_control_is_limited_by_connection(self): + def test_flow_control_is_limited_by_connection(self) -> None: """ The flow control window is limited by the flow control of the connection. """ c = h2.connection.H2Connection() c.send_headers(1, self.example_request_headers) - c.send_data(1, b'some data') + c.send_data(1, b"some data") c.send_headers(3, self.example_request_headers) - remaining_length = self.DEFAULT_FLOW_WINDOW - len(b'some data') + remaining_length = self.DEFAULT_FLOW_WINDOW - len(b"some data") assert (c.local_flow_control_window(3) == remaining_length) - def test_remote_flow_control_is_limited_by_connection(self, frame_factory): + def test_remote_flow_control_is_limited_by_connection(self, frame_factory) -> None: """ The remote flow control window is limited by the flow control of the connection. @@ -124,17 +125,17 @@ def test_remote_flow_control_is_limited_by_connection(self, frame_factory): c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) f1 = frame_factory.build_headers_frame(self.example_request_headers) - f2 = frame_factory.build_data_frame(b'some data') + f2 = frame_factory.build_data_frame(b"some data") f3 = frame_factory.build_headers_frame( self.example_request_headers, stream_id=3, ) c.receive_data(f1.serialize() + f2.serialize() + f3.serialize()) - remaining_length = self.DEFAULT_FLOW_WINDOW - len(b'some data') + remaining_length = self.DEFAULT_FLOW_WINDOW - len(b"some data") assert (c.remote_flow_control_window(3) == remaining_length) - def test_cannot_send_more_data_than_window(self): + def test_cannot_send_more_data_than_window(self) -> None: """ Sending more data than the remaining flow control window raises a FlowControlError. @@ -144,9 +145,9 @@ def test_cannot_send_more_data_than_window(self): c.outbound_flow_control_window = 5 with pytest.raises(h2.exceptions.FlowControlError): - c.send_data(1, b'some data') + c.send_data(1, b"some data") - def test_increasing_connection_window_allows_sending(self, frame_factory): + def test_increasing_connection_window_allows_sending(self, frame_factory) -> None: """ Confirm that sending a WindowUpdate frame on the connection frees up space for further frames. @@ -156,7 +157,7 @@ def test_increasing_connection_window_allows_sending(self, frame_factory): c.outbound_flow_control_window = 5 with pytest.raises(h2.exceptions.FlowControlError): - c.send_data(1, b'some data') + c.send_data(1, b"some data") f = frame_factory.build_window_update_frame( stream_id=0, @@ -165,10 +166,10 @@ def test_increasing_connection_window_allows_sending(self, frame_factory): c.receive_data(f.serialize()) c.clear_outbound_data_buffer() - c.send_data(1, b'some data') + c.send_data(1, b"some data") assert c.data_to_send() - def test_increasing_stream_window_allows_sending(self, frame_factory): + def test_increasing_stream_window_allows_sending(self, frame_factory) -> None: """ Confirm that sending a WindowUpdate frame on the connection frees up space for further frames. @@ -178,7 +179,7 @@ def test_increasing_stream_window_allows_sending(self, frame_factory): c._get_stream_by_id(1).outbound_flow_control_window = 5 with pytest.raises(h2.exceptions.FlowControlError): - c.send_data(1, b'some data') + c.send_data(1, b"some data") f = frame_factory.build_window_update_frame( stream_id=1, @@ -187,10 +188,10 @@ def test_increasing_stream_window_allows_sending(self, frame_factory): c.receive_data(f.serialize()) c.clear_outbound_data_buffer() - c.send_data(1, b'some data') + c.send_data(1, b"some data") assert c.data_to_send() - def test_flow_control_shrinks_in_response_to_settings(self, frame_factory): + def test_flow_control_shrinks_in_response_to_settings(self, frame_factory) -> None: """ Acknowledging SETTINGS_INITIAL_WINDOW_SIZE shrinks the flow control window. @@ -201,13 +202,13 @@ def test_flow_control_shrinks_in_response_to_settings(self, frame_factory): assert c.local_flow_control_window(1) == 65535 f = frame_factory.build_settings_frame( - settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 1280} + settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 1280}, ) c.receive_data(f.serialize()) assert c.local_flow_control_window(1) == 1280 - def test_flow_control_grows_in_response_to_settings(self, frame_factory): + def test_flow_control_grows_in_response_to_settings(self, frame_factory) -> None: """ Acknowledging SETTINGS_INITIAL_WINDOW_SIZE grows the flow control window. @@ -217,7 +218,7 @@ def test_flow_control_grows_in_response_to_settings(self, frame_factory): # Greatly increase the connection flow control window. f = frame_factory.build_window_update_frame( - stream_id=0, increment=128000 + stream_id=0, increment=128000, ) c.receive_data(f.serialize()) @@ -225,14 +226,14 @@ def test_flow_control_grows_in_response_to_settings(self, frame_factory): assert c.local_flow_control_window(1) == 65535 f = frame_factory.build_settings_frame( - settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000} + settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000}, ) c.receive_data(f.serialize()) # The stream window is still the bottleneck, but larger now. assert c.local_flow_control_window(1) == 128000 - def test_flow_control_settings_blocked_by_conn_window(self, frame_factory): + def test_flow_control_settings_blocked_by_conn_window(self, frame_factory) -> None: """ Changing SETTINGS_INITIAL_WINDOW_SIZE does not affect the effective flow control window if the connection window isn't changed. @@ -243,13 +244,13 @@ def test_flow_control_settings_blocked_by_conn_window(self, frame_factory): assert c.local_flow_control_window(1) == 65535 f = frame_factory.build_settings_frame( - settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000} + settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000}, ) c.receive_data(f.serialize()) assert c.local_flow_control_window(1) == 65535 - def test_new_streams_have_flow_control_per_settings(self, frame_factory): + def test_new_streams_have_flow_control_per_settings(self, frame_factory) -> None: """ After a SETTINGS_INITIAL_WINDOW_SIZE change is received, new streams have appropriate new flow control windows. @@ -257,20 +258,20 @@ def test_new_streams_have_flow_control_per_settings(self, frame_factory): c = h2.connection.H2Connection() f = frame_factory.build_settings_frame( - settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000} + settings={h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000}, ) c.receive_data(f.serialize()) # Greatly increase the connection flow control window. f = frame_factory.build_window_update_frame( - stream_id=0, increment=128000 + stream_id=0, increment=128000, ) c.receive_data(f.serialize()) c.send_headers(1, self.example_request_headers) assert c.local_flow_control_window(1) == 128000 - def test_window_update_no_stream(self, frame_factory): + def test_window_update_no_stream(self, frame_factory) -> None: """ WindowUpdate frames received without streams fire an appropriate WindowUpdated event. @@ -280,7 +281,7 @@ def test_window_update_no_stream(self, frame_factory): f = frame_factory.build_window_update_frame( stream_id=0, - increment=5 + increment=5, ) events = c.receive_data(f.serialize()) @@ -291,7 +292,7 @@ def test_window_update_no_stream(self, frame_factory): assert event.stream_id == 0 assert event.delta == 5 - def test_window_update_with_stream(self, frame_factory): + def test_window_update_with_stream(self, frame_factory) -> None: """ WindowUpdate frames received with streams fire an appropriate WindowUpdated event. @@ -302,9 +303,9 @@ def test_window_update_with_stream(self, frame_factory): f1 = frame_factory.build_headers_frame(self.example_request_headers) f2 = frame_factory.build_window_update_frame( stream_id=1, - increment=66 + increment=66, ) - data = b''.join(map(lambda f: f.serialize(), [f1, f2])) + data = b"".join(f.serialize() for f in [f1, f2]) events = c.receive_data(data) assert len(events) == 2 @@ -314,7 +315,7 @@ def test_window_update_with_stream(self, frame_factory): assert event.stream_id == 1 assert event.delta == 66 - def test_we_can_increment_stream_flow_control(self, frame_factory): + def test_we_can_increment_stream_flow_control(self, frame_factory) -> None: """ It is possible for the user to increase the flow control window for streams. @@ -326,14 +327,14 @@ def test_we_can_increment_stream_flow_control(self, frame_factory): expected_frame = frame_factory.build_window_update_frame( stream_id=1, - increment=5 + increment=5, ) events = c.increment_flow_control_window(increment=5, stream_id=1) assert not events assert c.data_to_send() == expected_frame.serialize() - def test_we_can_increment_connection_flow_control(self, frame_factory): + def test_we_can_increment_connection_flow_control(self, frame_factory) -> None: """ It is possible for the user to increase the flow control window for the entire connection. @@ -345,14 +346,14 @@ def test_we_can_increment_connection_flow_control(self, frame_factory): expected_frame = frame_factory.build_window_update_frame( stream_id=0, - increment=5 + increment=5, ) events = c.increment_flow_control_window(increment=5) assert not events assert c.data_to_send() == expected_frame.serialize() - def test_we_enforce_our_flow_control_window(self, frame_factory): + def test_we_enforce_our_flow_control_window(self, frame_factory) -> None: """ The user can set a low flow control window, which leads to connection teardown if violated. @@ -362,7 +363,7 @@ def test_we_enforce_our_flow_control_window(self, frame_factory): # Change the flow control window to 80 bytes. c.update_settings( - {h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 80} + {h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 80}, ) f = frame_factory.build_settings_frame({}, ack=True) c.receive_data(f.serialize()) @@ -373,7 +374,7 @@ def test_we_enforce_our_flow_control_window(self, frame_factory): # Attempt to violate the flow control window. c.clear_outbound_data_buffer() - f = frame_factory.build_data_frame(b'\x01' * 100) + f = frame_factory.build_data_frame(b"\x01" * 100) with pytest.raises(h2.exceptions.FlowControlError): c.receive_data(f.serialize()) @@ -385,7 +386,7 @@ def test_we_enforce_our_flow_control_window(self, frame_factory): ) assert c.data_to_send() == expected_frame.serialize() - def test_shrink_remote_flow_control_settings(self, frame_factory): + def test_shrink_remote_flow_control_settings(self, frame_factory) -> None: """ The remote peer acknowledging our SETTINGS_INITIAL_WINDOW_SIZE shrinks the flow control window. @@ -402,7 +403,7 @@ def test_shrink_remote_flow_control_settings(self, frame_factory): assert c.remote_flow_control_window(1) == 1280 - def test_grow_remote_flow_control_settings(self, frame_factory): + def test_grow_remote_flow_control_settings(self, frame_factory) -> None: """ The remote peer acknowledging our SETTINGS_INITIAL_WINDOW_SIZE grows the flow control window. @@ -416,14 +417,14 @@ def test_grow_remote_flow_control_settings(self, frame_factory): assert c.remote_flow_control_window(1) == 65535 c.update_settings( - {h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000} + {h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000}, ) f = frame_factory.build_settings_frame({}, ack=True) c.receive_data(f.serialize()) assert c.remote_flow_control_window(1) == 128000 - def test_new_streams_have_remote_flow_control(self, frame_factory): + def test_new_streams_have_remote_flow_control(self, frame_factory) -> None: """ After a SETTINGS_INITIAL_WINDOW_SIZE change is acknowledged by the remote peer, new streams have appropriate new flow control windows. @@ -431,7 +432,7 @@ def test_new_streams_have_remote_flow_control(self, frame_factory): c = h2.connection.H2Connection() c.update_settings( - {h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000} + {h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 128000}, ) f = frame_factory.build_settings_frame({}, ack=True) c.receive_data(f.serialize()) @@ -443,9 +444,9 @@ def test_new_streams_have_remote_flow_control(self, frame_factory): assert c.remote_flow_control_window(1) == 128000 @pytest.mark.parametrize( - 'increment', [0, -15, 2**31] + "increment", [0, -15, 2**31], ) - def test_reject_bad_attempts_to_increment_flow_control(self, increment): + def test_reject_bad_attempts_to_increment_flow_control(self, increment) -> None: """ Attempting to increment a flow control increment outside the valid range causes a ValueError to be raised. @@ -462,8 +463,8 @@ def test_reject_bad_attempts_to_increment_flow_control(self, increment): with pytest.raises(ValueError): c.increment_flow_control_window(increment=increment) - @pytest.mark.parametrize('stream_id', [0, 1]) - def test_reject_bad_remote_increments(self, frame_factory, stream_id): + @pytest.mark.parametrize("stream_id", [0, 1]) + def test_reject_bad_remote_increments(self, frame_factory, stream_id) -> None: """ Remote peers attempting to increment flow control outside the valid range cause connection errors of type PROTOCOL_ERROR. @@ -476,7 +477,7 @@ def test_reject_bad_remote_increments(self, frame_factory, stream_id): c.clear_outbound_data_buffer() f = frame_factory.build_window_update_frame( - stream_id=stream_id, increment=0 + stream_id=stream_id, increment=0, ) with pytest.raises(h2.exceptions.ProtocolError): @@ -488,7 +489,7 @@ def test_reject_bad_remote_increments(self, frame_factory, stream_id): ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_increasing_connection_window_too_far(self, frame_factory): + def test_reject_increasing_connection_window_too_far(self, frame_factory) -> None: """ Attempts by the remote peer to increase the connection flow control window beyond 2**31 - 1 are rejected. @@ -500,7 +501,7 @@ def test_reject_increasing_connection_window_too_far(self, frame_factory): increment = 2**31 - c.outbound_flow_control_window f = frame_factory.build_window_update_frame( - stream_id=0, increment=increment + stream_id=0, increment=increment, ) with pytest.raises(h2.exceptions.FlowControlError): @@ -512,7 +513,7 @@ def test_reject_increasing_connection_window_too_far(self, frame_factory): ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_increasing_stream_window_too_far(self, frame_factory): + def test_reject_increasing_stream_window_too_far(self, frame_factory) -> None: """ Attempts by the remote peer to increase the stream flow control window beyond 2**31 - 1 are rejected. @@ -525,7 +526,7 @@ def test_reject_increasing_stream_window_too_far(self, frame_factory): increment = 2**31 - c.outbound_flow_control_window f = frame_factory.build_window_update_frame( - stream_id=1, increment=increment + stream_id=1, increment=increment, ) events = c.receive_data(f.serialize()) @@ -543,7 +544,7 @@ def test_reject_increasing_stream_window_too_far(self, frame_factory): ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_overlarge_conn_window_settings(self, frame_factory): + def test_reject_overlarge_conn_window_settings(self, frame_factory) -> None: """ SETTINGS frames cannot change the size of the connection flow control window. @@ -555,7 +556,7 @@ def test_reject_overlarge_conn_window_settings(self, frame_factory): increment = 2**31 - 1 - c.outbound_flow_control_window f = frame_factory.build_window_update_frame( - stream_id=0, increment=increment + stream_id=0, increment=increment, ) c.receive_data(f.serialize()) @@ -563,8 +564,8 @@ def test_reject_overlarge_conn_window_settings(self, frame_factory): f = frame_factory.build_settings_frame( settings={ h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: - self.DEFAULT_FLOW_WINDOW + 1 - } + self.DEFAULT_FLOW_WINDOW + 1, + }, ) c.clear_outbound_data_buffer() @@ -575,11 +576,11 @@ def test_reject_overlarge_conn_window_settings(self, frame_factory): expected_frame = frame_factory.build_settings_frame( settings={}, - ack=True + ack=True, ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_overlarge_stream_window_settings(self, frame_factory): + def test_reject_overlarge_stream_window_settings(self, frame_factory) -> None: """ Remote attempts to create overlarge stream windows via SETTINGS frames are rejected. @@ -592,7 +593,7 @@ def test_reject_overlarge_stream_window_settings(self, frame_factory): increment = 2**31 - 1 - c.outbound_flow_control_window f = frame_factory.build_window_update_frame( - stream_id=1, increment=increment + stream_id=1, increment=increment, ) c.receive_data(f.serialize()) @@ -600,8 +601,8 @@ def test_reject_overlarge_stream_window_settings(self, frame_factory): f = frame_factory.build_settings_frame( settings={ h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: - self.DEFAULT_FLOW_WINDOW + 1 - } + self.DEFAULT_FLOW_WINDOW + 1, + }, ) c.clear_outbound_data_buffer() with pytest.raises(h2.exceptions.FlowControlError): @@ -613,7 +614,7 @@ def test_reject_overlarge_stream_window_settings(self, frame_factory): ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_local_overlarge_increase_connection_window(self): + def test_reject_local_overlarge_increase_connection_window(self) -> None: """ Local attempts to increase the connection window too far are rejected. """ @@ -625,7 +626,7 @@ def test_reject_local_overlarge_increase_connection_window(self): with pytest.raises(h2.exceptions.FlowControlError): c.increment_flow_control_window(increment=increment) - def test_reject_local_overlarge_increase_stream_window(self): + def test_reject_local_overlarge_increase_stream_window(self) -> None: """ Local attempts to increase the connection window too far are rejected. """ @@ -638,7 +639,7 @@ def test_reject_local_overlarge_increase_stream_window(self): with pytest.raises(h2.exceptions.FlowControlError): c.increment_flow_control_window(increment=increment, stream_id=1) - def test_send_update_on_closed_streams(self, frame_factory): + def test_send_update_on_closed_streams(self, frame_factory) -> None: c = h2.connection.H2Connection() c.initiate_connection() c.send_headers(1, self.example_request_headers) @@ -648,7 +649,7 @@ def test_send_update_on_closed_streams(self, frame_factory): c.open_outbound_streams c.open_inbound_streams - f = frame_factory.build_data_frame(b'some data'*1500) + f = frame_factory.build_data_frame(b"some data"*1500) events = c.receive_data(f.serialize()*3) assert not events @@ -664,7 +665,7 @@ def test_send_update_on_closed_streams(self, frame_factory): ).serialize() assert c.data_to_send() == expected - f = frame_factory.build_data_frame(b'') + f = frame_factory.build_data_frame(b"") events = c.receive_data(f.serialize()) assert not events @@ -675,15 +676,16 @@ def test_send_update_on_closed_streams(self, frame_factory): assert c.data_to_send() == expected -class TestAutomaticFlowControl(object): +class TestAutomaticFlowControl: """ Tests for the automatic flow control logic. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] server_config = h2.config.H2Configuration(client_side=False) @@ -699,16 +701,16 @@ def _setup_connection_and_send_headers(self, frame_factory): c.receive_data(frame_factory.preamble()) c.update_settings( - {h2.settings.SettingCodes.MAX_FRAME_SIZE: self.DEFAULT_FLOW_WINDOW} + {h2.settings.SettingCodes.MAX_FRAME_SIZE: self.DEFAULT_FLOW_WINDOW}, ) settings_frame = frame_factory.build_settings_frame( - settings={}, ack=True + settings={}, ack=True, ) c.receive_data(settings_frame.serialize()) c.clear_outbound_data_buffer() headers_frame = frame_factory.build_headers_frame( - headers=self.example_request_headers + headers=self.example_request_headers, ) c.receive_data(headers_frame.serialize()) c.clear_outbound_data_buffer() @@ -716,7 +718,7 @@ def _setup_connection_and_send_headers(self, frame_factory): @given(stream_id=integers(max_value=0)) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_must_acknowledge_for_stream(self, frame_factory, stream_id): + def test_must_acknowledge_for_stream(self, frame_factory, stream_id) -> None: """ Flow control acknowledgements must be done on a stream ID that is greater than zero. @@ -731,18 +733,18 @@ def test_must_acknowledge_for_stream(self, frame_factory, stream_id): # data acknolwedgement. c = self._setup_connection_and_send_headers(frame_factory) data_frame = frame_factory.build_data_frame( - b'some data', flags=['END_STREAM'] + b"some data", flags=["END_STREAM"], ) c.receive_data(data_frame.serialize()) with pytest.raises(ValueError): c.acknowledge_received_data( - acknowledged_size=5, stream_id=stream_id + acknowledged_size=5, stream_id=stream_id, ) @given(size=integers(max_value=-1)) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_cannot_acknowledge_less_than_zero(self, frame_factory, size): + def test_cannot_acknowledge_less_than_zero(self, frame_factory, size) -> None: """ The user must acknowledge at least 0 bytes. """ @@ -756,14 +758,14 @@ def test_cannot_acknowledge_less_than_zero(self, frame_factory, size): # data acknolwedgement. c = self._setup_connection_and_send_headers(frame_factory) data_frame = frame_factory.build_data_frame( - b'some data', flags=['END_STREAM'] + b"some data", flags=["END_STREAM"], ) c.receive_data(data_frame.serialize()) with pytest.raises(ValueError): c.acknowledge_received_data(acknowledged_size=size, stream_id=1) - def test_acknowledging_small_chunks_does_nothing(self, frame_factory): + def test_acknowledging_small_chunks_does_nothing(self, frame_factory) -> None: """ When a small amount of data is received and acknowledged, no window update is emitted. @@ -771,17 +773,17 @@ def test_acknowledging_small_chunks_does_nothing(self, frame_factory): c = self._setup_connection_and_send_headers(frame_factory) data_frame = frame_factory.build_data_frame( - b'some data', flags=['END_STREAM'] + b"some data", flags=["END_STREAM"], ) data_event = c.receive_data(data_frame.serialize())[0] c.acknowledge_received_data( - data_event.flow_controlled_length, stream_id=1 + data_event.flow_controlled_length, stream_id=1, ) assert not c.data_to_send() - def test_acknowledging_no_data_does_nothing(self, frame_factory): + def test_acknowledging_no_data_does_nothing(self, frame_factory) -> None: """ If a user accidentally acknowledges no data, nothing happens. """ @@ -789,28 +791,28 @@ def test_acknowledging_no_data_does_nothing(self, frame_factory): # Send an empty data frame, just to give the user impetus to ack the # data. - data_frame = frame_factory.build_data_frame(b'') + data_frame = frame_factory.build_data_frame(b"") c.receive_data(data_frame.serialize()) c.acknowledge_received_data(0, stream_id=1) assert not c.data_to_send() - @pytest.mark.parametrize('force_cleanup', (True, False)) + @pytest.mark.parametrize("force_cleanup", [True, False]) def test_acknowledging_data_on_closed_stream(self, frame_factory, - force_cleanup): + force_cleanup) -> None: """ When acknowledging data on a stream that has just been closed, no acknowledgement is given for that stream, only for the connection. """ c = self._setup_connection_and_send_headers(frame_factory) - data_to_send = b'\x00' * self.DEFAULT_FLOW_WINDOW + data_to_send = b"\x00" * self.DEFAULT_FLOW_WINDOW data_frame = frame_factory.build_data_frame(data_to_send) c.receive_data(data_frame.serialize()) rst_frame = frame_factory.build_rst_stream_frame( - stream_id=1 + stream_id=1, ) c.receive_data(rst_frame.serialize()) c.clear_outbound_data_buffer() @@ -823,11 +825,11 @@ def test_acknowledging_data_on_closed_stream(self, c.acknowledge_received_data(2048, stream_id=1) expected = frame_factory.build_window_update_frame( - stream_id=0, increment=2048 + stream_id=0, increment=2048, ) assert c.data_to_send() == expected.serialize() - def test_acknowledging_streams_we_never_saw(self, frame_factory): + def test_acknowledging_streams_we_never_saw(self, frame_factory) -> None: """ If the user acknowledges a stream ID we've never seen, that raises a NoSuchStreamError. @@ -842,7 +844,7 @@ def test_acknowledging_streams_we_never_saw(self, frame_factory): @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) def test_acknowledging_1024_bytes_when_empty_increments(self, frame_factory, - increment): + increment) -> None: """ If the flow control window is empty and we acknowledge 1024 bytes or more, we will emit a WINDOW_UPDATE frame just to move the connection @@ -856,20 +858,20 @@ def test_acknowledging_1024_bytes_when_empty_increments(self, c = self._setup_connection_and_send_headers(frame_factory) - data_to_send = b'\x00' * self.DEFAULT_FLOW_WINDOW + data_to_send = b"\x00" * self.DEFAULT_FLOW_WINDOW data_frame = frame_factory.build_data_frame(data_to_send) c.receive_data(data_frame.serialize()) c.acknowledge_received_data(increment, stream_id=1) first_expected = frame_factory.build_window_update_frame( - stream_id=0, increment=increment + stream_id=0, increment=increment, ) second_expected = frame_factory.build_window_update_frame( - stream_id=1, increment=increment + stream_id=1, increment=increment, ) - expected_data = b''.join( - [first_expected.serialize(), second_expected.serialize()] + expected_data = b"".join( + [first_expected.serialize(), second_expected.serialize()], ) assert c.data_to_send() == expected_data @@ -877,7 +879,7 @@ def test_acknowledging_1024_bytes_when_empty_increments(self, # increment the stream window anyway. @given(integers(min_value=1025, max_value=(DEFAULT_FLOW_WINDOW // 4) - 1)) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_connection_only_empty(self, frame_factory, increment): + def test_connection_only_empty(self, frame_factory, increment) -> None: """ If the connection flow control window is empty, but the stream flow control windows aren't, and 1024 bytes or more are acknowledged by the @@ -894,20 +896,20 @@ def test_connection_only_empty(self, frame_factory, increment): for stream_id in [3, 5, 7]: f = frame_factory.build_headers_frame( - headers=self.example_request_headers, stream_id=stream_id + headers=self.example_request_headers, stream_id=stream_id, ) c.receive_data(f.serialize()) # Now we send 1/4 of the connection window per stream. Annoyingly, # that's an odd number, so we need to round the last frame up. - data_to_send = b'\x00' * (self.DEFAULT_FLOW_WINDOW // 4) + data_to_send = b"\x00" * (self.DEFAULT_FLOW_WINDOW // 4) for stream_id in [1, 3, 5]: f = frame_factory.build_data_frame( - data_to_send, stream_id=stream_id + data_to_send, stream_id=stream_id, ) c.receive_data(f.serialize()) - data_to_send = b'\x00' * c.remote_flow_control_window(7) + data_to_send = b"\x00" * c.remote_flow_control_window(7) data_frame = frame_factory.build_data_frame(data_to_send, stream_id=7) c.receive_data(data_frame.serialize()) @@ -915,13 +917,13 @@ def test_connection_only_empty(self, frame_factory, increment): c.acknowledge_received_data(increment, stream_id=1) expected_data = frame_factory.build_window_update_frame( - stream_id=0, increment=increment + stream_id=0, increment=increment, ).serialize() assert c.data_to_send() == expected_data @given(integers(min_value=1025, max_value=DEFAULT_FLOW_WINDOW)) @settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) - def test_mixing_update_forms(self, frame_factory, increment): + def test_mixing_update_forms(self, frame_factory, increment) -> None: """ If the user mixes acknowledging data with manually incrementing windows, we still keep track of what's going on. @@ -934,14 +936,14 @@ def test_mixing_update_forms(self, frame_factory, increment): # Empty the flow control window. c = self._setup_connection_and_send_headers(frame_factory) - data_to_send = b'\x00' * self.DEFAULT_FLOW_WINDOW + data_to_send = b"\x00" * self.DEFAULT_FLOW_WINDOW data_frame = frame_factory.build_data_frame(data_to_send) c.receive_data(data_frame.serialize()) # Manually increment the connection flow control window back to fully # open, but leave the stream window closed. c.increment_flow_control_window( - stream_id=None, increment=self.DEFAULT_FLOW_WINDOW + stream_id=None, increment=self.DEFAULT_FLOW_WINDOW, ) c.clear_outbound_data_buffer() @@ -952,6 +954,6 @@ def test_mixing_update_forms(self, frame_factory, increment): # We expect to see one window update frame only, for the stream. expected_data = frame_factory.build_window_update_frame( - stream_id=1, increment=increment + stream_id=1, increment=increment, ).serialize() assert c.data_to_send() == expected_data diff --git a/tests/test_h2_upgrade.py b/tests/test_h2_upgrade.py index 7954a52c..fae46082 100644 --- a/tests/test_h2_upgrade.py +++ b/tests/test_h2_upgrade.py @@ -1,15 +1,12 @@ -# -*- coding: utf-8 -*- """ -test_h2_upgrade.py -~~~~~~~~~~~~~~~~~~ - -This module contains tests that exercise the HTTP Upgrade functionality of +Tests that exercise the HTTP Upgrade functionality of hyper-h2, ensuring that clients and servers can upgrade their plaintext HTTP/1.1 connections to HTTP/2. """ +from __future__ import annotations + import base64 -from h2.utilities import utf8_encode_headers import pytest import h2.config @@ -17,32 +14,33 @@ import h2.errors import h2.events import h2.exceptions - +from h2.utilities import utf8_encode_headers EXAMPLE_REQUEST_HEADERS = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] EXAMPLE_REQUEST_HEADERS_BYTES = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] -class TestClientUpgrade(object): +class TestClientUpgrade: """ Tests of the client-side of the HTTP/2 upgrade dance. """ + example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), ] - def test_returns_http2_settings(self, frame_factory): + def test_returns_http2_settings(self, frame_factory) -> None: """ Calling initiate_upgrade_connection returns a base64url encoded Settings frame with the settings used by the connection. @@ -51,16 +49,16 @@ def test_returns_http2_settings(self, frame_factory): data = conn.initiate_upgrade_connection() # The base64 encoding must not be padded. - assert not data.endswith(b'=') + assert not data.endswith(b"=") # However, SETTINGS frames should never need to be padded. decoded_frame = base64.urlsafe_b64decode(data) expected_frame = frame_factory.build_settings_frame( - settings=conn.local_settings + settings=conn.local_settings, ) assert decoded_frame == expected_frame.serialize_body() - def test_emits_preamble(self, frame_factory): + def test_emits_preamble(self, frame_factory) -> None: """ Calling initiate_upgrade_connection emits the connection preamble. """ @@ -72,11 +70,11 @@ def test_emits_preamble(self, frame_factory): data = data[len(frame_factory.preamble()):] expected_frame = frame_factory.build_settings_frame( - settings=conn.local_settings + settings=conn.local_settings, ) assert data == expected_frame.serialize() - def test_can_receive_response(self, frame_factory): + def test_can_receive_response(self, frame_factory) -> None: """ After upgrading, we can safely receive a response. """ @@ -90,8 +88,8 @@ def test_can_receive_response(self, frame_factory): ) f2 = frame_factory.build_data_frame( stream_id=1, - data=b'some data', - flags=['END_STREAM'] + data=b"some data", + flags=["END_STREAM"], ) events = c.receive_data(f1.serialize() + f2.serialize()) assert len(events) == 3 @@ -101,13 +99,13 @@ def test_can_receive_response(self, frame_factory): assert isinstance(events[2], h2.events.StreamEnded) assert events[0].headers == self.example_response_headers - assert events[1].data == b'some data' + assert events[1].data == b"some data" assert all(e.stream_id == 1 for e in events) assert not c.data_to_send() @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) - def test_can_receive_pushed_stream(self, frame_factory, headers): + def test_can_receive_pushed_stream(self, frame_factory, headers) -> None: """ After upgrading, we can safely receive a pushed stream. """ @@ -118,7 +116,7 @@ def test_can_receive_pushed_stream(self, frame_factory, headers): f = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, - headers=headers + headers=headers, ) events = c.receive_data(f.serialize()) assert len(events) == 1 @@ -129,7 +127,7 @@ def test_can_receive_pushed_stream(self, frame_factory, headers): assert events[0].pushed_stream_id == 2 @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) - def test_cannot_send_headers_stream_1(self, frame_factory, headers): + def test_cannot_send_headers_stream_1(self, frame_factory, headers) -> None: """ After upgrading, we cannot send headers on stream 1. """ @@ -140,7 +138,7 @@ def test_cannot_send_headers_stream_1(self, frame_factory, headers): with pytest.raises(h2.exceptions.ProtocolError): c.send_headers(stream_id=1, headers=headers) - def test_cannot_send_data_stream_1(self, frame_factory): + def test_cannot_send_data_stream_1(self, frame_factory) -> None: """ After upgrading, we cannot send data on stream 1. """ @@ -149,20 +147,21 @@ def test_cannot_send_data_stream_1(self, frame_factory): c.clear_outbound_data_buffer() with pytest.raises(h2.exceptions.ProtocolError): - c.send_data(stream_id=1, data=b'some data') + c.send_data(stream_id=1, data=b"some data") -class TestServerUpgrade(object): +class TestServerUpgrade: """ Tests of the server-side of the HTTP/2 upgrade dance. """ + example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) - def test_returns_nothing(self, frame_factory): + def test_returns_nothing(self, frame_factory) -> None: """ Calling initiate_upgrade_connection returns nothing. """ @@ -171,7 +170,7 @@ def test_returns_nothing(self, frame_factory): data = conn.initiate_upgrade_connection(curl_header) assert data is None - def test_emits_preamble(self, frame_factory): + def test_emits_preamble(self, frame_factory) -> None: """ Calling initiate_upgrade_connection emits the connection preamble. """ @@ -180,11 +179,11 @@ def test_emits_preamble(self, frame_factory): data = conn.data_to_send() expected_frame = frame_factory.build_settings_frame( - settings=conn.local_settings + settings=conn.local_settings, ) assert data == expected_frame.serialize() - def test_can_send_response(self, frame_factory): + def test_can_send_response(self, frame_factory) -> None: """ After upgrading, we can safely send a response. """ @@ -193,7 +192,7 @@ def test_can_send_response(self, frame_factory): c.clear_outbound_data_buffer() c.send_headers(stream_id=1, headers=self.example_response_headers) - c.send_data(stream_id=1, data=b'some data', end_stream=True) + c.send_data(stream_id=1, data=b"some data", end_stream=True) f1 = frame_factory.build_headers_frame( stream_id=1, @@ -201,15 +200,15 @@ def test_can_send_response(self, frame_factory): ) f2 = frame_factory.build_data_frame( stream_id=1, - data=b'some data', - flags=['END_STREAM'] + data=b"some data", + flags=["END_STREAM"], ) expected_data = f1.serialize() + f2.serialize() assert c.data_to_send() == expected_data @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) - def test_can_push_stream(self, frame_factory, headers): + def test_can_push_stream(self, frame_factory, headers) -> None: """ After upgrading, we can safely push a stream. """ @@ -220,7 +219,7 @@ def test_can_push_stream(self, frame_factory, headers): c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=headers + request_headers=headers, ) f = frame_factory.build_push_promise_frame( @@ -231,7 +230,7 @@ def test_can_push_stream(self, frame_factory, headers): assert c.data_to_send() == f.serialize() @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) - def test_cannot_receive_headers_stream_1(self, frame_factory, headers): + def test_cannot_receive_headers_stream_1(self, frame_factory, headers) -> None: """ After upgrading, we cannot receive headers on stream 1. """ @@ -252,7 +251,7 @@ def test_cannot_receive_headers_stream_1(self, frame_factory, headers): ) assert c.data_to_send() == expected_frame.serialize() - def test_cannot_receive_data_stream_1(self, frame_factory): + def test_cannot_receive_data_stream_1(self, frame_factory) -> None: """ After upgrading, we cannot receive data on stream 1. """ @@ -263,7 +262,7 @@ def test_cannot_receive_data_stream_1(self, frame_factory): f = frame_factory.build_data_frame( stream_id=1, - data=b'some data', + data=b"some data", ) c.receive_data(f.serialize()) @@ -273,7 +272,7 @@ def test_cannot_receive_data_stream_1(self, frame_factory): ).serialize() assert c.data_to_send() == expected - def test_client_settings_are_applied(self, frame_factory): + def test_client_settings_are_applied(self, frame_factory) -> None: """ The settings provided by the client are applied and immediately ACK'ed. @@ -300,7 +299,7 @@ def test_client_settings_are_applied(self, frame_factory): # and has not sent a SETTINGS ack, and also that the server has the # correct settings. expected_frame = frame_factory.build_settings_frame( - server.local_settings + server.local_settings, ) assert server.data_to_send() == expected_frame.serialize() diff --git a/tests/test_head_request.py b/tests/test_head_request.py index 2f46b72f..e32d1525 100644 --- a/tests/test_head_request.py +++ b/tests/test_head_request.py @@ -1,43 +1,40 @@ -# -*- coding; utf-8 -*- -""" -test_head_request -~~~~~~~~~~~~~~~~~ -""" -import h2.connection +from __future__ import annotations + import pytest +import h2.connection EXAMPLE_REQUEST_HEADERS_BYTES = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'HEAD'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"HEAD"), ] EXAMPLE_REQUEST_HEADERS = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'HEAD'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "HEAD"), ] -class TestHeadRequest(object): +class TestHeadRequest: example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0'), - (b'content_length', b'1'), + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), + (b"content_length", b"1"), ] - @pytest.mark.parametrize('headers', [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) - def test_non_zero_content_and_no_body(self, frame_factory, headers): + @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) + def test_non_zero_content_and_no_body(self, frame_factory, headers) -> None: c = h2.connection.H2Connection() c.initiate_connection() c.send_headers(1, headers, end_stream=True) f = frame_factory.build_headers_frame( self.example_response_headers, - flags=['END_STREAM'] + flags=["END_STREAM"], ) events = c.receive_data(f.serialize()) @@ -48,16 +45,16 @@ def test_non_zero_content_and_no_body(self, frame_factory, headers): assert event.stream_id == 1 assert event.headers == self.example_response_headers - @pytest.mark.parametrize('headers', [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) - def test_reject_non_zero_content_and_body(self, frame_factory, headers): + @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES]) + def test_reject_non_zero_content_and_body(self, frame_factory, headers) -> None: c = h2.connection.H2Connection() c.initiate_connection() c.send_headers(1, headers) headers = frame_factory.build_headers_frame( - self.example_response_headers + self.example_response_headers, ) - data = frame_factory.build_data_frame(data=b'\x01') + data = frame_factory.build_data_frame(data=b"\x01") c.receive_data(headers.serialize()) with pytest.raises(h2.exceptions.InvalidBodyLengthError): diff --git a/tests/test_header_indexing.py b/tests/test_header_indexing.py index 122db330..f50f23d8 100644 --- a/tests/test_header_indexing.py +++ b/tests/test_header_indexing.py @@ -1,20 +1,17 @@ -# -*- coding: utf-8 -*- """ -test_header_indexing.py -~~~~~~~~~~~~~~~~~~~~~~~ - This module contains tests that use HPACK header tuples that provide additional metadata to the hpack module about how to encode the headers. """ -import pytest +from __future__ import annotations +import pytest from hpack import HeaderTuple, NeverIndexedHeaderTuple import h2.config import h2.connection -def assert_header_blocks_actually_equal(block_a, block_b): +def assert_header_blocks_actually_equal(block_a, block_b) -> None: """ Asserts that two header bocks are really, truly equal, down to the types of their tuples. Doesn't return anything. @@ -26,70 +23,71 @@ def assert_header_blocks_actually_equal(block_a, block_b): assert a.__class__ is b.__class__ -class TestHeaderIndexing(object): +class TestHeaderIndexing: """ Test that Hyper-h2 can correctly handle never indexed header fields using the appropriate hpack data structures. """ + example_request_headers = [ - HeaderTuple(':authority', 'example.com'), - HeaderTuple(':path', '/'), - HeaderTuple(':scheme', 'https'), - HeaderTuple(':method', 'GET'), + HeaderTuple(":authority", "example.com"), + HeaderTuple(":path", "/"), + HeaderTuple(":scheme", "https"), + HeaderTuple(":method", "GET"), ] bytes_example_request_headers = [ - HeaderTuple(b':authority', b'example.com'), - HeaderTuple(b':path', b'/'), - HeaderTuple(b':scheme', b'https'), - HeaderTuple(b':method', b'GET'), + HeaderTuple(b":authority", b"example.com"), + HeaderTuple(b":path", b"/"), + HeaderTuple(b":scheme", b"https"), + HeaderTuple(b":method", b"GET"), ] extended_request_headers = [ - HeaderTuple(':authority', 'example.com'), - HeaderTuple(':path', '/'), - HeaderTuple(':scheme', 'https'), - HeaderTuple(':method', 'GET'), - NeverIndexedHeaderTuple('authorization', 'realpassword'), + HeaderTuple(":authority", "example.com"), + HeaderTuple(":path", "/"), + HeaderTuple(":scheme", "https"), + HeaderTuple(":method", "GET"), + NeverIndexedHeaderTuple("authorization", "realpassword"), ] bytes_extended_request_headers = [ - HeaderTuple(b':authority', b'example.com'), - HeaderTuple(b':path', b'/'), - HeaderTuple(b':scheme', b'https'), - HeaderTuple(b':method', b'GET'), - NeverIndexedHeaderTuple(b'authorization', b'realpassword'), + HeaderTuple(b":authority", b"example.com"), + HeaderTuple(b":path", b"/"), + HeaderTuple(b":scheme", b"https"), + HeaderTuple(b":method", b"GET"), + NeverIndexedHeaderTuple(b"authorization", b"realpassword"), ] example_response_headers = [ - HeaderTuple(':status', '200'), - HeaderTuple('server', 'fake-serv/0.1.0') + HeaderTuple(":status", "200"), + HeaderTuple("server", "fake-serv/0.1.0"), ] bytes_example_response_headers = [ - HeaderTuple(b':status', b'200'), - HeaderTuple(b'server', b'fake-serv/0.1.0') + HeaderTuple(b":status", b"200"), + HeaderTuple(b"server", b"fake-serv/0.1.0"), ] extended_response_headers = [ - HeaderTuple(':status', '200'), - HeaderTuple('server', 'fake-serv/0.1.0'), - NeverIndexedHeaderTuple('secure', 'you-bet'), + HeaderTuple(":status", "200"), + HeaderTuple("server", "fake-serv/0.1.0"), + NeverIndexedHeaderTuple("secure", "you-bet"), ] bytes_extended_response_headers = [ - HeaderTuple(b':status', b'200'), - HeaderTuple(b'server', b'fake-serv/0.1.0'), - NeverIndexedHeaderTuple(b'secure', b'you-bet'), + HeaderTuple(b":status", b"200"), + HeaderTuple(b"server", b"fake-serv/0.1.0"), + NeverIndexedHeaderTuple(b"secure", b"you-bet"), ] server_config = h2.config.H2Configuration(client_side=False) @pytest.mark.parametrize( - 'headers', ( + "headers", [ example_request_headers, bytes_example_request_headers, extended_request_headers, bytes_extended_request_headers, - ) + ], ) - def test_sending_header_tuples(self, headers, frame_factory): + def test_sending_header_tuples(self, headers, frame_factory) -> None: """ Providing HeaderTuple and HeaderTuple subclasses preserves the metadata about indexing. @@ -105,14 +103,14 @@ def test_sending_header_tuples(self, headers, frame_factory): assert c.data_to_send() == f.serialize() @pytest.mark.parametrize( - 'headers', ( + "headers", [ example_request_headers, bytes_example_request_headers, extended_request_headers, bytes_extended_request_headers, - ) + ], ) - def test_header_tuples_in_pushes(self, headers, frame_factory): + def test_header_tuples_in_pushes(self, headers, frame_factory) -> None: """ Providing HeaderTuple and HeaderTuple subclasses to push promises preserves metadata about indexing. @@ -122,7 +120,7 @@ def test_header_tuples_in_pushes(self, headers, frame_factory): # We can use normal headers for the request. f = frame_factory.build_headers_frame( - self.example_request_headers + self.example_request_headers, ) c.receive_data(f.serialize()) @@ -131,36 +129,36 @@ def test_header_tuples_in_pushes(self, headers, frame_factory): stream_id=1, promised_stream_id=2, headers=headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) c.clear_outbound_data_buffer() c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=headers + request_headers=headers, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize( - 'headers,encoding', ( - (example_request_headers, 'utf-8'), + ("headers", "encoding"), [ + (example_request_headers, "utf-8"), (bytes_example_request_headers, None), - (extended_request_headers, 'utf-8'), + (extended_request_headers, "utf-8"), (bytes_extended_request_headers, None), - ) + ], ) def test_header_tuples_are_decoded_request(self, headers, encoding, - frame_factory): + frame_factory) -> None: """ The indexing status of the header is preserved when emitting RequestReceived events. """ config = h2.config.H2Configuration( - client_side=False, header_encoding=encoding + client_side=False, header_encoding=encoding, ) c = h2.connection.H2Connection(config=config) c.receive_data(frame_factory.preamble()) @@ -176,23 +174,23 @@ def test_header_tuples_are_decoded_request(self, assert_header_blocks_actually_equal(headers, event.headers) @pytest.mark.parametrize( - 'headers,encoding', ( - (example_response_headers, 'utf-8'), + ("headers", "encoding"), [ + (example_response_headers, "utf-8"), (bytes_example_response_headers, None), - (extended_response_headers, 'utf-8'), + (extended_response_headers, "utf-8"), (bytes_extended_response_headers, None), - ) + ], ) def test_header_tuples_are_decoded_response(self, headers, encoding, - frame_factory): + frame_factory) -> None: """ The indexing status of the header is preserved when emitting ResponseReceived events. """ config = h2.config.H2Configuration( - header_encoding=encoding + header_encoding=encoding, ) c = h2.connection.H2Connection(config=config) c.initiate_connection() @@ -209,17 +207,17 @@ def test_header_tuples_are_decoded_response(self, assert_header_blocks_actually_equal(headers, event.headers) @pytest.mark.parametrize( - 'headers,encoding', ( - (example_response_headers, 'utf-8'), + ("headers", "encoding"), [ + (example_response_headers, "utf-8"), (bytes_example_response_headers, None), - (extended_response_headers, 'utf-8'), + (extended_response_headers, "utf-8"), (bytes_extended_response_headers, None), - ) + ], ) def test_header_tuples_are_decoded_info_response(self, headers, encoding, - frame_factory): + frame_factory) -> None: """ The indexing status of the header is preserved when emitting InformationalResponseReceived events. @@ -228,12 +226,12 @@ def test_header_tuples_are_decoded_info_response(self, # to avoid breaking the example headers. headers = headers[:] if encoding: - headers[0] = HeaderTuple(':status', '100') + headers[0] = HeaderTuple(":status", "100") else: - headers[0] = HeaderTuple(b':status', b'100') + headers[0] = HeaderTuple(b":status", b"100") config = h2.config.H2Configuration( - header_encoding=encoding + header_encoding=encoding, ) c = h2.connection.H2Connection(config=config) c.initiate_connection() @@ -250,17 +248,17 @@ def test_header_tuples_are_decoded_info_response(self, assert_header_blocks_actually_equal(headers, event.headers) @pytest.mark.parametrize( - 'headers,encoding', ( - (example_response_headers, 'utf-8'), + ("headers", "encoding"), [ + (example_response_headers, "utf-8"), (bytes_example_response_headers, None), - (extended_response_headers, 'utf-8'), + (extended_response_headers, "utf-8"), (bytes_extended_response_headers, None), - ) + ], ) def test_header_tuples_are_decoded_trailers(self, headers, encoding, - frame_factory): + frame_factory) -> None: """ The indexing status of the header is preserved when emitting TrailersReceived events. @@ -271,7 +269,7 @@ def test_header_tuples_are_decoded_trailers(self, headers = headers[1:] config = h2.config.H2Configuration( - header_encoding=encoding + header_encoding=encoding, ) c = h2.connection.H2Connection(config=config) c.initiate_connection() @@ -280,7 +278,7 @@ def test_header_tuples_are_decoded_trailers(self, data = f.serialize() c.receive_data(data) - f = frame_factory.build_headers_frame(headers, flags=['END_STREAM']) + f = frame_factory.build_headers_frame(headers, flags=["END_STREAM"]) data = f.serialize() events = c.receive_data(data) @@ -291,23 +289,23 @@ def test_header_tuples_are_decoded_trailers(self, assert_header_blocks_actually_equal(headers, event.headers) @pytest.mark.parametrize( - 'headers,encoding', ( - (example_request_headers, 'utf-8'), + ("headers", "encoding"), [ + (example_request_headers, "utf-8"), (bytes_example_request_headers, None), - (extended_request_headers, 'utf-8'), + (extended_request_headers, "utf-8"), (bytes_extended_request_headers, None), - ) + ], ) def test_header_tuples_are_decoded_push_promise(self, headers, encoding, - frame_factory): + frame_factory) -> None: """ The indexing status of the header is preserved when emitting PushedStreamReceived events. """ config = h2.config.H2Configuration( - header_encoding=encoding + header_encoding=encoding, ) c = h2.connection.H2Connection(config=config) c.initiate_connection() @@ -317,7 +315,7 @@ def test_header_tuples_are_decoded_push_promise(self, stream_id=1, promised_stream_id=2, headers=headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) data = f.serialize() events = c.receive_data(data) @@ -329,116 +327,115 @@ def test_header_tuples_are_decoded_push_promise(self, assert_header_blocks_actually_equal(headers, event.headers) -class TestSecureHeaders(object): +class TestSecureHeaders: """ Certain headers should always be transformed to their never-indexed form. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] bytes_example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] possible_auth_headers = [ - ('authorization', 'test'), - ('Authorization', 'test'), - ('authorization', 'really long test'), - HeaderTuple('authorization', 'test'), - HeaderTuple('Authorization', 'test'), - HeaderTuple('authorization', 'really long test'), - NeverIndexedHeaderTuple('authorization', 'test'), - NeverIndexedHeaderTuple('Authorization', 'test'), - NeverIndexedHeaderTuple('authorization', 'really long test'), - (b'authorization', b'test'), - (b'Authorization', b'test'), - (b'authorization', b'really long test'), - HeaderTuple(b'authorization', b'test'), - HeaderTuple(b'Authorization', b'test'), - HeaderTuple(b'authorization', b'really long test'), - NeverIndexedHeaderTuple(b'authorization', b'test'), - NeverIndexedHeaderTuple(b'Authorization', b'test'), - NeverIndexedHeaderTuple(b'authorization', b'really long test'), - ('proxy-authorization', 'test'), - ('Proxy-Authorization', 'test'), - ('proxy-authorization', 'really long test'), - HeaderTuple('proxy-authorization', 'test'), - HeaderTuple('Proxy-Authorization', 'test'), - HeaderTuple('proxy-authorization', 'really long test'), - NeverIndexedHeaderTuple('proxy-authorization', 'test'), - NeverIndexedHeaderTuple('Proxy-Authorization', 'test'), - NeverIndexedHeaderTuple('proxy-authorization', 'really long test'), - (b'proxy-authorization', b'test'), - (b'Proxy-Authorization', b'test'), - (b'proxy-authorization', b'really long test'), - HeaderTuple(b'proxy-authorization', b'test'), - HeaderTuple(b'Proxy-Authorization', b'test'), - HeaderTuple(b'proxy-authorization', b'really long test'), - NeverIndexedHeaderTuple(b'proxy-authorization', b'test'), - NeverIndexedHeaderTuple(b'Proxy-Authorization', b'test'), - NeverIndexedHeaderTuple(b'proxy-authorization', b'really long test'), + ("authorization", "test"), + ("Authorization", "test"), + ("authorization", "really long test"), + HeaderTuple("authorization", "test"), + HeaderTuple("Authorization", "test"), + HeaderTuple("authorization", "really long test"), + NeverIndexedHeaderTuple("authorization", "test"), + NeverIndexedHeaderTuple("Authorization", "test"), + NeverIndexedHeaderTuple("authorization", "really long test"), + (b"authorization", b"test"), + (b"Authorization", b"test"), + (b"authorization", b"really long test"), + HeaderTuple(b"authorization", b"test"), + HeaderTuple(b"Authorization", b"test"), + HeaderTuple(b"authorization", b"really long test"), + NeverIndexedHeaderTuple(b"authorization", b"test"), + NeverIndexedHeaderTuple(b"Authorization", b"test"), + NeverIndexedHeaderTuple(b"authorization", b"really long test"), + ("proxy-authorization", "test"), + ("Proxy-Authorization", "test"), + ("proxy-authorization", "really long test"), + HeaderTuple("proxy-authorization", "test"), + HeaderTuple("Proxy-Authorization", "test"), + HeaderTuple("proxy-authorization", "really long test"), + NeverIndexedHeaderTuple("proxy-authorization", "test"), + NeverIndexedHeaderTuple("Proxy-Authorization", "test"), + NeverIndexedHeaderTuple("proxy-authorization", "really long test"), + (b"proxy-authorization", b"test"), + (b"Proxy-Authorization", b"test"), + (b"proxy-authorization", b"really long test"), + HeaderTuple(b"proxy-authorization", b"test"), + HeaderTuple(b"Proxy-Authorization", b"test"), + HeaderTuple(b"proxy-authorization", b"really long test"), + NeverIndexedHeaderTuple(b"proxy-authorization", b"test"), + NeverIndexedHeaderTuple(b"Proxy-Authorization", b"test"), + NeverIndexedHeaderTuple(b"proxy-authorization", b"really long test"), ] secured_cookie_headers = [ - ('cookie', 'short'), - ('Cookie', 'short'), - ('cookie', 'nineteen byte cooki'), - HeaderTuple('cookie', 'short'), - HeaderTuple('Cookie', 'short'), - HeaderTuple('cookie', 'nineteen byte cooki'), - NeverIndexedHeaderTuple('cookie', 'short'), - NeverIndexedHeaderTuple('Cookie', 'short'), - NeverIndexedHeaderTuple('cookie', 'nineteen byte cooki'), - NeverIndexedHeaderTuple('cookie', 'longer manually secured cookie'), - (b'cookie', b'short'), - (b'Cookie', b'short'), - (b'cookie', b'nineteen byte cooki'), - HeaderTuple(b'cookie', b'short'), - HeaderTuple(b'Cookie', b'short'), - HeaderTuple(b'cookie', b'nineteen byte cooki'), - NeverIndexedHeaderTuple(b'cookie', b'short'), - NeverIndexedHeaderTuple(b'Cookie', b'short'), - NeverIndexedHeaderTuple(b'cookie', b'nineteen byte cooki'), - NeverIndexedHeaderTuple(b'cookie', b'longer manually secured cookie'), + ("cookie", "short"), + ("Cookie", "short"), + ("cookie", "nineteen byte cooki"), + HeaderTuple("cookie", "short"), + HeaderTuple("Cookie", "short"), + HeaderTuple("cookie", "nineteen byte cooki"), + NeverIndexedHeaderTuple("cookie", "short"), + NeverIndexedHeaderTuple("Cookie", "short"), + NeverIndexedHeaderTuple("cookie", "nineteen byte cooki"), + NeverIndexedHeaderTuple("cookie", "longer manually secured cookie"), + (b"cookie", b"short"), + (b"Cookie", b"short"), + (b"cookie", b"nineteen byte cooki"), + HeaderTuple(b"cookie", b"short"), + HeaderTuple(b"Cookie", b"short"), + HeaderTuple(b"cookie", b"nineteen byte cooki"), + NeverIndexedHeaderTuple(b"cookie", b"short"), + NeverIndexedHeaderTuple(b"Cookie", b"short"), + NeverIndexedHeaderTuple(b"cookie", b"nineteen byte cooki"), + NeverIndexedHeaderTuple(b"cookie", b"longer manually secured cookie"), ] unsecured_cookie_headers = [ - ('cookie', 'twenty byte cookie!!'), - ('Cookie', 'twenty byte cookie!!'), - ('cookie', 'substantially longer than 20 byte cookie'), - HeaderTuple('cookie', 'twenty byte cookie!!'), - HeaderTuple('cookie', 'twenty byte cookie!!'), - HeaderTuple('Cookie', 'twenty byte cookie!!'), - (b'cookie', b'twenty byte cookie!!'), - (b'Cookie', b'twenty byte cookie!!'), - (b'cookie', b'substantially longer than 20 byte cookie'), - HeaderTuple(b'cookie', b'twenty byte cookie!!'), - HeaderTuple(b'cookie', b'twenty byte cookie!!'), - HeaderTuple(b'Cookie', b'twenty byte cookie!!'), + ("cookie", "twenty byte cookie!!"), + ("Cookie", "twenty byte cookie!!"), + ("cookie", "substantially longer than 20 byte cookie"), + HeaderTuple("cookie", "twenty byte cookie!!"), + HeaderTuple("cookie", "twenty byte cookie!!"), + HeaderTuple("Cookie", "twenty byte cookie!!"), + (b"cookie", b"twenty byte cookie!!"), + (b"Cookie", b"twenty byte cookie!!"), + (b"cookie", b"substantially longer than 20 byte cookie"), + HeaderTuple(b"cookie", b"twenty byte cookie!!"), + HeaderTuple(b"cookie", b"twenty byte cookie!!"), + HeaderTuple(b"Cookie", b"twenty byte cookie!!"), ] server_config = h2.config.H2Configuration(client_side=False) @pytest.mark.parametrize( - 'headers', (example_request_headers, bytes_example_request_headers) + "headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('auth_header', possible_auth_headers) + @pytest.mark.parametrize("auth_header", possible_auth_headers) def test_authorization_headers_never_indexed(self, headers, auth_header, - frame_factory): + frame_factory) -> None: """ Authorization and Proxy-Authorization headers are always forced to be never-indexed, regardless of their form. """ # Regardless of what we send, we expect it to be never indexed. - send_headers = headers + [auth_header] - expected_headers = headers + [ - NeverIndexedHeaderTuple(auth_header[0].lower(), auth_header[1]) - ] + send_headers = [*headers, auth_header] + expected_headers = [*headers, NeverIndexedHeaderTuple(auth_header[0].lower(), auth_header[1])] c = h2.connection.H2Connection() c.initiate_connection() @@ -451,29 +448,27 @@ def test_authorization_headers_never_indexed(self, assert c.data_to_send() == f.serialize() @pytest.mark.parametrize( - 'headers', (example_request_headers, bytes_example_request_headers) + "headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('auth_header', possible_auth_headers) + @pytest.mark.parametrize("auth_header", possible_auth_headers) def test_authorization_headers_never_indexed_push(self, headers, auth_header, - frame_factory): + frame_factory) -> None: """ Authorization and Proxy-Authorization headers are always forced to be never-indexed, regardless of their form, when pushed by a server. """ # Regardless of what we send, we expect it to be never indexed. - send_headers = headers + [auth_header] - expected_headers = headers + [ - NeverIndexedHeaderTuple(auth_header[0].lower(), auth_header[1]) - ] + send_headers = [*headers, auth_header] + expected_headers = [*headers, NeverIndexedHeaderTuple(auth_header[0].lower(), auth_header[1])] c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) # We can use normal headers for the request. f = frame_factory.build_headers_frame( - self.example_request_headers + self.example_request_headers, ) c.receive_data(f.serialize()) @@ -482,35 +477,33 @@ def test_authorization_headers_never_indexed_push(self, stream_id=1, promised_stream_id=2, headers=expected_headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) c.clear_outbound_data_buffer() c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=send_headers + request_headers=send_headers, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize( - 'headers', (example_request_headers, bytes_example_request_headers) + "headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('cookie_header', secured_cookie_headers) + @pytest.mark.parametrize("cookie_header", secured_cookie_headers) def test_short_cookie_headers_never_indexed(self, headers, cookie_header, - frame_factory): + frame_factory) -> None: """ Short cookie headers, and cookies provided as NeverIndexedHeaderTuple, are never indexed. """ # Regardless of what we send, we expect it to be never indexed. - send_headers = headers + [cookie_header] - expected_headers = headers + [ - NeverIndexedHeaderTuple(cookie_header[0].lower(), cookie_header[1]) - ] + send_headers = [*headers, cookie_header] + expected_headers = [*headers, NeverIndexedHeaderTuple(cookie_header[0].lower(), cookie_header[1])] c = h2.connection.H2Connection() c.initiate_connection() @@ -523,29 +516,27 @@ def test_short_cookie_headers_never_indexed(self, assert c.data_to_send() == f.serialize() @pytest.mark.parametrize( - 'headers', (example_request_headers, bytes_example_request_headers) + "headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('cookie_header', secured_cookie_headers) + @pytest.mark.parametrize("cookie_header", secured_cookie_headers) def test_short_cookie_headers_never_indexed_push(self, headers, cookie_header, - frame_factory): + frame_factory) -> None: """ Short cookie headers, and cookies provided as NeverIndexedHeaderTuple, are never indexed when pushed by servers. """ # Regardless of what we send, we expect it to be never indexed. - send_headers = headers + [cookie_header] - expected_headers = headers + [ - NeverIndexedHeaderTuple(cookie_header[0].lower(), cookie_header[1]) - ] + send_headers = [*headers, cookie_header] + expected_headers = [*headers, NeverIndexedHeaderTuple(cookie_header[0].lower(), cookie_header[1])] c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) # We can use normal headers for the request. f = frame_factory.build_headers_frame( - self.example_request_headers + self.example_request_headers, ) c.receive_data(f.serialize()) @@ -554,34 +545,32 @@ def test_short_cookie_headers_never_indexed_push(self, stream_id=1, promised_stream_id=2, headers=expected_headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) c.clear_outbound_data_buffer() c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=send_headers + request_headers=send_headers, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize( - 'headers', (example_request_headers, bytes_example_request_headers) + "headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('cookie_header', unsecured_cookie_headers) + @pytest.mark.parametrize("cookie_header", unsecured_cookie_headers) def test_long_cookie_headers_can_be_indexed(self, headers, cookie_header, - frame_factory): + frame_factory) -> None: """ Longer cookie headers can be indexed. """ # Regardless of what we send, we expect it to be indexed. - send_headers = headers + [cookie_header] - expected_headers = headers + [ - HeaderTuple(cookie_header[0].lower(), cookie_header[1]) - ] + send_headers = [*headers, cookie_header] + expected_headers = [*headers, HeaderTuple(cookie_header[0].lower(), cookie_header[1])] c = h2.connection.H2Connection() c.initiate_connection() @@ -594,28 +583,26 @@ def test_long_cookie_headers_can_be_indexed(self, assert c.data_to_send() == f.serialize() @pytest.mark.parametrize( - 'headers', (example_request_headers, bytes_example_request_headers) + "headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('cookie_header', unsecured_cookie_headers) + @pytest.mark.parametrize("cookie_header", unsecured_cookie_headers) def test_long_cookie_headers_can_be_indexed_push(self, headers, cookie_header, - frame_factory): + frame_factory) -> None: """ Longer cookie headers can be indexed. """ # Regardless of what we send, we expect it to be never indexed. - send_headers = headers + [cookie_header] - expected_headers = headers + [ - HeaderTuple(cookie_header[0].lower(), cookie_header[1]) - ] + send_headers = [*headers, cookie_header] + expected_headers = [*headers, HeaderTuple(cookie_header[0].lower(), cookie_header[1])] c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) # We can use normal headers for the request. f = frame_factory.build_headers_frame( - self.example_request_headers + self.example_request_headers, ) c.receive_data(f.serialize()) @@ -624,14 +611,14 @@ def test_long_cookie_headers_can_be_indexed_push(self, stream_id=1, promised_stream_id=2, headers=expected_headers, - flags=['END_HEADERS'], + flags=["END_HEADERS"], ) c.clear_outbound_data_buffer() c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=send_headers + request_headers=send_headers, ) assert c.data_to_send() == expected_frame.serialize() diff --git a/tests/test_informational_responses.py b/tests/test_informational_responses.py index 64d3a6e7..bae9c74f 100644 --- a/tests/test_informational_responses.py +++ b/tests/test_informational_responses.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -*- """ -test_informational_responses -~~~~~~~~~~~~~~~~~~~~~~~~~~ - Tests that validate that hyper-h2 correctly handles informational (1XX) responses in its state machine. """ +from __future__ import annotations + import pytest import h2.config @@ -14,31 +12,32 @@ import h2.exceptions -class TestReceivingInformationalResponses(object): +class TestReceivingInformationalResponses: """ Tests for receiving informational responses. """ + example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), - (b'expect', b'100-continue'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), + (b"expect", b"100-continue"), ] example_informational_headers = [ - (b':status', b'100'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"100"), + (b"server", b"fake-serv/0.1.0"), ] example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), ] example_trailers = [ - (b'trailer', b'you-bet'), + (b"trailer", b"you-bet"), ] - @pytest.mark.parametrize('end_stream', (True, False)) - def test_single_informational_response(self, frame_factory, end_stream): + @pytest.mark.parametrize("end_stream", [True, False]) + def test_single_informational_response(self, frame_factory, end_stream) -> None: """ When receiving a informational response, the appropriate event is signaled. @@ -48,7 +47,7 @@ def test_single_informational_response(self, frame_factory, end_stream): c.send_headers( stream_id=1, headers=self.example_request_headers, - end_stream=end_stream + end_stream=end_stream, ) f = frame_factory.build_headers_frame( @@ -64,8 +63,8 @@ def test_single_informational_response(self, frame_factory, end_stream): assert event.headers == self.example_informational_headers assert event.stream_id == 1 - @pytest.mark.parametrize('end_stream', (True, False)) - def test_receiving_multiple_header_blocks(self, frame_factory, end_stream): + @pytest.mark.parametrize("end_stream", [True, False]) + def test_receiving_multiple_header_blocks(self, frame_factory, end_stream) -> None: """ At least three header blocks can be received: informational, headers, trailers. @@ -75,7 +74,7 @@ def test_receiving_multiple_header_blocks(self, frame_factory, end_stream): c.send_headers( stream_id=1, headers=self.example_request_headers, - end_stream=end_stream + end_stream=end_stream, ) f1 = frame_factory.build_headers_frame( @@ -89,10 +88,10 @@ def test_receiving_multiple_header_blocks(self, frame_factory, end_stream): f3 = frame_factory.build_headers_frame( headers=self.example_trailers, stream_id=1, - flags=['END_STREAM'], + flags=["END_STREAM"], ) events = c.receive_data( - f1.serialize() + f2.serialize() + f3.serialize() + f1.serialize() + f2.serialize() + f3.serialize(), ) assert len(events) == 4 @@ -109,10 +108,10 @@ def test_receiving_multiple_header_blocks(self, frame_factory, end_stream): assert events[2].headers == self.example_trailers assert events[2].stream_id == 1 - @pytest.mark.parametrize('end_stream', (True, False)) + @pytest.mark.parametrize("end_stream", [True, False]) def test_receiving_multiple_informational_responses(self, frame_factory, - end_stream): + end_stream) -> None: """ More than one informational response is allowed. """ @@ -121,7 +120,7 @@ def test_receiving_multiple_informational_responses(self, c.send_headers( stream_id=1, headers=self.example_request_headers, - end_stream=end_stream + end_stream=end_stream, ) f1 = frame_factory.build_headers_frame( @@ -129,7 +128,7 @@ def test_receiving_multiple_informational_responses(self, stream_id=1, ) f2 = frame_factory.build_headers_frame( - headers=[(':status', '101')], + headers=[(":status", "101")], stream_id=1, ) events = c.receive_data(f1.serialize() + f2.serialize()) @@ -141,13 +140,13 @@ def test_receiving_multiple_informational_responses(self, assert events[0].stream_id == 1 assert isinstance(events[1], h2.events.InformationalResponseReceived) - assert events[1].headers == [(b':status', b'101')] + assert events[1].headers == [(b":status", b"101")] assert events[1].stream_id == 1 - @pytest.mark.parametrize('end_stream', (True, False)) + @pytest.mark.parametrize("end_stream", [True, False]) def test_receive_provisional_response_with_end_stream(self, frame_factory, - end_stream): + end_stream) -> None: """ Receiving provisional responses with END_STREAM set causes ProtocolErrors. @@ -157,14 +156,14 @@ def test_receive_provisional_response_with_end_stream(self, c.send_headers( stream_id=1, headers=self.example_request_headers, - end_stream=end_stream + end_stream=end_stream, ) c.clear_outbound_data_buffer() f = frame_factory.build_headers_frame( headers=self.example_informational_headers, stream_id=1, - flags=['END_STREAM'] + flags=["END_STREAM"], ) with pytest.raises(h2.exceptions.ProtocolError): @@ -176,8 +175,8 @@ def test_receive_provisional_response_with_end_stream(self, ) assert c.data_to_send() == expected.serialize() - @pytest.mark.parametrize('end_stream', (True, False)) - def test_receiving_out_of_order_headers(self, frame_factory, end_stream): + @pytest.mark.parametrize("end_stream", [True, False]) + def test_receiving_out_of_order_headers(self, frame_factory, end_stream) -> None: """ When receiving a informational response after the actual response headers we consider it a ProtocolError and raise it. @@ -187,7 +186,7 @@ def test_receiving_out_of_order_headers(self, frame_factory, end_stream): c.send_headers( stream_id=1, headers=self.example_request_headers, - end_stream=end_stream + end_stream=end_stream, ) f1 = frame_factory.build_headers_frame( @@ -211,53 +210,54 @@ def test_receiving_out_of_order_headers(self, frame_factory, end_stream): assert c.data_to_send() == expected.serialize() -class TestSendingInformationalResponses(object): +class TestSendingInformationalResponses: """ Tests for sending informational responses. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), - ('expect', '100-continue'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), + ("expect", "100-continue"), ] bytes_example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), - (b'expect', b'100-continue'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), + (b"expect", b"100-continue"), ] informational_headers = [ - (':status', '100'), - ('server', 'fake-serv/0.1.0') + (":status", "100"), + ("server", "fake-serv/0.1.0"), ] bytes_informational_headers = [ - (b':status', b'100'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"100"), + (b"server", b"fake-serv/0.1.0"), ] example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), ] example_trailers = [ - (b'trailer', b'you-bet'), + (b"trailer", b"you-bet"), ] server_config = h2.config.H2Configuration(client_side=False) @pytest.mark.parametrize( - 'hdrs', (informational_headers, bytes_informational_headers), + "hdrs", [informational_headers, bytes_informational_headers], ) @pytest.mark.parametrize( - 'request_headers', (example_request_headers, bytes_example_request_headers), + "request_headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('end_stream', (True, False)) + @pytest.mark.parametrize("end_stream", [True, False]) def test_single_informational_response(self, frame_factory, hdrs, request_headers, - end_stream): + end_stream) -> None: """ When sending a informational response, the appropriate frames are emitted. @@ -265,7 +265,7 @@ def test_single_informational_response(self, c = h2.connection.H2Connection(config=self.server_config) c.initiate_connection() c.receive_data(frame_factory.preamble()) - flags = ['END_STREAM'] if end_stream else [] + flags = ["END_STREAM"] if end_stream else [] f = frame_factory.build_headers_frame( headers=request_headers, stream_id=1, @@ -277,7 +277,7 @@ def test_single_informational_response(self, c.send_headers( stream_id=1, - headers=hdrs + headers=hdrs, ) f = frame_factory.build_headers_frame( @@ -287,17 +287,17 @@ def test_single_informational_response(self, assert c.data_to_send() == f.serialize() @pytest.mark.parametrize( - 'hdrs', (informational_headers, bytes_informational_headers), + "hdrs", [informational_headers, bytes_informational_headers], ) @pytest.mark.parametrize( - 'request_headers', (example_request_headers, bytes_example_request_headers), + "request_headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('end_stream', (True, False)) + @pytest.mark.parametrize("end_stream", [True, False]) def test_sending_multiple_header_blocks(self, frame_factory, hdrs, request_headers, - end_stream): + end_stream) -> None: """ At least three header blocks can be sent: informational, headers, trailers. @@ -305,7 +305,7 @@ def test_sending_multiple_header_blocks(self, c = h2.connection.H2Connection(config=self.server_config) c.initiate_connection() c.receive_data(frame_factory.preamble()) - flags = ['END_STREAM'] if end_stream else [] + flags = ["END_STREAM"] if end_stream else [] f = frame_factory.build_headers_frame( headers=request_headers, stream_id=1, @@ -318,16 +318,16 @@ def test_sending_multiple_header_blocks(self, # Send the three header blocks. c.send_headers( stream_id=1, - headers=hdrs + headers=hdrs, ) c.send_headers( stream_id=1, - headers=self.example_response_headers + headers=self.example_response_headers, ) c.send_headers( stream_id=1, headers=self.example_trailers, - end_stream=True + end_stream=True, ) # Check that we sent them properly. @@ -342,7 +342,7 @@ def test_sending_multiple_header_blocks(self, f3 = frame_factory.build_headers_frame( headers=self.example_trailers, stream_id=1, - flags=['END_STREAM'] + flags=["END_STREAM"], ) assert ( c.data_to_send() == @@ -350,24 +350,24 @@ def test_sending_multiple_header_blocks(self, ) @pytest.mark.parametrize( - 'hdrs', (informational_headers, bytes_informational_headers), + "hdrs", [informational_headers, bytes_informational_headers], ) @pytest.mark.parametrize( - 'request_headers', (example_request_headers, bytes_example_request_headers), + "request_headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('end_stream', (True, False)) + @pytest.mark.parametrize("end_stream", [True, False]) def test_sending_multiple_informational_responses(self, frame_factory, hdrs, request_headers, - end_stream): + end_stream) -> None: """ More than one informational response is allowed. """ c = h2.connection.H2Connection(config=self.server_config) c.initiate_connection() c.receive_data(frame_factory.preamble()) - flags = ['END_STREAM'] if end_stream else [] + flags = ["END_STREAM"] if end_stream else [] f = frame_factory.build_headers_frame( headers=request_headers, stream_id=1, @@ -384,7 +384,7 @@ def test_sending_multiple_informational_responses(self, ) c.send_headers( stream_id=1, - headers=[(b':status', b'101')] + headers=[(b":status", b"101")], ) # Check we sent them both. @@ -393,23 +393,23 @@ def test_sending_multiple_informational_responses(self, stream_id=1, ) f2 = frame_factory.build_headers_frame( - headers=[(':status', '101')], + headers=[(":status", "101")], stream_id=1, ) assert c.data_to_send() == f1.serialize() + f2.serialize() @pytest.mark.parametrize( - 'hdrs', (informational_headers, bytes_informational_headers), + "hdrs", [informational_headers, bytes_informational_headers], ) @pytest.mark.parametrize( - 'request_headers', (example_request_headers, bytes_example_request_headers), + "request_headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('end_stream', (True, False)) + @pytest.mark.parametrize("end_stream", [True, False]) def test_send_provisional_response_with_end_stream(self, frame_factory, hdrs, request_headers, - end_stream): + end_stream) -> None: """ Sending provisional responses with END_STREAM set causes ProtocolErrors. @@ -417,7 +417,7 @@ def test_send_provisional_response_with_end_stream(self, c = h2.connection.H2Connection(config=self.server_config) c.initiate_connection() c.receive_data(frame_factory.preamble()) - flags = ['END_STREAM'] if end_stream else [] + flags = ["END_STREAM"] if end_stream else [] f = frame_factory.build_headers_frame( headers=request_headers, stream_id=1, @@ -433,17 +433,17 @@ def test_send_provisional_response_with_end_stream(self, ) @pytest.mark.parametrize( - 'hdrs', (informational_headers, bytes_informational_headers), + "hdrs", [informational_headers, bytes_informational_headers], ) @pytest.mark.parametrize( - 'request_headers', (example_request_headers, bytes_example_request_headers), + "request_headers", [example_request_headers, bytes_example_request_headers], ) - @pytest.mark.parametrize('end_stream', (True, False)) + @pytest.mark.parametrize("end_stream", [True, False]) def test_reject_sending_out_of_order_headers(self, frame_factory, hdrs, request_headers, - end_stream): + end_stream) -> None: """ When sending an informational response after the actual response headers we consider it a ProtocolError and raise it. @@ -451,7 +451,7 @@ def test_reject_sending_out_of_order_headers(self, c = h2.connection.H2Connection(config=self.server_config) c.initiate_connection() c.receive_data(frame_factory.preamble()) - flags = ['END_STREAM'] if end_stream else [] + flags = ["END_STREAM"] if end_stream else [] f = frame_factory.build_headers_frame( headers=request_headers, stream_id=1, @@ -461,11 +461,11 @@ def test_reject_sending_out_of_order_headers(self, c.send_headers( stream_id=1, - headers=self.example_response_headers + headers=self.example_response_headers, ) with pytest.raises(h2.exceptions.ProtocolError): c.send_headers( stream_id=1, - headers=hdrs + headers=hdrs, ) diff --git a/tests/test_interacting_stacks.py b/tests/test_interacting_stacks.py index 70f2235b..c657986d 100644 --- a/tests/test_interacting_stacks.py +++ b/tests/test_interacting_stacks.py @@ -1,8 +1,4 @@ -# -*- coding: utf-8 -*- """ -test_interacting_stacks -~~~~~~~~~~~~~~~~~~~~~~~ - These tests run two entities, a client and a server, in parallel threads. These two entities talk to each other, running what amounts to a number of carefully controlled simulations of real flows. @@ -18,7 +14,7 @@ these tests, so that they can be written more easily, as they are remarkably useful. """ -from . import coroutine_tests +from __future__ import annotations import pytest @@ -27,37 +23,40 @@ import h2.events import h2.settings +from . import coroutine_tests + class TestCommunication(coroutine_tests.CoroutineTestCase): """ Test that two communicating state machines can work together. """ + server_config = h2.config.H2Configuration(client_side=False) request_headers = [ - (':method', 'GET'), - (':path', '/'), - (':authority', 'example.com'), - (':scheme', 'https'), - ('user-agent', 'test-client/0.1.0'), + (":method", "GET"), + (":path", "/"), + (":authority", "example.com"), + (":scheme", "https"), + ("user-agent", "test-client/0.1.0"), ] request_headers_bytes = [ - (b':method', b'GET'), - (b':path', b'/'), - (b':authority', b'example.com'), - (b':scheme', b'https'), - (b'user-agent', b'test-client/0.1.0'), + (b":method", b"GET"), + (b":path", b"/"), + (b":authority", b"example.com"), + (b":scheme", b"https"), + (b"user-agent", b"test-client/0.1.0"), ] response_headers = [ - (b':status', b'204'), - (b'server', b'test-server/0.1.0'), - (b'content-length', b'0'), + (b":status", b"204"), + (b"server", b"test-server/0.1.0"), + (b"content-length", b"0"), ] - @pytest.mark.parametrize('request_headers', [request_headers, request_headers_bytes]) - def test_basic_request_response(self, request_headers): + @pytest.mark.parametrize("request_headers", [request_headers, request_headers_bytes]) + def test_basic_request_response(self, request_headers) -> None: """ A request issued by hyper-h2 can be responded to by hyper-h2. """ diff --git a/tests/test_invalid_content_lengths.py b/tests/test_invalid_content_lengths.py index b33e9c6a..39401ea2 100644 --- a/tests/test_invalid_content_lengths.py +++ b/tests/test_invalid_content_lengths.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -*- """ -test_invalid_content_lengths.py -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - This module contains tests that use invalid content lengths, and validates that they fail appropriately. """ +from __future__ import annotations + import pytest import h2.config @@ -15,33 +13,34 @@ import h2.exceptions -class TestInvalidContentLengths(object): +class TestInvalidContentLengths: """ Hyper-h2 raises Protocol Errors when the content-length sent by a remote peer is not valid. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'POST'), - ('content-length', '15'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "POST"), + ("content-length", "15"), ] example_request_headers_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'POST'), - (b'content-length', b'15'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"POST"), + (b"content-length", b"15"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_too_much_data(self, frame_factory, request_headers): + def test_too_much_data(self, frame_factory, request_headers) -> None: """ Remote peers sending data in excess of content-length causes Protocol Errors. @@ -51,13 +50,13 @@ def test_too_much_data(self, frame_factory, request_headers): c.receive_data(frame_factory.preamble()) headers = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) - first_data = frame_factory.build_data_frame(data=b'\x01'*15) + first_data = frame_factory.build_data_frame(data=b"\x01"*15) c.receive_data(headers.serialize() + first_data.serialize()) c.clear_outbound_data_buffer() - second_data = frame_factory.build_data_frame(data=b'\x01') + second_data = frame_factory.build_data_frame(data=b"\x01") with pytest.raises(h2.exceptions.InvalidBodyLengthError) as exp: c.receive_data(second_data.serialize()) @@ -74,7 +73,7 @@ def test_too_much_data(self, frame_factory, request_headers): assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_insufficient_data(self, frame_factory, request_headers): + def test_insufficient_data(self, frame_factory, request_headers) -> None: """ Remote peers sending less data than content-length causes Protocol Errors. @@ -84,15 +83,15 @@ def test_insufficient_data(self, frame_factory, request_headers): c.receive_data(frame_factory.preamble()) headers = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) - first_data = frame_factory.build_data_frame(data=b'\x01'*13) + first_data = frame_factory.build_data_frame(data=b"\x01"*13) c.receive_data(headers.serialize() + first_data.serialize()) c.clear_outbound_data_buffer() second_data = frame_factory.build_data_frame( - data=b'\x01', - flags=['END_STREAM'], + data=b"\x01", + flags=["END_STREAM"], ) with pytest.raises(h2.exceptions.InvalidBodyLengthError) as exp: c.receive_data(second_data.serialize()) @@ -110,7 +109,7 @@ def test_insufficient_data(self, frame_factory, request_headers): assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_insufficient_data_empty_frame(self, frame_factory, request_headers): + def test_insufficient_data_empty_frame(self, frame_factory, request_headers) -> None: """ Remote peers sending less data than content-length where the last data frame is empty causes Protocol Errors. @@ -120,15 +119,15 @@ def test_insufficient_data_empty_frame(self, frame_factory, request_headers): c.receive_data(frame_factory.preamble()) headers = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) - first_data = frame_factory.build_data_frame(data=b'\x01'*14) + first_data = frame_factory.build_data_frame(data=b"\x01"*14) c.receive_data(headers.serialize() + first_data.serialize()) c.clear_outbound_data_buffer() second_data = frame_factory.build_data_frame( - data=b'', - flags=['END_STREAM'], + data=b"", + flags=["END_STREAM"], ) with pytest.raises(h2.exceptions.InvalidBodyLengthError) as exp: c.receive_data(second_data.serialize()) diff --git a/tests/test_invalid_frame_sequences.py b/tests/test_invalid_frame_sequences.py index 81bf07c4..f78e9f14 100644 --- a/tests/test_invalid_frame_sequences.py +++ b/tests/test_invalid_frame_sequences.py @@ -1,11 +1,8 @@ -# -*- coding: utf-8 -*- """ -test_invalid_frame_sequences.py -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -This module contains tests that use invalid frame sequences, and validates that -they fail appropriately. +Tests invalid frame sequences, and validates that they fail appropriately. """ +from __future__ import annotations + import pytest import h2.config @@ -15,32 +12,33 @@ import h2.exceptions -class TestInvalidFrameSequences(object): +class TestInvalidFrameSequences: """ Invalid frame sequences, either sent or received, cause ProtocolErrors to be thrown. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_request_headers_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) client_config = h2.config.H2Configuration(client_side=True) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_cannot_send_on_closed_stream(self, request_headers): + def test_cannot_send_on_closed_stream(self, request_headers) -> None: """ When we've closed a stream locally, we cannot send further data. """ @@ -49,23 +47,23 @@ def test_cannot_send_on_closed_stream(self, request_headers): c.send_headers(1, request_headers, end_stream=True) with pytest.raises(h2.exceptions.ProtocolError): - c.send_data(1, b'some data') + c.send_data(1, b"some data") - def test_missing_preamble_errors(self): + def test_missing_preamble_errors(self) -> None: """ Server side connections require the preamble. """ c = h2.connection.H2Connection(config=self.server_config) encoded_headers_frame = ( - b'\x00\x00\r\x01\x04\x00\x00\x00\x01' - b'A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x87\x82' + b"\x00\x00\r\x01\x04\x00\x00\x00\x01" + b"A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x87\x82" ) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(encoded_headers_frame) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_server_connections_reject_even_streams(self, frame_factory, request_headers): + def test_server_connections_reject_even_streams(self, frame_factory, request_headers) -> None: """ Servers do not allow clients to initiate even-numbered streams. """ @@ -74,14 +72,14 @@ def test_server_connections_reject_even_streams(self, frame_factory, request_hea c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - request_headers, stream_id=2 + request_headers, stream_id=2, ) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(f.serialize()) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_clients_reject_odd_stream_pushes(self, frame_factory, request_headers): + def test_clients_reject_odd_stream_pushes(self, frame_factory, request_headers) -> None: """ Clients do not allow servers to push odd numbered streams. """ @@ -92,14 +90,14 @@ def test_clients_reject_odd_stream_pushes(self, frame_factory, request_headers): f = frame_factory.build_push_promise_frame( stream_id=1, headers=request_headers, - promised_stream_id=3 + promised_stream_id=3, ) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(f.serialize()) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_can_handle_frames_with_invalid_padding(self, frame_factory, request_headers): + def test_can_handle_frames_with_invalid_padding(self, frame_factory, request_headers) -> None: """ Frames with invalid padding cause connection teardown. """ @@ -112,18 +110,18 @@ def test_can_handle_frames_with_invalid_padding(self, frame_factory, request_hea c.clear_outbound_data_buffer() invalid_data_frame = ( - b'\x00\x00\x05\x00\x0b\x00\x00\x00\x01\x06\x54\x65\x73\x74' + b"\x00\x00\x05\x00\x0b\x00\x00\x00\x01\x06\x54\x65\x73\x74" ) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(invalid_data_frame) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=1, error_code=1 + last_stream_id=1, error_code=1, ) assert c.data_to_send() == expected_frame.serialize() - def test_receiving_frames_with_insufficent_size(self, frame_factory): + def test_receiving_frames_with_insufficent_size(self, frame_factory) -> None: """ Frames with not enough data cause connection teardown. """ @@ -133,19 +131,19 @@ def test_receiving_frames_with_insufficent_size(self, frame_factory): c.clear_outbound_data_buffer() invalid_window_update_frame = ( - b'\x00\x00\x03\x08\x00\x00\x00\x00\x00\x00\x00\x02' + b"\x00\x00\x03\x08\x00\x00\x00\x00\x00\x00\x00\x02" ) with pytest.raises(h2.exceptions.FrameDataMissingError): c.receive_data(invalid_window_update_frame) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=0, error_code=h2.errors.ErrorCodes.FRAME_SIZE_ERROR + last_stream_id=0, error_code=h2.errors.ErrorCodes.FRAME_SIZE_ERROR, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_reject_data_on_closed_streams(self, frame_factory, request_headers): + def test_reject_data_on_closed_streams(self, frame_factory, request_headers) -> None: """ When a stream is not open to the remote peer, we reject receiving data frames from them. @@ -156,13 +154,13 @@ def test_reject_data_on_closed_streams(self, frame_factory, request_headers): f = frame_factory.build_headers_frame( request_headers, - flags=['END_STREAM'] + flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() bad_frame = frame_factory.build_data_frame( - data=b'some data' + data=b"some data", ) c.receive_data(bad_frame.serialize()) @@ -173,7 +171,7 @@ def test_reject_data_on_closed_streams(self, frame_factory, request_headers): assert c.data_to_send() == expected @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_unexpected_continuation_on_closed_stream(self, frame_factory, request_headers): + def test_unexpected_continuation_on_closed_stream(self, frame_factory, request_headers) -> None: """ CONTINUATION frames received on closed streams cause connection errors of type PROTOCOL_ERROR. @@ -184,13 +182,13 @@ def test_unexpected_continuation_on_closed_stream(self, frame_factory, request_h f = frame_factory.build_headers_frame( request_headers, - flags=['END_STREAM'] + flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() bad_frame = frame_factory.build_continuation_frame( - header_block=b'hello' + header_block=b"hello", ) with pytest.raises(h2.exceptions.ProtocolError): @@ -198,12 +196,12 @@ def test_unexpected_continuation_on_closed_stream(self, frame_factory, request_h expected_frame = frame_factory.build_goaway_frame( error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, - last_stream_id=1 + last_stream_id=1, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_prevent_continuation_dos(self, frame_factory, request_headers): + def test_prevent_continuation_dos(self, frame_factory, request_headers) -> None: """ Receiving too many CONTINUATION frames in one block causes a protocol error. @@ -215,20 +213,20 @@ def test_prevent_continuation_dos(self, frame_factory, request_headers): f = frame_factory.build_headers_frame( request_headers, ) - f.flags = {'END_STREAM'} + f.flags = {"END_STREAM"} c.receive_data(f.serialize()) c.clear_outbound_data_buffer() # Send 63 additional frames. - for _ in range(0, 63): + for _ in range(63): extra_frame = frame_factory.build_continuation_frame( - header_block=b'hello' + header_block=b"hello", ) c.receive_data(extra_frame.serialize()) # The final continuation frame should cause a protocol error. extra_frame = frame_factory.build_continuation_frame( - header_block=b'hello' + header_block=b"hello", ) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(extra_frame.serialize()) @@ -247,9 +245,9 @@ def test_prevent_continuation_dos(self, frame_factory, request_headers): {0x4: 2**31}, {0x5: 5}, {0x5: 2**24}, - ] + ], ) - def test_reject_invalid_settings_values(self, frame_factory, settings): + def test_reject_invalid_settings_values(self, frame_factory, settings) -> None: """ When a SETTINGS frame is received with invalid settings values it causes connection teardown with the appropriate error code. @@ -269,7 +267,7 @@ def test_reject_invalid_settings_values(self, frame_factory, settings): ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_invalid_frame_headers_are_protocol_errors(self, frame_factory, request_headers): + def test_invalid_frame_headers_are_protocol_errors(self, frame_factory, request_headers) -> None: """ When invalid frame headers are received they cause ProtocolErrors to be raised. @@ -279,7 +277,7 @@ def test_invalid_frame_headers_are_protocol_errors(self, frame_factory, request_ c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) # Do some annoying bit twiddling here: the stream ID is currently set @@ -287,7 +285,7 @@ def test_invalid_frame_headers_are_protocol_errors(self, frame_factory, request_ # replace any instances of the byte '\x01', and then graft it onto the # remaining bytes. frame_data = f.serialize() - frame_data = frame_data[:9].replace(b'\x01', b'\x00') + frame_data[9:] + frame_data = frame_data[:9].replace(b"\x01", b"\x00") + frame_data[9:] with pytest.raises(h2.exceptions.ProtocolError) as e: c.receive_data(frame_data) @@ -295,7 +293,7 @@ def test_invalid_frame_headers_are_protocol_errors(self, frame_factory, request_ assert "Received frame with invalid header" in str(e.value) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_data_before_headers(self, frame_factory, request_headers): + def test_data_before_headers(self, frame_factory, request_headers) -> None: """ When data frames are received before headers they cause ProtocolErrors to be raised. @@ -313,7 +311,7 @@ def test_data_before_headers(self, frame_factory, request_headers): assert "cannot receive data before headers" in str(e.value) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_get_stream_reset_event_on_auto_reset(self, frame_factory, request_headers): + def test_get_stream_reset_event_on_auto_reset(self, frame_factory, request_headers) -> None: """ When hyper-h2 resets a stream automatically, a StreamReset event fires. """ @@ -323,13 +321,13 @@ def test_get_stream_reset_event_on_auto_reset(self, frame_factory, request_heade f = frame_factory.build_headers_frame( request_headers, - flags=['END_STREAM'] + flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() bad_frame = frame_factory.build_data_frame( - data=b'some data' + data=b"some data", ) events = c.receive_data(bad_frame.serialize()) @@ -347,7 +345,7 @@ def test_get_stream_reset_event_on_auto_reset(self, frame_factory, request_heade assert not event.remote_reset @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_one_one_stream_reset(self, frame_factory, request_headers): + def test_one_one_stream_reset(self, frame_factory, request_headers) -> None: """ When hyper-h2 resets a stream automatically, a StreamReset event fires, but only for the first reset: the others are silent. @@ -358,13 +356,13 @@ def test_one_one_stream_reset(self, frame_factory, request_headers): f = frame_factory.build_headers_frame( request_headers, - flags=['END_STREAM'] + flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() bad_frame = frame_factory.build_data_frame( - data=b'some data' + data=b"some data", ) # Receive 5 frames. events = c.receive_data(bad_frame.serialize() * 5) @@ -383,8 +381,8 @@ def test_one_one_stream_reset(self, frame_factory, request_headers): assert not event.remote_reset @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - @pytest.mark.parametrize('value', ['', 'twelve']) - def test_error_on_invalid_content_length(self, frame_factory, value, request_headers): + @pytest.mark.parametrize("value", ["", "twelve"]) + def test_error_on_invalid_content_length(self, frame_factory, value, request_headers) -> None: """ When an invalid content-length is received, a ProtocolError is thrown. """ @@ -395,19 +393,19 @@ def test_error_on_invalid_content_length(self, frame_factory, value, request_hea f = frame_factory.build_headers_frame( stream_id=1, - headers=request_headers + [('content-length', value)] + headers=[*request_headers, ("content-length", value)], ) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(f.serialize()) expected_frame = frame_factory.build_goaway_frame( last_stream_id=1, - error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_invalid_header_data_protocol_error(self, frame_factory, request_headers): + def test_invalid_header_data_protocol_error(self, frame_factory, request_headers) -> None: """ If an invalid header block is received, we raise a ProtocolError. """ @@ -418,21 +416,21 @@ def test_invalid_header_data_protocol_error(self, frame_factory, request_headers f = frame_factory.build_headers_frame( stream_id=1, - headers=request_headers + headers=request_headers, ) - f.data = b'\x00\x00\x00\x00' + f.data = b"\x00\x00\x00\x00" with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(f.serialize()) expected_frame = frame_factory.build_goaway_frame( last_stream_id=0, - error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_invalid_push_promise_data_protocol_error(self, frame_factory, request_headers): + def test_invalid_push_promise_data_protocol_error(self, frame_factory, request_headers) -> None: """ If an invalid header block is received on a PUSH_PROMISE, we raise a ProtocolError. @@ -445,21 +443,21 @@ def test_invalid_push_promise_data_protocol_error(self, frame_factory, request_h f = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, - headers=request_headers + headers=request_headers, ) - f.data = b'\x00\x00\x00\x00' + f.data = b"\x00\x00\x00\x00" with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(f.serialize()) expected_frame = frame_factory.build_goaway_frame( last_stream_id=0, - error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_cannot_receive_push_on_pushed_stream(self, frame_factory, request_headers): + def test_cannot_receive_push_on_pushed_stream(self, frame_factory, request_headers) -> None: """ If a PUSH_PROMISE frame is received with the parent stream ID being a pushed stream, this is rejected with a PROTOCOL_ERROR. @@ -469,7 +467,7 @@ def test_cannot_receive_push_on_pushed_stream(self, frame_factory, request_heade c.send_headers( stream_id=1, headers=request_headers, - end_stream=True + end_stream=True, ) f1 = frame_factory.build_push_promise_frame( @@ -495,12 +493,12 @@ def test_cannot_receive_push_on_pushed_stream(self, frame_factory, request_heade expected_frame = frame_factory.build_goaway_frame( last_stream_id=2, - error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_cannot_send_push_on_pushed_stream(self, frame_factory, request_headers): + def test_cannot_send_push_on_pushed_stream(self, frame_factory, request_headers) -> None: """ If a user tries to send a PUSH_PROMISE frame with the parent stream ID being a pushed stream, this is rejected with a PROTOCOL_ERROR. @@ -509,14 +507,14 @@ def test_cannot_send_push_on_pushed_stream(self, frame_factory, request_headers) c.initiate_connection() c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - stream_id=1, headers=request_headers + stream_id=1, headers=request_headers, ) c.receive_data(f.serialize()) c.push_stream( stream_id=1, promised_stream_id=2, - request_headers=request_headers + request_headers=request_headers, ) c.send_headers(stream_id=2, headers=self.example_response_headers) @@ -524,5 +522,5 @@ def test_cannot_send_push_on_pushed_stream(self, frame_factory, request_headers) c.push_stream( stream_id=2, promised_stream_id=4, - request_headers=request_headers + request_headers=request_headers, ) diff --git a/tests/test_invalid_headers.py b/tests/test_invalid_headers.py index f388cf3f..192ba10d 100644 --- a/tests/test_invalid_headers.py +++ b/tests/test_invalid_headers.py @@ -1,14 +1,14 @@ -# -*- coding: utf-8 -*- """ -test_invalid_headers.py -~~~~~~~~~~~~~~~~~~~~~~~ - -This module contains tests that use invalid header blocks, and validates that -they fail appropriately. +Tests invalid header blocks, and validates that they fail appropriately. """ +from __future__ import annotations + import itertools +import hyperframe.frame import pytest +from hypothesis import given +from hypothesis.strategies import binary, lists, tuples import h2.config import h2.connection @@ -18,53 +18,49 @@ import h2.settings import h2.utilities -import hyperframe.frame - -from hypothesis import given -from hypothesis.strategies import binary, lists, tuples - HEADERS_STRATEGY = lists(tuples(binary(min_size=1), binary())) -class TestInvalidFrameSequences(object): +class TestInvalidFrameSequences: """ Invalid header sequences cause ProtocolErrors to be thrown when received. """ + base_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), - ('user-agent', 'someua/0.0.1'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), + ("user-agent", "someua/0.0.1"), ] base_invalid_header_blocks = [ - base_request_headers + [('Uppercase', 'name')], - base_request_headers + [(':late', 'pseudo-header')], - [(':path', 'duplicate-pseudo-header')] + base_request_headers, - base_request_headers + [('connection', 'close')], - base_request_headers + [('proxy-connection', 'close')], - base_request_headers + [('keep-alive', 'close')], - base_request_headers + [('transfer-encoding', 'gzip')], - base_request_headers + [('upgrade', 'super-protocol/1.1')], - base_request_headers + [('te', 'chunked')], - base_request_headers + [('host', 'notexample.com')], - base_request_headers + [(' name', 'name with leading space')], - base_request_headers + [('name ', 'name with trailing space')], - base_request_headers + [('name', ' value with leading space')], - base_request_headers + [('name', 'value with trailing space ')], + [*base_request_headers, ("Uppercase", "name")], + [*base_request_headers, (":late", "pseudo-header")], + [(":path", "duplicate-pseudo-header"), *base_request_headers], + [*base_request_headers, ("connection", "close")], + [*base_request_headers, ("proxy-connection", "close")], + [*base_request_headers, ("keep-alive", "close")], + [*base_request_headers, ("transfer-encoding", "gzip")], + [*base_request_headers, ("upgrade", "super-protocol/1.1")], + [*base_request_headers, ("te", "chunked")], + [*base_request_headers, ("host", "notexample.com")], + [*base_request_headers, (" name", "name with leading space")], + [*base_request_headers, ("name ", "name with trailing space")], + [*base_request_headers, ("name", " value with leading space")], + [*base_request_headers, ("name", "value with trailing space ")], [header for header in base_request_headers - if header[0] != ':authority'], - [(':protocol', 'websocket')] + base_request_headers, + if header[0] != ":authority"], + [(":protocol", "websocket"), *base_request_headers], ] invalid_header_blocks = base_invalid_header_blocks + [ h2.utilities.utf8_encode_headers(headers) for headers in base_invalid_header_blocks ] server_config = h2.config.H2Configuration( - client_side=False, header_encoding='utf-8' + client_side=False, header_encoding="utf-8", ) - @pytest.mark.parametrize('headers', invalid_header_blocks) - def test_headers_event(self, frame_factory, headers): + @pytest.mark.parametrize("headers", invalid_header_blocks) + def test_headers_event(self, frame_factory, headers) -> None: """ Test invalid headers are rejected with PROTOCOL_ERROR. """ @@ -79,12 +75,12 @@ def test_headers_event(self, frame_factory, headers): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=1, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + last_stream_id=1, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() - @pytest.mark.parametrize('headers', invalid_header_blocks) - def test_push_promise_event(self, frame_factory, headers): + @pytest.mark.parametrize("headers", invalid_header_blocks) + def test_push_promise_event(self, frame_factory, headers) -> None: """ If a PUSH_PROMISE header frame is received with an invalid header block it is rejected with a PROTOCOL_ERROR. @@ -92,14 +88,14 @@ def test_push_promise_event(self, frame_factory, headers): c = h2.connection.H2Connection() c.initiate_connection() c.send_headers( - stream_id=1, headers=self.base_request_headers, end_stream=True + stream_id=1, headers=self.base_request_headers, end_stream=True, ) c.clear_outbound_data_buffer() f = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, - headers=headers + headers=headers, ) data = f.serialize() @@ -107,12 +103,12 @@ def test_push_promise_event(self, frame_factory, headers): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=0, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + last_stream_id=0, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() - @pytest.mark.parametrize('headers', invalid_header_blocks) - def test_push_promise_skipping_validation(self, frame_factory, headers): + @pytest.mark.parametrize("headers", invalid_header_blocks) + def test_push_promise_skipping_validation(self, frame_factory, headers) -> None: """ If we have ``validate_inbound_headers`` disabled, then invalid header blocks in push promise frames are allowed to pass. @@ -125,14 +121,14 @@ def test_push_promise_skipping_validation(self, frame_factory, headers): c = h2.connection.H2Connection(config=config) c.initiate_connection() c.send_headers( - stream_id=1, headers=self.base_request_headers, end_stream=True + stream_id=1, headers=self.base_request_headers, end_stream=True, ) c.clear_outbound_data_buffer() f = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, - headers=headers + headers=headers, ) data = f.serialize() @@ -141,8 +137,8 @@ def test_push_promise_skipping_validation(self, frame_factory, headers): pp_event = events[0] assert pp_event.headers == h2.utilities.utf8_encode_headers(headers) - @pytest.mark.parametrize('headers', invalid_header_blocks) - def test_headers_event_skipping_validation(self, frame_factory, headers): + @pytest.mark.parametrize("headers", invalid_header_blocks) + def test_headers_event_skipping_validation(self, frame_factory, headers) -> None: """ If we have ``validate_inbound_headers`` disabled, then all of these invalid header blocks are allowed to pass. @@ -163,12 +159,12 @@ def test_headers_event_skipping_validation(self, frame_factory, headers): request_event = events[0] assert request_event.headers == h2.utilities.utf8_encode_headers(headers) - def test_te_trailers_is_valid(self, frame_factory): + def test_te_trailers_is_valid(self, frame_factory) -> None: """ `te: trailers` is allowed by the filter. """ headers = ( - self.base_request_headers + [('te', 'trailers')] + [*self.base_request_headers, ("te", "trailers")] ) c = h2.connection.H2Connection(config=self.server_config) @@ -182,21 +178,21 @@ def test_te_trailers_is_valid(self, frame_factory): request_event = events[0] assert request_event.headers == headers - def test_pseudo_headers_rejected_in_trailer(self, frame_factory): + def test_pseudo_headers_rejected_in_trailer(self, frame_factory) -> None: """ Ensure we reject pseudo headers included in trailers """ - trailers = [(':path', '/'), ('extra', 'value')] + trailers = [(":path", "/"), ("extra", "value")] c = h2.connection.H2Connection(config=self.server_config) c.receive_data(frame_factory.preamble()) c.clear_outbound_data_buffer() header_frame = frame_factory.build_headers_frame( - self.base_request_headers + self.base_request_headers, ) trailer_frame = frame_factory.build_headers_frame( - trailers, flags=["END_STREAM"] + trailers, flags=["END_STREAM"], ) head = header_frame.serialize() trailer = trailer_frame.serialize() @@ -209,44 +205,45 @@ def test_pseudo_headers_rejected_in_trailer(self, frame_factory): # Test appropriate response frame is generated expected_frame = frame_factory.build_goaway_frame( - last_stream_id=1, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + last_stream_id=1, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() -class TestSendingInvalidFrameSequences(object): +class TestSendingInvalidFrameSequences: """ Trying to send invalid header sequences cause ProtocolErrors to be thrown. """ + base_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), - ('user-agent', 'someua/0.0.1'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), + ("user-agent", "someua/0.0.1"), ] invalid_header_blocks = [ - base_request_headers + [(':late', 'pseudo-header')], - [(':path', 'duplicate-pseudo-header')] + base_request_headers, - base_request_headers + [('te', 'chunked')], - base_request_headers + [('host', 'notexample.com')], + [*base_request_headers, (":late", "pseudo-header")], + [(":path", "duplicate-pseudo-header"), *base_request_headers], + [*base_request_headers, ("te", "chunked")], + [*base_request_headers, ("host", "notexample.com")], [header for header in base_request_headers - if header[0] != ':authority'], + if header[0] != ":authority"], ] strippable_header_blocks = [ - base_request_headers + [('connection', 'close')], - base_request_headers + [('proxy-connection', 'close')], - base_request_headers + [('keep-alive', 'close')], - base_request_headers + [('transfer-encoding', 'gzip')], - base_request_headers + [('upgrade', 'super-protocol/1.1')] + [*base_request_headers, ("connection", "close")], + [*base_request_headers, ("proxy-connection", "close")], + [*base_request_headers, ("keep-alive", "close")], + [*base_request_headers, ("transfer-encoding", "gzip")], + [*base_request_headers, ("upgrade", "super-protocol/1.1")], ] all_header_blocks = invalid_header_blocks + strippable_header_blocks server_config = h2.config.H2Configuration(client_side=False) - @pytest.mark.parametrize('headers', invalid_header_blocks) - def test_headers_event(self, frame_factory, headers): + @pytest.mark.parametrize("headers", invalid_header_blocks) + def test_headers_event(self, frame_factory, headers) -> None: """ Test sending invalid headers raise a ProtocolError. """ @@ -258,8 +255,8 @@ def test_headers_event(self, frame_factory, headers): with pytest.raises(h2.exceptions.ProtocolError): c.send_headers(1, headers) - @pytest.mark.parametrize('headers', invalid_header_blocks) - def test_send_push_promise(self, frame_factory, headers): + @pytest.mark.parametrize("headers", invalid_header_blocks) + def test_send_push_promise(self, frame_factory, headers) -> None: """ Sending invalid headers in a push promise raises a ProtocolError. """ @@ -268,7 +265,7 @@ def test_send_push_promise(self, frame_factory, headers): c.receive_data(frame_factory.preamble()) header_frame = frame_factory.build_headers_frame( - self.base_request_headers + self.base_request_headers, ) c.receive_data(header_frame.serialize()) @@ -276,17 +273,17 @@ def test_send_push_promise(self, frame_factory, headers): c.clear_outbound_data_buffer() with pytest.raises(h2.exceptions.ProtocolError): c.push_stream( - stream_id=1, promised_stream_id=2, request_headers=headers + stream_id=1, promised_stream_id=2, request_headers=headers, ) - @pytest.mark.parametrize('headers', all_header_blocks) - def test_headers_event_skipping_validation(self, frame_factory, headers): + @pytest.mark.parametrize("headers", all_header_blocks) + def test_headers_event_skipping_validation(self, frame_factory, headers) -> None: """ If we have ``validate_outbound_headers`` disabled, then all of these invalid header blocks are allowed to pass. """ config = h2.config.H2Configuration( - validate_outbound_headers=False + validate_outbound_headers=False, ) c = h2.connection.H2Connection(config=config) @@ -299,13 +296,13 @@ def test_headers_event_skipping_validation(self, frame_factory, headers): # Ensure headers are still normalized. headers = h2.utilities.utf8_encode_headers(headers) norm_headers = h2.utilities.normalize_outbound_headers( - headers, None, False + headers, None, False, ) f = frame_factory.build_headers_frame(norm_headers) assert c.data_to_send() == f.serialize() - @pytest.mark.parametrize('headers', all_header_blocks) - def test_push_promise_skipping_validation(self, frame_factory, headers): + @pytest.mark.parametrize("headers", all_header_blocks) + def test_push_promise_skipping_validation(self, frame_factory, headers) -> None: """ If we have ``validate_outbound_headers`` disabled, then all of these invalid header blocks are allowed to pass. @@ -320,7 +317,7 @@ def test_push_promise_skipping_validation(self, frame_factory, headers): c.receive_data(frame_factory.preamble()) header_frame = frame_factory.build_headers_frame( - self.base_request_headers + self.base_request_headers, ) c.receive_data(header_frame.serialize()) @@ -329,28 +326,28 @@ def test_push_promise_skipping_validation(self, frame_factory, headers): # Create push promise frame with normalized headers. headers = h2.utilities.utf8_encode_headers(headers) norm_headers = h2.utilities.normalize_outbound_headers( - headers, None, False + headers, None, False, ) pp_frame = frame_factory.build_push_promise_frame( - stream_id=1, promised_stream_id=2, headers=norm_headers + stream_id=1, promised_stream_id=2, headers=norm_headers, ) # Clear the data, then send a push promise. c.clear_outbound_data_buffer() c.push_stream( - stream_id=1, promised_stream_id=2, request_headers=headers + stream_id=1, promised_stream_id=2, request_headers=headers, ) assert c.data_to_send() == pp_frame.serialize() - @pytest.mark.parametrize('headers', all_header_blocks) - def test_headers_event_skip_normalization(self, frame_factory, headers): + @pytest.mark.parametrize("headers", all_header_blocks) + def test_headers_event_skip_normalization(self, frame_factory, headers) -> None: """ If we have ``normalize_outbound_headers`` disabled, then all of these invalid header blocks are sent through unmodified. """ config = h2.config.H2Configuration( validate_outbound_headers=False, - normalize_outbound_headers=False + normalize_outbound_headers=False, ) c = h2.connection.H2Connection(config=config) @@ -366,8 +363,8 @@ def test_headers_event_skip_normalization(self, frame_factory, headers): c.send_headers(1, headers) assert c.data_to_send() == f.serialize() - @pytest.mark.parametrize('headers', all_header_blocks) - def test_push_promise_skip_normalization(self, frame_factory, headers): + @pytest.mark.parametrize("headers", all_header_blocks) + def test_push_promise_skip_normalization(self, frame_factory, headers) -> None: """ If we have ``normalize_outbound_headers`` disabled, then all of these invalid header blocks are allowed to pass unmodified. @@ -383,24 +380,24 @@ def test_push_promise_skip_normalization(self, frame_factory, headers): c.receive_data(frame_factory.preamble()) header_frame = frame_factory.build_headers_frame( - self.base_request_headers + self.base_request_headers, ) c.receive_data(header_frame.serialize()) frame_factory.refresh_encoder() pp_frame = frame_factory.build_push_promise_frame( - stream_id=1, promised_stream_id=2, headers=headers + stream_id=1, promised_stream_id=2, headers=headers, ) # Clear the data, then send a push promise. c.clear_outbound_data_buffer() c.push_stream( - stream_id=1, promised_stream_id=2, request_headers=headers + stream_id=1, promised_stream_id=2, request_headers=headers, ) assert c.data_to_send() == pp_frame.serialize() - @pytest.mark.parametrize('headers', strippable_header_blocks) - def test_strippable_headers(self, frame_factory, headers): + @pytest.mark.parametrize("headers", strippable_header_blocks) + def test_strippable_headers(self, frame_factory, headers) -> None: """ Test connection related headers are removed before sending. """ @@ -415,7 +412,7 @@ def test_strippable_headers(self, frame_factory, headers): assert c.data_to_send() == f.serialize() -class TestFilter(object): +class TestFilter: """ Test the filter function directly. @@ -424,14 +421,15 @@ class TestFilter(object): HTTP/2 and so may never hit the function, but it's worth validating that it behaves as expected anyway. """ + validation_functions = [ h2.utilities.validate_headers, - h2.utilities.validate_outbound_headers + h2.utilities.validate_outbound_headers, ] hdr_validation_combos = [ h2.utilities.HeaderValidationFlags( - is_client, is_trailer, is_response_header, is_push_promise + is_client, is_trailer, is_response_header, is_push_promise, ) for is_client, is_trailer, is_response_header, is_push_promise in ( itertools.product([True, False], repeat=4) @@ -451,42 +449,42 @@ class TestFilter(object): invalid_request_header_blocks_bytes = ( # First, missing :method ( - (b':authority', b'google.com'), - (b':path', b'/'), - (b':scheme', b'https'), + (b":authority", b"google.com"), + (b":path", b"/"), + (b":scheme", b"https"), ), # Next, missing :path ( - (b':authority', b'google.com'), - (b':method', b'GET'), - (b':scheme', b'https'), + (b":authority", b"google.com"), + (b":method", b"GET"), + (b":scheme", b"https"), ), # Next, missing :scheme ( - (b':authority', b'google.com'), - (b':method', b'GET'), - (b':path', b'/'), + (b":authority", b"google.com"), + (b":method", b"GET"), + (b":path", b"/"), ), # Finally, path present but empty. ( - (b':authority', b'google.com'), - (b':method', b'GET'), - (b':scheme', b'https'), - (b':path', b''), + (b":authority", b"google.com"), + (b":method", b"GET"), + (b":scheme", b"https"), + (b":path", b""), ), ) # All headers that are forbidden from either request or response blocks. - forbidden_request_headers_bytes = (b':status',) - forbidden_response_headers_bytes = (b':path', b':scheme', b':authority', b':method') + forbidden_request_headers_bytes = (b":status",) + forbidden_response_headers_bytes = (b":path", b":scheme", b":authority", b":method") - @pytest.mark.parametrize('validation_function', validation_functions) - @pytest.mark.parametrize('hdr_validation_flags', hdr_validation_combos) + @pytest.mark.parametrize("validation_function", validation_functions) + @pytest.mark.parametrize("hdr_validation_flags", hdr_validation_combos) @given(headers=HEADERS_STRATEGY) def test_range_of_acceptable_outputs(self, headers, validation_function, - hdr_validation_flags): + hdr_validation_flags) -> None: """ The header validation functions either return the data unchanged or throw a ProtocolError. @@ -497,175 +495,175 @@ def test_range_of_acceptable_outputs(self, except h2.exceptions.ProtocolError: assert True - @pytest.mark.parametrize('hdr_validation_flags', hdr_validation_combos) - def test_invalid_pseudo_headers(self, hdr_validation_flags): - headers = [(b':custom', b'value')] + @pytest.mark.parametrize("hdr_validation_flags", hdr_validation_combos) + def test_invalid_pseudo_headers(self, hdr_validation_flags) -> None: + headers = [(b":custom", b"value")] with pytest.raises(h2.exceptions.ProtocolError): list(h2.utilities.validate_headers(headers, hdr_validation_flags)) - @pytest.mark.parametrize('validation_function', validation_functions) + @pytest.mark.parametrize("validation_function", validation_functions) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + "hdr_validation_flags", hdr_validation_request_headers_no_trailer, ) def test_matching_authority_host_headers(self, validation_function, - hdr_validation_flags): + hdr_validation_flags) -> None: """ If a header block has :authority and Host headers and they match, the headers should pass through unchanged. """ headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), - (b'host', b'example.com'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), + (b"host", b"example.com"), ] assert headers == list(h2.utilities.validate_headers( - headers, hdr_validation_flags + headers, hdr_validation_flags, )) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_response_headers + "hdr_validation_flags", hdr_validation_response_headers, ) - def test_response_header_without_status(self, hdr_validation_flags): - headers = [(b'content-length', b'42')] + def test_response_header_without_status(self, hdr_validation_flags) -> None: + headers = [(b"content-length", b"42")] with pytest.raises(h2.exceptions.ProtocolError): list(h2.utilities.validate_headers(headers, hdr_validation_flags)) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + "hdr_validation_flags", hdr_validation_request_headers_no_trailer, ) @pytest.mark.parametrize( - 'header_block', + "header_block", (invalid_request_header_blocks_bytes), ) def test_outbound_req_header_missing_pseudo_headers(self, hdr_validation_flags, - header_block): + header_block) -> None: with pytest.raises(h2.exceptions.ProtocolError): list( h2.utilities.validate_outbound_headers( - header_block, hdr_validation_flags - ) + header_block, hdr_validation_flags, + ), ) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + "hdr_validation_flags", hdr_validation_request_headers_no_trailer, ) @pytest.mark.parametrize( - 'header_block', invalid_request_header_blocks_bytes + "header_block", invalid_request_header_blocks_bytes, ) def test_inbound_req_header_missing_pseudo_headers(self, hdr_validation_flags, - header_block): + header_block) -> None: with pytest.raises(h2.exceptions.ProtocolError): list( h2.utilities.validate_headers( - header_block, hdr_validation_flags - ) + header_block, hdr_validation_flags, + ), ) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + "hdr_validation_flags", hdr_validation_request_headers_no_trailer, ) @pytest.mark.parametrize( - 'invalid_header', + "invalid_header", forbidden_request_headers_bytes, ) def test_outbound_req_header_extra_pseudo_headers(self, hdr_validation_flags, - invalid_header): + invalid_header) -> None: """ Outbound request header blocks containing the forbidden request headers fail validation. """ headers = [ - (b':path', b'/'), - (b':scheme', b'https'), - (b':authority', b'google.com'), - (b':method', b'GET'), + (b":path", b"/"), + (b":scheme", b"https"), + (b":authority", b"google.com"), + (b":method", b"GET"), ] - headers.append((invalid_header, b'some value')) + headers.append((invalid_header, b"some value")) with pytest.raises(h2.exceptions.ProtocolError): list( h2.utilities.validate_outbound_headers( - headers, hdr_validation_flags - ) + headers, hdr_validation_flags, + ), ) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + "hdr_validation_flags", hdr_validation_request_headers_no_trailer, ) @pytest.mark.parametrize( - 'invalid_header', - forbidden_request_headers_bytes + "invalid_header", + forbidden_request_headers_bytes, ) def test_inbound_req_header_extra_pseudo_headers(self, hdr_validation_flags, - invalid_header): + invalid_header) -> None: """ Inbound request header blocks containing the forbidden request headers fail validation. """ headers = [ - (b':path', b'/'), - (b':scheme', b'https'), - (b':authority', b'google.com'), - (b':method', b'GET'), + (b":path", b"/"), + (b":scheme", b"https"), + (b":authority", b"google.com"), + (b":method", b"GET"), ] - headers.append((invalid_header, b'some value')) + headers.append((invalid_header, b"some value")) with pytest.raises(h2.exceptions.ProtocolError): list(h2.utilities.validate_headers(headers, hdr_validation_flags)) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_response_headers + "hdr_validation_flags", hdr_validation_response_headers, ) @pytest.mark.parametrize( - 'invalid_header', + "invalid_header", forbidden_response_headers_bytes, ) def test_outbound_resp_header_extra_pseudo_headers(self, hdr_validation_flags, - invalid_header): + invalid_header) -> None: """ Outbound response header blocks containing the forbidden response headers fail validation. """ - headers = [(b':status', b'200')] - headers.append((invalid_header, b'some value')) + headers = [(b":status", b"200")] + headers.append((invalid_header, b"some value")) with pytest.raises(h2.exceptions.ProtocolError): list( h2.utilities.validate_outbound_headers( - headers, hdr_validation_flags - ) + headers, hdr_validation_flags, + ), ) @pytest.mark.parametrize( - 'hdr_validation_flags', hdr_validation_response_headers + "hdr_validation_flags", hdr_validation_response_headers, ) @pytest.mark.parametrize( - 'invalid_header', - forbidden_response_headers_bytes + "invalid_header", + forbidden_response_headers_bytes, ) def test_inbound_resp_header_extra_pseudo_headers(self, hdr_validation_flags, - invalid_header): + invalid_header) -> None: """ Inbound response header blocks containing the forbidden response headers fail validation. """ - headers = [(b':status', b'200')] - headers.append((invalid_header, b'some value')) + headers = [(b":status", b"200")] + headers.append((invalid_header, b"some value")) with pytest.raises(h2.exceptions.ProtocolError): list(h2.utilities.validate_headers(headers, hdr_validation_flags)) - @pytest.mark.parametrize('hdr_validation_flags', hdr_validation_combos) - def test_inbound_header_name_length(self, hdr_validation_flags): + @pytest.mark.parametrize("hdr_validation_flags", hdr_validation_combos) + def test_inbound_header_name_length(self, hdr_validation_flags) -> None: with pytest.raises(h2.exceptions.ProtocolError): - list(h2.utilities.validate_headers([(b'', b'foobar')], hdr_validation_flags)) + list(h2.utilities.validate_headers([(b"", b"foobar")], hdr_validation_flags)) - def test_inbound_header_name_length_full_frame_decode(self, frame_factory): + def test_inbound_header_name_length_full_frame_decode(self, frame_factory) -> None: f = frame_factory.build_headers_frame([]) f.data = b"\x00\x00\x05\x00\x00\x00\x00\x04" data = f.serialize() @@ -679,20 +677,21 @@ def test_inbound_header_name_length_full_frame_decode(self, frame_factory): c.receive_data(data) -class TestOversizedHeaders(object): +class TestOversizedHeaders: """ Tests that oversized header blocks are correctly rejected. This replicates the "HPACK Bomb" attack, and confirms that we're resistant against it. """ + request_header_block = [ - (b':method', b'GET'), - (b':authority', b'example.com'), - (b':scheme', b'https'), - (b':path', b'/'), + (b":method", b"GET"), + (b":authority", b"example.com"), + (b":scheme", b"https"), + (b":path", b"/"), ] response_header_block = [ - (b':status', b'200'), + (b":status", b"200"), ] # The first header block contains a single header that fills the header @@ -701,18 +700,18 @@ class TestOversizedHeaders(object): # table. It must come last, so that it evicts all other headers. # This block must be appended to either a request or response block. first_header_block = [ - (b'a', b'a' * 4063), + (b"a", b"a" * 4063), ] # The second header "block" is actually a custom HEADERS frame body that # simply repeatedly refers to the first entry for 16kB. Each byte has the # high bit set (0x80), and then uses the remaining 7 bits to encode the # number 62 (0x3e), leading to a repeat of the byte 0xbe. - second_header_block = b'\xbe' * 2**14 + second_header_block = b"\xbe" * 2**14 server_config = h2.config.H2Configuration(client_side=False) - def test_hpack_bomb_request(self, frame_factory): + def test_hpack_bomb_request(self, frame_factory) -> None: """ A HPACK bomb request causes the connection to be torn down with the error code ENHANCE_YOUR_CALM. @@ -722,7 +721,7 @@ def test_hpack_bomb_request(self, frame_factory): c.clear_outbound_data_buffer() f = frame_factory.build_headers_frame( - self.request_header_block + self.first_header_block + self.request_header_block + self.first_header_block, ) data = f.serialize() c.receive_data(data) @@ -730,18 +729,18 @@ def test_hpack_bomb_request(self, frame_factory): # Build the attack payload. attack_frame = hyperframe.frame.HeadersFrame(stream_id=3) attack_frame.data = self.second_header_block - attack_frame.flags.add('END_HEADERS') + attack_frame.flags.add("END_HEADERS") data = attack_frame.serialize() with pytest.raises(h2.exceptions.DenialOfServiceError): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=1, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM + last_stream_id=1, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM, ) assert c.data_to_send() == expected_frame.serialize() - def test_hpack_bomb_response(self, frame_factory): + def test_hpack_bomb_response(self, frame_factory) -> None: """ A HPACK bomb response causes the connection to be torn down with the error code ENHANCE_YOUR_CALM. @@ -749,15 +748,15 @@ def test_hpack_bomb_response(self, frame_factory): c = h2.connection.H2Connection() c.initiate_connection() c.send_headers( - stream_id=1, headers=self.request_header_block + stream_id=1, headers=self.request_header_block, ) c.send_headers( - stream_id=3, headers=self.request_header_block + stream_id=3, headers=self.request_header_block, ) c.clear_outbound_data_buffer() f = frame_factory.build_headers_frame( - self.response_header_block + self.first_header_block + self.response_header_block + self.first_header_block, ) data = f.serialize() c.receive_data(data) @@ -765,18 +764,18 @@ def test_hpack_bomb_response(self, frame_factory): # Build the attack payload. attack_frame = hyperframe.frame.HeadersFrame(stream_id=3) attack_frame.data = self.second_header_block - attack_frame.flags.add('END_HEADERS') + attack_frame.flags.add("END_HEADERS") data = attack_frame.serialize() with pytest.raises(h2.exceptions.DenialOfServiceError): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=0, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM + last_stream_id=0, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM, ) assert c.data_to_send() == expected_frame.serialize() - def test_hpack_bomb_push(self, frame_factory): + def test_hpack_bomb_push(self, frame_factory) -> None: """ A HPACK bomb push causes the connection to be torn down with the error code ENHANCE_YOUR_CALM. @@ -784,12 +783,12 @@ def test_hpack_bomb_push(self, frame_factory): c = h2.connection.H2Connection() c.initiate_connection() c.send_headers( - stream_id=1, headers=self.request_header_block + stream_id=1, headers=self.request_header_block, ) c.clear_outbound_data_buffer() f = frame_factory.build_headers_frame( - self.response_header_block + self.first_header_block + self.response_header_block + self.first_header_block, ) data = f.serialize() c.receive_data(data) @@ -799,18 +798,18 @@ def test_hpack_bomb_push(self, frame_factory): attack_frame = hyperframe.frame.PushPromiseFrame(stream_id=3) attack_frame.promised_stream_id = 2 attack_frame.data = self.second_header_block[:-4] - attack_frame.flags.add('END_HEADERS') + attack_frame.flags.add("END_HEADERS") data = attack_frame.serialize() with pytest.raises(h2.exceptions.DenialOfServiceError): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=0, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM + last_stream_id=0, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM, ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_headers_when_list_size_shrunk(self, frame_factory): + def test_reject_headers_when_list_size_shrunk(self, frame_factory) -> None: """ When we've shrunk the header list size, we reject new header blocks that violate the new size. @@ -822,7 +821,7 @@ def test_reject_headers_when_list_size_shrunk(self, frame_factory): # Receive the first request, which causes no problem. f = frame_factory.build_headers_frame( stream_id=1, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() c.receive_data(data) @@ -833,7 +832,7 @@ def test_reject_headers_when_list_size_shrunk(self, frame_factory): c.clear_outbound_data_buffer() f = frame_factory.build_headers_frame( stream_id=3, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() c.receive_data(data) @@ -846,7 +845,7 @@ def test_reject_headers_when_list_size_shrunk(self, frame_factory): # Now a third request comes in. This explodes. f = frame_factory.build_headers_frame( stream_id=5, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() @@ -854,11 +853,11 @@ def test_reject_headers_when_list_size_shrunk(self, frame_factory): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=3, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM + last_stream_id=3, error_code=h2.errors.ErrorCodes.ENHANCE_YOUR_CALM, ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_headers_when_table_size_shrunk(self, frame_factory): + def test_reject_headers_when_table_size_shrunk(self, frame_factory) -> None: """ When we've shrunk the header table size, we reject header blocks that do not respect the change. @@ -870,7 +869,7 @@ def test_reject_headers_when_table_size_shrunk(self, frame_factory): # Receive the first request, which causes no problem. f = frame_factory.build_headers_frame( stream_id=1, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() c.receive_data(data) @@ -881,7 +880,7 @@ def test_reject_headers_when_table_size_shrunk(self, frame_factory): c.clear_outbound_data_buffer() f = frame_factory.build_headers_frame( stream_id=3, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() c.receive_data(data) @@ -895,7 +894,7 @@ def test_reject_headers_when_table_size_shrunk(self, frame_factory): # a dynamic table size update. f = frame_factory.build_headers_frame( stream_id=5, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() @@ -903,11 +902,11 @@ def test_reject_headers_when_table_size_shrunk(self, frame_factory): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=3, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + last_stream_id=3, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() - def test_reject_headers_exceeding_table_size(self, frame_factory): + def test_reject_headers_exceeding_table_size(self, frame_factory) -> None: """ When the remote peer sends a dynamic table size update that exceeds our setting, we reject it. @@ -919,7 +918,7 @@ def test_reject_headers_exceeding_table_size(self, frame_factory): # Receive the first request, which causes no problem. f = frame_factory.build_headers_frame( stream_id=1, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() c.receive_data(data) @@ -929,7 +928,7 @@ def test_reject_headers_exceeding_table_size(self, frame_factory): frame_factory.change_table_size(c.local_settings.header_table_size + 1) f = frame_factory.build_headers_frame( stream_id=5, - headers=self.request_header_block + headers=self.request_header_block, ) data = f.serialize() @@ -937,6 +936,6 @@ def test_reject_headers_exceeding_table_size(self, frame_factory): c.receive_data(data) expected_frame = frame_factory.build_goaway_frame( - last_stream_id=1, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR + last_stream_id=1, error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR, ) assert c.data_to_send() == expected_frame.serialize() diff --git a/tests/test_priority.py b/tests/test_priority.py index 086cf68f..761df4ab 100644 --- a/tests/test_priority.py +++ b/tests/test_priority.py @@ -1,10 +1,5 @@ -# -*- coding: utf-8 -*- -""" -test_priority -~~~~~~~~~~~~~ +from __future__ import annotations -Test the priority logic of Hyper-h2. -""" import pytest import h2.config @@ -15,29 +10,30 @@ import h2.stream -class TestPriority(object): +class TestPriority: """ Basic priority tests. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_request_headers_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'pytest-h2'), + (":status", "200"), + ("server", "pytest-h2"), ] server_config = h2.config.H2Configuration(client_side=False) - def test_receiving_priority_emits_priority_update(self, frame_factory): + def test_receiving_priority_emits_priority_update(self, frame_factory) -> None: """ Receiving a priority frame emits a PriorityUpdated event. """ @@ -63,7 +59,7 @@ def test_receiving_priority_emits_priority_update(self, frame_factory): assert event.exclusive is False @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_headers_with_priority_info(self, frame_factory, request_headers): + def test_headers_with_priority_info(self, frame_factory, request_headers) -> None: """ Receiving a HEADERS frame with priority information on it emits a PriorityUpdated event. @@ -76,7 +72,7 @@ def test_headers_with_priority_info(self, frame_factory, request_headers): f = frame_factory.build_headers_frame( headers=request_headers, stream_id=3, - flags=['PRIORITY'], + flags=["PRIORITY"], stream_weight=15, depends_on=1, exclusive=True, @@ -94,7 +90,7 @@ def test_headers_with_priority_info(self, frame_factory, request_headers): assert event.exclusive is True @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_streams_may_not_depend_on_themselves(self, frame_factory, request_headers): + def test_streams_may_not_depend_on_themselves(self, frame_factory, request_headers) -> None: """ A stream adjusted to depend on itself causes a Protocol Error. """ @@ -106,7 +102,7 @@ def test_streams_may_not_depend_on_themselves(self, frame_factory, request_heade f = frame_factory.build_headers_frame( headers=request_headers, stream_id=3, - flags=['PRIORITY'], + flags=["PRIORITY"], stream_weight=15, depends_on=1, exclusive=True, @@ -117,7 +113,7 @@ def test_streams_may_not_depend_on_themselves(self, frame_factory, request_heade f = frame_factory.build_priority_frame( stream_id=3, depends_on=3, - weight=15 + weight=15, ) with pytest.raises(h2.exceptions.ProtocolError): c.receive_data(f.serialize()) @@ -130,15 +126,15 @@ def test_streams_may_not_depend_on_themselves(self, frame_factory, request_heade @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) @pytest.mark.parametrize( - 'depends_on,weight,exclusive', + ("depends_on", "weight", "exclusive"), [ (0, 256, False), (3, 128, False), (3, 128, True), - ] + ], ) def test_can_prioritize_stream(self, depends_on, weight, exclusive, - frame_factory, request_headers): + frame_factory, request_headers) -> None: """ hyper-h2 can emit priority frames. """ @@ -153,7 +149,7 @@ def test_can_prioritize_stream(self, depends_on, weight, exclusive, stream_id=1, depends_on=depends_on, weight=weight, - exclusive=exclusive + exclusive=exclusive, ) f = frame_factory.build_priority_frame( @@ -166,15 +162,15 @@ def test_can_prioritize_stream(self, depends_on, weight, exclusive, @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) @pytest.mark.parametrize( - 'depends_on,weight,exclusive', + ("depends_on", "weight", "exclusive"), [ (0, 256, False), (1, 128, False), (1, 128, True), - ] + ], ) def test_emit_headers_with_priority_info(self, depends_on, weight, - exclusive, frame_factory, request_headers): + exclusive, frame_factory, request_headers) -> None: """ It is possible to send a headers frame with priority information on it. @@ -194,7 +190,7 @@ def test_emit_headers_with_priority_info(self, depends_on, weight, f = frame_factory.build_headers_frame( headers=request_headers, stream_id=3, - flags=['PRIORITY'], + flags=["PRIORITY"], stream_weight=weight - 1, depends_on=depends_on, exclusive=exclusive, @@ -202,7 +198,7 @@ def test_emit_headers_with_priority_info(self, depends_on, weight, assert c.data_to_send() == f.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_may_not_prioritize_stream_to_depend_on_self(self, frame_factory, request_headers): + def test_may_not_prioritize_stream_to_depend_on_self(self, frame_factory, request_headers) -> None: """ A stream adjusted to depend on itself causes a Protocol Error. """ @@ -227,7 +223,7 @@ def test_may_not_prioritize_stream_to_depend_on_self(self, frame_factory, reques assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_may_not_initially_set_stream_depend_on_self(self, frame_factory, request_headers): + def test_may_not_initially_set_stream_depend_on_self(self, frame_factory, request_headers) -> None: """ A stream that starts by depending on itself causes a Protocol Error. """ @@ -245,8 +241,8 @@ def test_may_not_initially_set_stream_depend_on_self(self, frame_factory, reques assert not c.data_to_send() - @pytest.mark.parametrize('weight', [0, -15, 257]) - def test_prioritize_requires_valid_weight(self, weight): + @pytest.mark.parametrize("weight", [0, -15, 257]) + def test_prioritize_requires_valid_weight(self, weight) -> None: """ A call to prioritize with an invalid weight causes a ProtocolError. """ @@ -260,8 +256,8 @@ def test_prioritize_requires_valid_weight(self, weight): assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - @pytest.mark.parametrize('weight', [0, -15, 257]) - def test_send_headers_requires_valid_weight(self, weight, request_headers): + @pytest.mark.parametrize("weight", [0, -15, 257]) + def test_send_headers_requires_valid_weight(self, weight, request_headers) -> None: """ A call to send_headers with an invalid weight causes a ProtocolError. """ @@ -273,12 +269,12 @@ def test_send_headers_requires_valid_weight(self, weight, request_headers): c.send_headers( stream_id=1, headers=request_headers, - priority_weight=weight + priority_weight=weight, ) assert not c.data_to_send() - def test_prioritize_defaults(self, frame_factory): + def test_prioritize_defaults(self, frame_factory) -> None: """ When prioritize() is called with no explicit arguments, it emits a weight of 16, depending on stream zero non-exclusively. @@ -299,14 +295,14 @@ def test_prioritize_defaults(self, frame_factory): @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) @pytest.mark.parametrize( - 'priority_kwargs', + "priority_kwargs", [ - {'priority_weight': 16}, - {'priority_depends_on': 0}, - {'priority_exclusive': False}, - ] + {"priority_weight": 16}, + {"priority_depends_on": 0}, + {"priority_exclusive": False}, + ], ) - def test_send_headers_defaults(self, priority_kwargs, frame_factory, request_headers): + def test_send_headers_defaults(self, priority_kwargs, frame_factory, request_headers) -> None: """ When send_headers() is called with only one explicit argument, it emits default values for everything else. @@ -318,13 +314,13 @@ def test_send_headers_defaults(self, priority_kwargs, frame_factory, request_hea c.send_headers( stream_id=1, headers=request_headers, - **priority_kwargs + **priority_kwargs, ) f = frame_factory.build_headers_frame( headers=request_headers, stream_id=1, - flags=['PRIORITY'], + flags=["PRIORITY"], stream_weight=15, depends_on=0, exclusive=False, @@ -332,7 +328,7 @@ def test_send_headers_defaults(self, priority_kwargs, frame_factory, request_hea assert c.data_to_send() == f.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_servers_cannot_prioritize(self, frame_factory, request_headers): + def test_servers_cannot_prioritize(self, frame_factory, request_headers) -> None: """ Server stacks are not allowed to call ``prioritize()``. """ @@ -351,7 +347,7 @@ def test_servers_cannot_prioritize(self, frame_factory, request_headers): c.prioritize(stream_id=1) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_servers_cannot_prioritize_with_headers(self, frame_factory, request_headers): + def test_servers_cannot_prioritize_with_headers(self, frame_factory, request_headers) -> None: """ Server stacks are not allowed to prioritize on headers either. """ diff --git a/tests/test_related_events.py b/tests/test_related_events.py index 19334408..aed7c187 100644 --- a/tests/test_related_events.py +++ b/tests/test_related_events.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -*- """ -test_related_events.py -~~~~~~~~~~~~~~~~~~~~~~ - Specific tests to validate the "related events" logic used by certain events inside hyper-h2. """ +from __future__ import annotations + import pytest import h2.config @@ -13,42 +11,43 @@ import h2.events -class TestRelatedEvents(object): +class TestRelatedEvents: """ Related events correlate all those events that happen on a single frame. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_request_headers_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] informational_response_headers = [ - (':status', '100'), - ('server', 'fake-serv/0.1.0') + (":status", "100"), + ("server", "fake-serv/0.1.0"), ] example_trailers = [ - ('another', 'field'), + ("another", "field"), ] server_config = h2.config.H2Configuration(client_side=False) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_request_received_related_all(self, frame_factory, request_headers): + def test_request_received_related_all(self, frame_factory, request_headers) -> None: """ RequestReceived has two possible related events: PriorityUpdated and StreamEnded, all fired when a single HEADERS frame is received. @@ -59,7 +58,7 @@ def test_request_received_related_all(self, frame_factory, request_headers): input_frame = frame_factory.build_headers_frame( headers=request_headers, - flags=['END_STREAM', 'PRIORITY'], + flags=["END_STREAM", "PRIORITY"], stream_weight=15, depends_on=0, exclusive=False, @@ -74,11 +73,11 @@ def test_request_received_related_all(self, frame_factory, request_headers): assert isinstance(base_event.stream_ended, h2.events.StreamEnded) assert base_event.priority_updated in other_events assert isinstance( - base_event.priority_updated, h2.events.PriorityUpdated + base_event.priority_updated, h2.events.PriorityUpdated, ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_request_received_related_priority(self, frame_factory, request_headers): + def test_request_received_related_priority(self, frame_factory, request_headers) -> None: """ RequestReceived can be related to PriorityUpdated. """ @@ -88,7 +87,7 @@ def test_request_received_related_priority(self, frame_factory, request_headers) input_frame = frame_factory.build_headers_frame( headers=request_headers, - flags=['PRIORITY'], + flags=["PRIORITY"], stream_weight=15, depends_on=0, exclusive=False, @@ -102,11 +101,11 @@ def test_request_received_related_priority(self, frame_factory, request_headers) assert base_event.priority_updated is priority_updated_event assert base_event.stream_ended is None assert isinstance( - base_event.priority_updated, h2.events.PriorityUpdated + base_event.priority_updated, h2.events.PriorityUpdated, ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_request_received_related_stream_ended(self, frame_factory, request_headers): + def test_request_received_related_stream_ended(self, frame_factory, request_headers) -> None: """ RequestReceived can be related to StreamEnded. """ @@ -116,7 +115,7 @@ def test_request_received_related_stream_ended(self, frame_factory, request_head input_frame = frame_factory.build_headers_frame( headers=request_headers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) events = c.receive_data(input_frame.serialize()) @@ -129,7 +128,7 @@ def test_request_received_related_stream_ended(self, frame_factory, request_head assert isinstance(base_event.stream_ended, h2.events.StreamEnded) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_response_received_related_nothing(self, frame_factory, request_headers): + def test_response_received_related_nothing(self, frame_factory, request_headers) -> None: """ ResponseReceived is ordinarily related to no events. """ @@ -149,7 +148,7 @@ def test_response_received_related_nothing(self, frame_factory, request_headers) assert base_event.priority_updated is None @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_response_received_related_all(self, frame_factory, request_headers): + def test_response_received_related_all(self, frame_factory, request_headers) -> None: """ ResponseReceived has two possible related events: PriorityUpdated and StreamEnded, all fired when a single HEADERS frame is received. @@ -160,7 +159,7 @@ def test_response_received_related_all(self, frame_factory, request_headers): input_frame = frame_factory.build_headers_frame( headers=self.example_response_headers, - flags=['END_STREAM', 'PRIORITY'], + flags=["END_STREAM", "PRIORITY"], stream_weight=15, depends_on=0, exclusive=False, @@ -175,11 +174,11 @@ def test_response_received_related_all(self, frame_factory, request_headers): assert isinstance(base_event.stream_ended, h2.events.StreamEnded) assert base_event.priority_updated in other_events assert isinstance( - base_event.priority_updated, h2.events.PriorityUpdated + base_event.priority_updated, h2.events.PriorityUpdated, ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_response_received_related_priority(self, frame_factory, request_headers): + def test_response_received_related_priority(self, frame_factory, request_headers) -> None: """ ResponseReceived can be related to PriorityUpdated. """ @@ -189,7 +188,7 @@ def test_response_received_related_priority(self, frame_factory, request_headers input_frame = frame_factory.build_headers_frame( headers=self.example_response_headers, - flags=['PRIORITY'], + flags=["PRIORITY"], stream_weight=15, depends_on=0, exclusive=False, @@ -203,11 +202,11 @@ def test_response_received_related_priority(self, frame_factory, request_headers assert base_event.priority_updated is priority_updated_event assert base_event.stream_ended is None assert isinstance( - base_event.priority_updated, h2.events.PriorityUpdated + base_event.priority_updated, h2.events.PriorityUpdated, ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_response_received_related_stream_ended(self, frame_factory, request_headers): + def test_response_received_related_stream_ended(self, frame_factory, request_headers) -> None: """ ResponseReceived can be related to StreamEnded. """ @@ -217,7 +216,7 @@ def test_response_received_related_stream_ended(self, frame_factory, request_hea input_frame = frame_factory.build_headers_frame( headers=self.example_response_headers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) events = c.receive_data(input_frame.serialize()) @@ -230,7 +229,7 @@ def test_response_received_related_stream_ended(self, frame_factory, request_hea assert isinstance(base_event.stream_ended, h2.events.StreamEnded) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_trailers_received_related_all(self, frame_factory, request_headers): + def test_trailers_received_related_all(self, frame_factory, request_headers) -> None: """ TrailersReceived has two possible related events: PriorityUpdated and StreamEnded, all fired when a single HEADERS frame is received. @@ -246,7 +245,7 @@ def test_trailers_received_related_all(self, frame_factory, request_headers): input_frame = frame_factory.build_headers_frame( headers=self.example_trailers, - flags=['END_STREAM', 'PRIORITY'], + flags=["END_STREAM", "PRIORITY"], stream_weight=15, depends_on=0, exclusive=False, @@ -261,11 +260,11 @@ def test_trailers_received_related_all(self, frame_factory, request_headers): assert isinstance(base_event.stream_ended, h2.events.StreamEnded) assert base_event.priority_updated in other_events assert isinstance( - base_event.priority_updated, h2.events.PriorityUpdated + base_event.priority_updated, h2.events.PriorityUpdated, ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_trailers_received_related_stream_ended(self, frame_factory, request_headers): + def test_trailers_received_related_stream_ended(self, frame_factory, request_headers) -> None: """ TrailersReceived can be related to StreamEnded by itself. """ @@ -280,7 +279,7 @@ def test_trailers_received_related_stream_ended(self, frame_factory, request_hea input_frame = frame_factory.build_headers_frame( headers=self.example_trailers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) events = c.receive_data(input_frame.serialize()) @@ -293,7 +292,7 @@ def test_trailers_received_related_stream_ended(self, frame_factory, request_hea assert isinstance(base_event.stream_ended, h2.events.StreamEnded) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_informational_response_related_nothing(self, frame_factory, request_headers): + def test_informational_response_related_nothing(self, frame_factory, request_headers) -> None: """ InformationalResponseReceived in the standard case is related to nothing. @@ -313,7 +312,7 @@ def test_informational_response_related_nothing(self, frame_factory, request_hea assert base_event.priority_updated is None @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_informational_response_received_related_all(self, frame_factory, request_headers): + def test_informational_response_received_related_all(self, frame_factory, request_headers) -> None: """ InformationalResponseReceived has one possible related event: PriorityUpdated, fired when a single HEADERS frame is received. @@ -324,7 +323,7 @@ def test_informational_response_received_related_all(self, frame_factory, reques input_frame = frame_factory.build_headers_frame( headers=self.informational_response_headers, - flags=['PRIORITY'], + flags=["PRIORITY"], stream_weight=15, depends_on=0, exclusive=False, @@ -337,11 +336,11 @@ def test_informational_response_received_related_all(self, frame_factory, reques assert base_event.priority_updated is priority_updated_event assert isinstance( - base_event.priority_updated, h2.events.PriorityUpdated + base_event.priority_updated, h2.events.PriorityUpdated, ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_data_received_normally_relates_to_nothing(self, frame_factory, request_headers): + def test_data_received_normally_relates_to_nothing(self, frame_factory, request_headers) -> None: """ A plain DATA frame leads to DataReceieved with no related events. """ @@ -355,7 +354,7 @@ def test_data_received_normally_relates_to_nothing(self, frame_factory, request_ c.receive_data(f.serialize()) input_frame = frame_factory.build_data_frame( - data=b'some data', + data=b"some data", ) events = c.receive_data(input_frame.serialize()) @@ -365,7 +364,7 @@ def test_data_received_normally_relates_to_nothing(self, frame_factory, request_ assert base_event.stream_ended is None @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_data_received_related_stream_ended(self, frame_factory, request_headers): + def test_data_received_related_stream_ended(self, frame_factory, request_headers) -> None: """ DataReceived can be related to StreamEnded by itself. """ @@ -379,8 +378,8 @@ def test_data_received_related_stream_ended(self, frame_factory, request_headers c.receive_data(f.serialize()) input_frame = frame_factory.build_data_frame( - data=b'some data', - flags=['END_STREAM'], + data=b"some data", + flags=["END_STREAM"], ) events = c.receive_data(input_frame.serialize()) diff --git a/tests/test_rfc7838.py b/tests/test_rfc7838.py index 2396a3f3..63a4d2cc 100644 --- a/tests/test_rfc7838.py +++ b/tests/test_rfc7838.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- """ -test_rfc7838 -~~~~~~~~~~~~ - Test the RFC 7838 ALTSVC support. """ +from __future__ import annotations + import pytest import h2.config @@ -13,28 +11,29 @@ import h2.exceptions -class TestRFC7838Client(object): +class TestRFC7838Client: """ Tests that the client supports receiving the RFC 7838 AltSvc frame. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_request_headers_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] - def test_receiving_altsvc_stream_zero(self, frame_factory): + def test_receiving_altsvc_stream_zero(self, frame_factory) -> None: """ An ALTSVC frame received on stream zero correctly transposes all the fields from the frames. @@ -44,7 +43,7 @@ def test_receiving_altsvc_stream_zero(self, frame_factory): c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=0, origin=b"example.com", field=b'h2=":8000"; ma=60' + stream_id=0, origin=b"example.com", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -58,7 +57,7 @@ def test_receiving_altsvc_stream_zero(self, frame_factory): # No data gets sent. assert not c.data_to_send() - def test_receiving_altsvc_stream_zero_no_origin(self, frame_factory): + def test_receiving_altsvc_stream_zero_no_origin(self, frame_factory) -> None: """ An ALTSVC frame received on stream zero without an origin field is ignored. @@ -68,7 +67,7 @@ def test_receiving_altsvc_stream_zero_no_origin(self, frame_factory): c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=0, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=0, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -76,7 +75,7 @@ def test_receiving_altsvc_stream_zero_no_origin(self, frame_factory): assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_receiving_altsvc_on_stream(self, frame_factory, request_headers): + def test_receiving_altsvc_on_stream(self, frame_factory, request_headers) -> None: """ An ALTSVC frame received on a stream correctly transposes all the fields from the frame and attaches the expected origin. @@ -87,7 +86,7 @@ def test_receiving_altsvc_on_stream(self, frame_factory, request_headers): c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -102,7 +101,7 @@ def test_receiving_altsvc_on_stream(self, frame_factory, request_headers): assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_receiving_altsvc_on_stream_with_origin(self, frame_factory, request_headers): + def test_receiving_altsvc_on_stream_with_origin(self, frame_factory, request_headers) -> None: """ An ALTSVC frame received on a stream with an origin field present gets ignored. @@ -113,14 +112,14 @@ def test_receiving_altsvc_on_stream_with_origin(self, frame_factory, request_hea c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"example.com", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"example.com", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) assert len(events) == 0 assert not c.data_to_send() - def test_receiving_altsvc_on_stream_not_yet_opened(self, frame_factory): + def test_receiving_altsvc_on_stream_not_yet_opened(self, frame_factory) -> None: """ When an ALTSVC frame is received on a stream the client hasn't yet opened, the frame is ignored. @@ -132,17 +131,17 @@ def test_receiving_altsvc_on_stream_not_yet_opened(self, frame_factory): # We'll test this twice, once on a client-initiated stream ID and once # on a server initiated one. f1 = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"", field=b'h2=":8000"; ma=60', ) f2 = frame_factory.build_alt_svc_frame( - stream_id=2, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=2, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f1.serialize() + f2.serialize()) assert len(events) == 0 assert not c.data_to_send() - def test_receiving_altsvc_before_sending_headers(self, frame_factory): + def test_receiving_altsvc_before_sending_headers(self, frame_factory) -> None: """ When an ALTSVC frame is received but the client hasn't sent headers yet it gets ignored. @@ -155,12 +154,12 @@ def test_receiving_altsvc_before_sending_headers(self, frame_factory): # don't currently have a mechanism by which this could occur), it could # happen in the future and we defend against it. c._begin_new_stream( - stream_id=1, allowed_ids=h2.connection.AllowedStreamIDs.ODD + stream_id=1, allowed_ids=h2.connection.AllowedStreamIDs.ODD, ) c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -168,7 +167,7 @@ def test_receiving_altsvc_before_sending_headers(self, frame_factory): assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_receiving_altsvc_after_receiving_headers(self, frame_factory, request_headers): + def test_receiving_altsvc_after_receiving_headers(self, frame_factory, request_headers) -> None: """ When an ALTSVC frame is received but the server has already sent headers it gets ignored. @@ -178,13 +177,13 @@ def test_receiving_altsvc_after_receiving_headers(self, frame_factory, request_h c.send_headers(stream_id=1, headers=request_headers) f = frame_factory.build_headers_frame( - headers=self.example_response_headers + headers=self.example_response_headers, ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -192,25 +191,25 @@ def test_receiving_altsvc_after_receiving_headers(self, frame_factory, request_h assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_receiving_altsvc_on_closed_stream(self, frame_factory, request_headers): + def test_receiving_altsvc_on_closed_stream(self, frame_factory, request_headers) -> None: """ When an ALTSVC frame is received on a closed stream, we ignore it. """ c = h2.connection.H2Connection() c.initiate_connection() c.send_headers( - stream_id=1, headers=request_headers, end_stream=True + stream_id=1, headers=request_headers, end_stream=True, ) f = frame_factory.build_headers_frame( headers=self.example_response_headers, - flags=['END_STREAM'], + flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -218,7 +217,7 @@ def test_receiving_altsvc_on_closed_stream(self, frame_factory, request_headers) assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_receiving_altsvc_on_pushed_stream(self, frame_factory, request_headers): + def test_receiving_altsvc_on_pushed_stream(self, frame_factory, request_headers) -> None: """ When an ALTSVC frame is received on a stream that the server pushed, the frame is accepted. @@ -230,13 +229,13 @@ def test_receiving_altsvc_on_pushed_stream(self, frame_factory, request_headers) f = frame_factory.build_push_promise_frame( stream_id=1, promised_stream_id=2, - headers=request_headers + headers=request_headers, ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=2, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=2, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -251,7 +250,7 @@ def test_receiving_altsvc_on_pushed_stream(self, frame_factory, request_headers) assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_cannot_send_explicit_alternative_service(self, frame_factory, request_headers): + def test_cannot_send_explicit_alternative_service(self, frame_factory, request_headers) -> None: """ A client cannot send an explicit alternative service. """ @@ -267,7 +266,7 @@ def test_cannot_send_explicit_alternative_service(self, frame_factory, request_h ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_cannot_send_implicit_alternative_service(self, frame_factory, request_headers): + def test_cannot_send_implicit_alternative_service(self, frame_factory, request_headers) -> None: """ A client cannot send an implicit alternative service. """ @@ -283,30 +282,31 @@ def test_cannot_send_implicit_alternative_service(self, frame_factory, request_h ) -class TestRFC7838Server(object): +class TestRFC7838Server: """ Tests that the server supports sending the RFC 7838 AltSvc frame. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_request_headers_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (u':status', u'200'), - (u'server', u'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) - def test_receiving_altsvc_as_server_stream_zero(self, frame_factory): + def test_receiving_altsvc_as_server_stream_zero(self, frame_factory) -> None: """ When an ALTSVC frame is received on stream zero and we are a server, we ignore it. @@ -317,7 +317,7 @@ def test_receiving_altsvc_as_server_stream_zero(self, frame_factory): c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=0, origin=b"example.com", field=b'h2=":8000"; ma=60' + stream_id=0, origin=b"example.com", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) @@ -325,7 +325,7 @@ def test_receiving_altsvc_as_server_stream_zero(self, frame_factory): assert not c.data_to_send() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_receiving_altsvc_as_server_on_stream(self, frame_factory, request_headers): + def test_receiving_altsvc_as_server_on_stream(self, frame_factory, request_headers) -> None: """ When an ALTSVC frame is received on a stream and we are a server, we ignore it. @@ -335,20 +335,20 @@ def test_receiving_altsvc_as_server_on_stream(self, frame_factory, request_heade c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() f = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"", field=b'h2=":8000"; ma=60', ) events = c.receive_data(f.serialize()) assert len(events) == 0 assert not c.data_to_send() - def test_sending_explicit_alternative_service(self, frame_factory): + def test_sending_explicit_alternative_service(self, frame_factory) -> None: """ A server can send an explicit alternative service. """ @@ -363,12 +363,12 @@ def test_sending_explicit_alternative_service(self, frame_factory): ) f = frame_factory.build_alt_svc_frame( - stream_id=0, origin=b"example.com", field=b'h2=":8000"; ma=60' + stream_id=0, origin=b"example.com", field=b'h2=":8000"; ma=60', ) assert c.data_to_send() == f.serialize() @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_sending_implicit_alternative_service(self, frame_factory, request_headers): + def test_sending_implicit_alternative_service(self, frame_factory, request_headers) -> None: """ A server can send an implicit alternative service. """ @@ -377,7 +377,7 @@ def test_sending_implicit_alternative_service(self, frame_factory, request_heade c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() @@ -388,12 +388,12 @@ def test_sending_implicit_alternative_service(self, frame_factory, request_heade ) f = frame_factory.build_alt_svc_frame( - stream_id=1, origin=b"", field=b'h2=":8000"; ma=60' + stream_id=1, origin=b"", field=b'h2=":8000"; ma=60', ) assert c.data_to_send() == f.serialize() def test_no_implicit_alternative_service_before_headers(self, - frame_factory): + frame_factory) -> None: """ If headers haven't been received yet, the server forbids sending an implicit alternative service. @@ -412,7 +412,7 @@ def test_no_implicit_alternative_service_before_headers(self, @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) def test_no_implicit_alternative_service_after_response(self, frame_factory, - request_headers): + request_headers) -> None: """ If the server has sent response headers, hyper-h2 forbids sending an implicit alternative service. @@ -422,7 +422,7 @@ def test_no_implicit_alternative_service_after_response(self, c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) c.receive_data(f.serialize()) c.send_headers(stream_id=1, headers=self.example_response_headers) @@ -435,7 +435,7 @@ def test_no_implicit_alternative_service_after_response(self, ) @pytest.mark.parametrize("request_headers", [example_request_headers, example_request_headers_bytes]) - def test_cannot_provide_origin_and_stream_id(self, frame_factory, request_headers): + def test_cannot_provide_origin_and_stream_id(self, frame_factory, request_headers) -> None: """ The user cannot provide both the origin and stream_id arguments when advertising alternative services. @@ -444,7 +444,7 @@ def test_cannot_provide_origin_and_stream_id(self, frame_factory, request_header c.initiate_connection() c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - headers=request_headers + headers=request_headers, ) c.receive_data(f.serialize()) @@ -455,7 +455,7 @@ def test_cannot_provide_origin_and_stream_id(self, frame_factory, request_header stream_id=1, ) - def test_cannot_provide_unicode_altsvc_field(self, frame_factory): + def test_cannot_provide_unicode_altsvc_field(self, frame_factory) -> None: """ The user cannot provide the field value for alternative services as a unicode string. @@ -466,6 +466,6 @@ def test_cannot_provide_unicode_altsvc_field(self, frame_factory): with pytest.raises(ValueError): c.advertise_alternative_service( - field_value=u'h2=":8000"; ma=60', + field_value='h2=":8000"; ma=60', origin=b"example.com", ) diff --git a/tests/test_rfc8441.py b/tests/test_rfc8441.py index d3bbde40..0f6092c2 100644 --- a/tests/test_rfc8441.py +++ b/tests/test_rfc8441.py @@ -1,50 +1,48 @@ -# -*- coding: utf-8 -*- """ -test_rfc8441 -~~~~~~~~~~~~ - Test the RFC 8441 extended connect request support. """ -from h2.utilities import utf8_encode_headers +from __future__ import annotations + import pytest import h2.config import h2.connection import h2.events +from h2.utilities import utf8_encode_headers -class TestRFC8441(object): +class TestRFC8441: """ Tests that the client supports sending an extended connect request and the server supports receiving it. """ headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'CONNECT'), - (':protocol', 'websocket'), - ('user-agent', 'someua/0.0.1'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "CONNECT"), + (":protocol", "websocket"), + ("user-agent", "someua/0.0.1"), ] headers_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'CONNECT'), - (b':protocol', b'websocket'), - (b'user-agent', b'someua/0.0.1'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"CONNECT"), + (b":protocol", b"websocket"), + (b"user-agent", b"someua/0.0.1"), ] @pytest.mark.parametrize("headers", [headers, headers_bytes]) - def test_can_send_headers(self, frame_factory, headers): + def test_can_send_headers(self, frame_factory, headers) -> None: client = h2.connection.H2Connection() client.initiate_connection() client.send_headers(stream_id=1, headers=headers) server = h2.connection.H2Connection( - config=h2.config.H2Configuration(client_side=False) + config=h2.config.H2Configuration(client_side=False), ) events = server.receive_data(client.data_to_send()) event = events[1] diff --git a/tests/test_settings.py b/tests/test_settings.py index d19f93a7..89acb90e 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,27 +1,23 @@ -# -*- coding: utf-8 -*- """ -test_settings -~~~~~~~~~~~~~ - Test the Settings object. """ +from __future__ import annotations + import pytest +from hypothesis import assume, given +from hypothesis.strategies import booleans, builds, fixed_dictionaries, integers import h2.errors import h2.exceptions import h2.settings -from hypothesis import given, assume -from hypothesis.strategies import ( - integers, booleans, fixed_dictionaries, builds -) - -class TestSettings(object): +class TestSettings: """ Test the Settings object behaves as expected. """ - def test_settings_defaults_client(self): + + def test_settings_defaults_client(self) -> None: """ The Settings object begins with the appropriate defaults for clients. """ @@ -33,7 +29,7 @@ def test_settings_defaults_client(self): assert s[h2.settings.SettingCodes.MAX_FRAME_SIZE] == 16384 assert s[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL] == 0 - def test_settings_defaults_server(self): + def test_settings_defaults_server(self) -> None: """ The Settings object begins with the appropriate defaults for servers. """ @@ -45,8 +41,8 @@ def test_settings_defaults_server(self): assert s[h2.settings.SettingCodes.MAX_FRAME_SIZE] == 16384 assert s[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL] == 0 - @pytest.mark.parametrize('client', [True, False]) - def test_can_set_initial_values(self, client): + @pytest.mark.parametrize("client", [True, False]) + def test_can_set_initial_values(self, client) -> None: """ The Settings object can be provided initial values that override the defaults. @@ -69,7 +65,7 @@ def test_can_set_initial_values(self, client): assert s[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL] == 1 @pytest.mark.parametrize( - 'setting,value', + ("setting", "value"), [ (h2.settings.SettingCodes.ENABLE_PUSH, 2), (h2.settings.SettingCodes.ENABLE_PUSH, -1), @@ -79,9 +75,9 @@ def test_can_set_initial_values(self, client): (h2.settings.SettingCodes.MAX_FRAME_SIZE, 2**30), (h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE, -1), (h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL, -1), - ] + ], ) - def test_cannot_set_invalid_initial_values(self, setting, value): + def test_cannot_set_invalid_initial_values(self, setting, value) -> None: """ The Settings object can be provided initial values that override the defaults. @@ -91,7 +87,7 @@ def test_cannot_set_invalid_initial_values(self, setting, value): with pytest.raises(h2.exceptions.InvalidSettingsValueError): h2.settings.Settings(initial_values=overrides) - def test_applying_value_doesnt_take_effect_immediately(self): + def test_applying_value_doesnt_take_effect_immediately(self) -> None: """ When a value is applied to the settings object, it doesn't immediately take effect. @@ -101,7 +97,7 @@ def test_applying_value_doesnt_take_effect_immediately(self): assert s[h2.settings.SettingCodes.HEADER_TABLE_SIZE] == 4096 - def test_acknowledging_values(self): + def test_acknowledging_values(self) -> None: """ When we acknowledge settings, the values change. """ @@ -121,7 +117,7 @@ def test_acknowledging_values(self): s.acknowledge() assert dict(s) == new_settings - def test_acknowledging_returns_the_changed_settings(self): + def test_acknowledging_returns_the_changed_settings(self) -> None: """ Acknowledging settings returns the changes. """ @@ -147,7 +143,7 @@ def test_acknowledging_returns_the_changed_settings(self): assert push_change.original_value == 1 assert push_change.new_value == 0 - def test_acknowledging_only_returns_changed_settings(self): + def test_acknowledging_only_returns_changed_settings(self) -> None: """ Acknowledging settings does not return unchanged settings. """ @@ -157,10 +153,10 @@ def test_acknowledging_only_returns_changed_settings(self): changes = s.acknowledge() assert len(changes) == 1 assert list(changes.keys()) == [ - h2.settings.SettingCodes.INITIAL_WINDOW_SIZE + h2.settings.SettingCodes.INITIAL_WINDOW_SIZE, ] - def test_deleting_values_deletes_all_of_them(self): + def test_deleting_values_deletes_all_of_them(self) -> None: """ When we delete a key we lose all state about it. """ @@ -172,7 +168,7 @@ def test_deleting_values_deletes_all_of_them(self): with pytest.raises(KeyError): s[h2.settings.SettingCodes.HEADER_TABLE_SIZE] - def test_length_correctly_reported(self): + def test_length_correctly_reported(self) -> None: """ Length is related only to the number of keys. """ @@ -188,7 +184,7 @@ def test_length_correctly_reported(self): del s[h2.settings.SettingCodes.HEADER_TABLE_SIZE] assert len(s) == 4 - def test_new_values_work(self): + def test_new_values_work(self) -> None: """ New values initially don't appear """ @@ -198,7 +194,7 @@ def test_new_values_work(self): with pytest.raises(KeyError): s[80] - def test_new_values_follow_basic_acknowledgement_rules(self): + def test_new_values_follow_basic_acknowledgement_rules(self) -> None: """ A new value properly appears when acknowledged. """ @@ -214,7 +210,7 @@ def test_new_values_follow_basic_acknowledgement_rules(self): assert changed.original_value is None assert changed.new_value == 81 - def test_single_values_arent_affected_by_acknowledgement(self): + def test_single_values_arent_affected_by_acknowledgement(self) -> None: """ When acknowledged, unchanged settings remain unchanged. """ @@ -224,7 +220,7 @@ def test_single_values_arent_affected_by_acknowledgement(self): s.acknowledge() assert s[h2.settings.SettingCodes.HEADER_TABLE_SIZE] == 4096 - def test_settings_getters(self): + def test_settings_getters(self) -> None: """ Getters exist for well-known settings. """ @@ -244,7 +240,7 @@ def test_settings_getters(self): h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL ] - def test_settings_setters(self): + def test_settings_setters(self) -> None: """ Setters exist for well-known settings. """ @@ -268,7 +264,7 @@ def test_settings_setters(self): assert s[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL] == 1 @given(integers()) - def test_cannot_set_invalid_values_for_enable_push(self, val): + def test_cannot_set_invalid_values_for_enable_push(self, val) -> None: """ SETTINGS_ENABLE_PUSH only allows two values: 0, 1. """ @@ -290,7 +286,7 @@ def test_cannot_set_invalid_values_for_enable_push(self, val): assert s[h2.settings.SettingCodes.ENABLE_PUSH] == 1 @given(integers()) - def test_cannot_set_invalid_vals_for_initial_window_size(self, val): + def test_cannot_set_invalid_vals_for_initial_window_size(self, val) -> None: """ SETTINGS_INITIAL_WINDOW_SIZE only allows values between 0 and 2**32 - 1 inclusive. @@ -321,7 +317,7 @@ def test_cannot_set_invalid_vals_for_initial_window_size(self, val): assert s[h2.settings.SettingCodes.INITIAL_WINDOW_SIZE] == 65535 @given(integers()) - def test_cannot_set_invalid_values_for_max_frame_size(self, val): + def test_cannot_set_invalid_values_for_max_frame_size(self, val) -> None: """ SETTINGS_MAX_FRAME_SIZE only allows values between 2**14 and 2**24 - 1. """ @@ -347,7 +343,7 @@ def test_cannot_set_invalid_values_for_max_frame_size(self, val): assert s[h2.settings.SettingCodes.MAX_FRAME_SIZE] == 16384 @given(integers()) - def test_cannot_set_invalid_values_for_max_header_list_size(self, val): + def test_cannot_set_invalid_values_for_max_header_list_size(self, val) -> None: """ SETTINGS_MAX_HEADER_LIST_SIZE only allows non-negative values. """ @@ -375,7 +371,7 @@ def test_cannot_set_invalid_values_for_max_header_list_size(self, val): s[h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE] @given(integers()) - def test_cannot_set_invalid_values_for_enable_connect_protocol(self, val): + def test_cannot_set_invalid_values_for_enable_connect_protocol(self, val) -> None: """ SETTINGS_ENABLE_CONNECT_PROTOCOL only allows two values: 0, 1. """ @@ -397,7 +393,7 @@ def test_cannot_set_invalid_values_for_enable_connect_protocol(self, val): assert s[h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL] == 0 -class TestSettingsEquality(object): +class TestSettingsEquality: """ A class defining tests for the standard implementation of == and != . """ @@ -417,48 +413,48 @@ class TestSettingsEquality(object): integers(0, 2**32 - 1), h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: integers(0, 2**32 - 1), - }) + }), ) @given(settings=SettingsStrategy) - def test_equality_reflexive(self, settings): + def test_equality_reflexive(self, settings) -> None: """ An object compares equal to itself using the == operator and the != operator. """ assert (settings == settings) - assert not (settings != settings) + assert settings == settings @given(settings=SettingsStrategy, o_settings=SettingsStrategy) - def test_equality_multiple(self, settings, o_settings): + def test_equality_multiple(self, settings, o_settings) -> None: """ Two objects compare themselves using the == operator and the != operator. """ if settings == o_settings: assert settings == o_settings - assert not (settings != o_settings) + assert settings == o_settings else: assert settings != o_settings - assert not (settings == o_settings) + assert settings != o_settings @given(settings=SettingsStrategy) - def test_another_type_equality(self, settings): + def test_another_type_equality(self, settings) -> None: """ The object does not compare equal to an object of an unrelated type (which does not implement the comparison) using the == operator. """ obj = object() assert (settings != obj) - assert not (settings == obj) + assert settings != obj @given(settings=SettingsStrategy) - def test_delegated_eq(self, settings): + def test_delegated_eq(self, settings) -> None: """ The result of comparison is delegated to the right-hand operand if it is of an unrelated type. """ - class Delegate(object): + class Delegate: def __eq__(self, other): return [self] diff --git a/tests/test_state_machines.py b/tests/test_state_machines.py index 034ae909..b92ef597 100644 --- a/tests/test_state_machines.py +++ b/tests/test_state_machines.py @@ -1,29 +1,27 @@ -# -*- coding: utf-8 -*- """ -test_state_machines -~~~~~~~~~~~~~~~~~~~ - These tests validate the state machines directly. Writing meaningful tests for this case can be tricky, so the majority of these tests use Hypothesis to try to talk about general behaviours rather than specific cases. """ +from __future__ import annotations + import pytest +from hypothesis import given +from hypothesis.strategies import sampled_from import h2.connection import h2.exceptions import h2.stream -from hypothesis import given -from hypothesis.strategies import sampled_from - -class TestConnectionStateMachine(object): +class TestConnectionStateMachine: """ Tests of the connection state machine. """ + @given(state=sampled_from(h2.connection.ConnectionState), input_=sampled_from(h2.connection.ConnectionInputs)) - def test_state_transitions(self, state, input_): + def test_state_transitions(self, state, input_) -> None: c = h2.connection.H2ConnectionStateMachine() c.state = state @@ -34,7 +32,7 @@ def test_state_transitions(self, state, input_): else: assert c.state in h2.connection.ConnectionState - def test_state_machine_only_allows_connection_states(self): + def test_state_machine_only_allows_connection_states(self) -> None: """ The Connection state machine only allows ConnectionState inputs. """ @@ -54,10 +52,10 @@ def test_state_machine_only_allows_connection_states(self): "input_", [ h2.connection.ConnectionInputs.RECV_PRIORITY, - h2.connection.ConnectionInputs.SEND_PRIORITY - ] + h2.connection.ConnectionInputs.SEND_PRIORITY, + ], ) - def test_priority_frames_allowed_in_all_states(self, state, input_): + def test_priority_frames_allowed_in_all_states(self, state, input_) -> None: """ Priority frames can be sent/received in all connection states except closed. @@ -68,13 +66,14 @@ def test_priority_frames_allowed_in_all_states(self, state, input_): c.process_input(input_) -class TestStreamStateMachine(object): +class TestStreamStateMachine: """ Tests of the stream state machine. """ + @given(state=sampled_from(h2.stream.StreamState), input_=sampled_from(h2.stream.StreamInputs)) - def test_state_transitions(self, state, input_): + def test_state_transitions(self, state, input_) -> None: s = h2.stream.H2StreamStateMachine(stream_id=1) s.state = state @@ -105,7 +104,7 @@ def test_state_transitions(self, state, input_): else: assert s.state in h2.stream.StreamState - def test_state_machine_only_allows_stream_states(self): + def test_state_machine_only_allows_stream_states(self) -> None: """ The Stream state machine only allows StreamState inputs. """ @@ -114,7 +113,7 @@ def test_state_machine_only_allows_stream_states(self): with pytest.raises(ValueError): s.process_input(1) - def test_stream_state_machine_forbids_pushes_on_server_streams(self): + def test_stream_state_machine_forbids_pushes_on_server_streams(self) -> None: """ Streams where this peer is a server do not allow receiving pushed frames. @@ -125,7 +124,7 @@ def test_stream_state_machine_forbids_pushes_on_server_streams(self): with pytest.raises(h2.exceptions.ProtocolError): s.process_input(h2.stream.StreamInputs.RECV_PUSH_PROMISE) - def test_stream_state_machine_forbids_sending_pushes_from_clients(self): + def test_stream_state_machine_forbids_sending_pushes_from_clients(self) -> None: """ Streams where this peer is a client do not allow sending pushed frames. """ @@ -144,9 +143,9 @@ def test_stream_state_machine_forbids_sending_pushes_from_clients(self): h2.stream.StreamInputs.SEND_DATA, h2.stream.StreamInputs.SEND_WINDOW_UPDATE, h2.stream.StreamInputs.SEND_END_STREAM, - ] + ], ) - def test_cannot_send_on_closed_streams(self, input_): + def test_cannot_send_on_closed_streams(self, input_) -> None: """ Sending anything but a PRIORITY frame is forbidden on closed streams. """ diff --git a/tests/test_stream_reset.py b/tests/test_stream_reset.py index 77844551..df0b4f2b 100644 --- a/tests/test_stream_reset.py +++ b/tests/test_stream_reset.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- """ -test_stream_reset -~~~~~~~~~~~~~~~~~ - More complex tests that exercise stream resetting functionality to validate that connection state is appropriately maintained. Specifically, these tests validate that streams that have been reset accurately keep track of connection-level state. """ +from __future__ import annotations + import pytest import h2.connection @@ -16,23 +14,24 @@ import h2.events -class TestStreamReset(object): +class TestStreamReset: """ Tests for resetting streams. """ + example_request_headers = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] example_response_headers = [ - (b':status', b'200'), - (b'server', b'fake-serv/0.1.0'), - (b'content-length', b'0') + (b":status", b"200"), + (b"server", b"fake-serv/0.1.0"), + (b"content-length", b"0"), ] - def test_reset_stream_keeps_header_state_correct(self, frame_factory): + def test_reset_stream_keeps_header_state_correct(self, frame_factory) -> None: """ A stream that has been reset still affects the header decoder. """ @@ -44,10 +43,10 @@ def test_reset_stream_keeps_header_state_correct(self, frame_factory): c.clear_outbound_data_buffer() f = frame_factory.build_headers_frame( - headers=self.example_response_headers, stream_id=1 + headers=self.example_response_headers, stream_id=1, ) rst_frame = frame_factory.build_rst_stream_frame( - 1, h2.errors.ErrorCodes.STREAM_CLOSED + 1, h2.errors.ErrorCodes.STREAM_CLOSED, ) events = c.receive_data(f.serialize()) assert not events @@ -56,7 +55,7 @@ def test_reset_stream_keeps_header_state_correct(self, frame_factory): # This works because the header state should be intact from the headers # frame that was send on stream 1, so they should decode cleanly. f = frame_factory.build_headers_frame( - headers=self.example_response_headers, stream_id=3 + headers=self.example_response_headers, stream_id=3, ) event = c.receive_data(f.serialize())[0] @@ -64,11 +63,11 @@ def test_reset_stream_keeps_header_state_correct(self, frame_factory): assert event.stream_id == 3 assert event.headers == self.example_response_headers - @pytest.mark.parametrize('close_id,other_id', [(1, 3), (3, 1)]) + @pytest.mark.parametrize(("close_id", "other_id"), [(1, 3), (3, 1)]) def test_reset_stream_keeps_flow_control_correct(self, close_id, other_id, - frame_factory): + frame_factory) -> None: """ A stream that has been reset does not affect the connection flow control window. @@ -82,15 +81,15 @@ def test_reset_stream_keeps_flow_control_correct(self, initial_window = c.remote_flow_control_window(stream_id=other_id) f = frame_factory.build_headers_frame( - headers=self.example_response_headers, stream_id=close_id + headers=self.example_response_headers, stream_id=close_id, ) c.receive_data(f.serialize()) c.reset_stream(stream_id=close_id) c.clear_outbound_data_buffer() f = frame_factory.build_data_frame( - data=b'some data', - stream_id=close_id + data=b"some data", + stream_id=close_id, ) c.receive_data(f.serialize()) @@ -101,12 +100,12 @@ def test_reset_stream_keeps_flow_control_correct(self, assert c.data_to_send() == expected new_window = c.remote_flow_control_window(stream_id=other_id) - assert initial_window - len(b'some data') == new_window + assert initial_window - len(b"some data") == new_window - @pytest.mark.parametrize('clear_streams', [True, False]) + @pytest.mark.parametrize("clear_streams", [True, False]) def test_reset_stream_automatically_resets_pushed_streams(self, frame_factory, - clear_streams): + clear_streams) -> None: """ Resetting a stream causes RST_STREAM frames to be automatically emitted to close any streams pushed after the reset. diff --git a/tests/test_utility_functions.py b/tests/test_utility_functions.py index 3aa0a245..31634f40 100644 --- a/tests/test_utility_functions.py +++ b/tests/test_utility_functions.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- """ -test_utility_functions -~~~~~~~~~~~~~~~~~~~~~~ - Tests for the various utility functions provided by hyper-h2. """ +from __future__ import annotations + import pytest import h2.config @@ -15,23 +13,24 @@ from h2.utilities import SizeLimitDict, extract_method_header -class TestGetNextAvailableStreamID(object): +class TestGetNextAvailableStreamID: """ Tests for the ``H2Connection.get_next_available_stream_id`` method. """ + example_request_headers = [ - (':authority', 'example.com'), - (':path', '/'), - (':scheme', 'https'), - (':method', 'GET'), + (":authority", "example.com"), + (":path", "/"), + (":scheme", "https"), + (":method", "GET"), ] example_response_headers = [ - (':status', '200'), - ('server', 'fake-serv/0.1.0') + (":status", "200"), + ("server", "fake-serv/0.1.0"), ] server_config = h2.config.H2Configuration(client_side=False) - def test_returns_correct_sequence_for_clients(self, frame_factory): + def test_returns_correct_sequence_for_clients(self, frame_factory) -> None: """ For a client connection, the correct sequence of stream IDs is returned. @@ -55,12 +54,12 @@ def test_returns_correct_sequence_for_clients(self, frame_factory): c.send_headers( stream_id=stream_id, headers=self.example_request_headers, - end_stream=True + end_stream=True, ) f = frame_factory.build_headers_frame( headers=self.example_response_headers, stream_id=stream_id, - flags=['END_STREAM'], + flags=["END_STREAM"], ) c.receive_data(f.serialize()) c.clear_outbound_data_buffer() @@ -71,13 +70,13 @@ def test_returns_correct_sequence_for_clients(self, frame_factory): c.send_headers( stream_id=last_client_id, headers=self.example_request_headers, - end_stream=True + end_stream=True, ) with pytest.raises(h2.exceptions.NoAvailableStreamIDError): c.get_next_available_stream_id() - def test_returns_correct_sequence_for_servers(self, frame_factory): + def test_returns_correct_sequence_for_servers(self, frame_factory) -> None: """ For a server connection, the correct sequence of stream IDs is returned. @@ -94,7 +93,7 @@ def test_returns_correct_sequence_for_servers(self, frame_factory): c.initiate_connection() c.receive_data(frame_factory.preamble()) f = frame_factory.build_headers_frame( - headers=self.example_request_headers + headers=self.example_request_headers, ) c.receive_data(f.serialize()) @@ -107,12 +106,12 @@ def test_returns_correct_sequence_for_servers(self, frame_factory): c.push_stream( stream_id=1, promised_stream_id=stream_id, - request_headers=self.example_request_headers + request_headers=self.example_request_headers, ) c.send_headers( stream_id=stream_id, headers=self.example_response_headers, - end_stream=True + end_stream=True, ) c.clear_outbound_data_buffer() @@ -128,7 +127,7 @@ def test_returns_correct_sequence_for_servers(self, frame_factory): with pytest.raises(h2.exceptions.NoAvailableStreamIDError): c.get_next_available_stream_id() - def test_does_not_increment_without_stream_send(self): + def test_does_not_increment_without_stream_send(self) -> None: """ If a new stream isn't actually created, the next stream ID doesn't change. @@ -143,29 +142,29 @@ def test_does_not_increment_without_stream_send(self): c.send_headers( stream_id=first_stream_id, - headers=self.example_request_headers + headers=self.example_request_headers, ) third_stream_id = c.get_next_available_stream_id() assert third_stream_id == (first_stream_id + 2) -class TestExtractHeader(object): +class TestExtractHeader: example_headers_with_bytes = [ - (b':authority', b'example.com'), - (b':path', b'/'), - (b':scheme', b'https'), - (b':method', b'GET'), + (b":authority", b"example.com"), + (b":path", b"/"), + (b":scheme", b"https"), + (b":method", b"GET"), ] - def test_extract_header_method(self): + def test_extract_header_method(self) -> None: assert extract_method_header( - self.example_headers_with_bytes - ) == b'GET' + self.example_headers_with_bytes, + ) == b"GET" -def test_size_limit_dict_limit(): +def test_size_limit_dict_limit() -> None: dct = SizeLimitDict(size_limit=2) dct[1] = 1 @@ -183,7 +182,7 @@ def test_size_limit_dict_limit(): assert 1 not in dct -def test_size_limit_dict_limit_init(): +def test_size_limit_dict_limit_init() -> None: initial_dct = { 1: 1, 2: 2, @@ -195,7 +194,7 @@ def test_size_limit_dict_limit_init(): assert len(dct) == 2 -def test_size_limit_dict_no_limit(): +def test_size_limit_dict_no_limit() -> None: dct = SizeLimitDict(size_limit=None) dct[1] = 1