_util.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. from typing import Any, Dict, NoReturn, Pattern, Tuple, Type, TypeVar, Union
  2. __all__ = [
  3. "ProtocolError",
  4. "LocalProtocolError",
  5. "RemoteProtocolError",
  6. "validate",
  7. "bytesify",
  8. ]
  9. class ProtocolError(Exception):
  10. """Exception indicating a violation of the HTTP/1.1 protocol.
  11. This as an abstract base class, with two concrete base classes:
  12. :exc:`LocalProtocolError`, which indicates that you tried to do something
  13. that HTTP/1.1 says is illegal, and :exc:`RemoteProtocolError`, which
  14. indicates that the remote peer tried to do something that HTTP/1.1 says is
  15. illegal. See :ref:`error-handling` for details.
  16. In addition to the normal :exc:`Exception` features, it has one attribute:
  17. .. attribute:: error_status_hint
  18. This gives a suggestion as to what status code a server might use if
  19. this error occurred as part of a request.
  20. For a :exc:`RemoteProtocolError`, this is useful as a suggestion for
  21. how you might want to respond to a misbehaving peer, if you're
  22. implementing a server.
  23. For a :exc:`LocalProtocolError`, this can be taken as a suggestion for
  24. how your peer might have responded to *you* if h11 had allowed you to
  25. continue.
  26. The default is 400 Bad Request, a generic catch-all for protocol
  27. violations.
  28. """
  29. def __init__(self, msg: str, error_status_hint: int = 400) -> None:
  30. if type(self) is ProtocolError:
  31. raise TypeError("tried to directly instantiate ProtocolError")
  32. Exception.__init__(self, msg)
  33. self.error_status_hint = error_status_hint
  34. # Strategy: there are a number of public APIs where a LocalProtocolError can
  35. # be raised (send(), all the different event constructors, ...), and only one
  36. # public API where RemoteProtocolError can be raised
  37. # (receive_data()). Therefore we always raise LocalProtocolError internally,
  38. # and then receive_data will translate this into a RemoteProtocolError.
  39. #
  40. # Internally:
  41. # LocalProtocolError is the generic "ProtocolError".
  42. # Externally:
  43. # LocalProtocolError is for local errors and RemoteProtocolError is for
  44. # remote errors.
  45. class LocalProtocolError(ProtocolError):
  46. def _reraise_as_remote_protocol_error(self) -> NoReturn:
  47. # After catching a LocalProtocolError, use this method to re-raise it
  48. # as a RemoteProtocolError. This method must be called from inside an
  49. # except: block.
  50. #
  51. # An easy way to get an equivalent RemoteProtocolError is just to
  52. # modify 'self' in place.
  53. self.__class__ = RemoteProtocolError # type: ignore
  54. # But the re-raising is somewhat non-trivial -- you might think that
  55. # now that we've modified the in-flight exception object, that just
  56. # doing 'raise' to re-raise it would be enough. But it turns out that
  57. # this doesn't work, because Python tracks the exception type
  58. # (exc_info[0]) separately from the exception object (exc_info[1]),
  59. # and we only modified the latter. So we really do need to re-raise
  60. # the new type explicitly.
  61. # On py3, the traceback is part of the exception object, so our
  62. # in-place modification preserved it and we can just re-raise:
  63. raise self
  64. class RemoteProtocolError(ProtocolError):
  65. pass
  66. def validate(
  67. regex: Pattern[bytes], data: bytes, msg: str = "malformed data", *format_args: Any
  68. ) -> Dict[str, bytes]:
  69. match = regex.fullmatch(data)
  70. if not match:
  71. if format_args:
  72. msg = msg.format(*format_args)
  73. raise LocalProtocolError(msg)
  74. return match.groupdict()
  75. # Sentinel values
  76. #
  77. # - Inherit identity-based comparison and hashing from object
  78. # - Have a nice repr
  79. # - Have a *bonus property*: type(sentinel) is sentinel
  80. #
  81. # The bonus property is useful if you want to take the return value from
  82. # next_event() and do some sort of dispatch based on type(event).
  83. _T_Sentinel = TypeVar("_T_Sentinel", bound="Sentinel")
  84. class Sentinel(type):
  85. def __new__(
  86. cls: Type[_T_Sentinel],
  87. name: str,
  88. bases: Tuple[type, ...],
  89. namespace: Dict[str, Any],
  90. **kwds: Any
  91. ) -> _T_Sentinel:
  92. assert bases == (Sentinel,)
  93. v = super().__new__(cls, name, bases, namespace, **kwds)
  94. v.__class__ = v # type: ignore
  95. return v
  96. def __repr__(self) -> str:
  97. return self.__name__
  98. # Used for methods, request targets, HTTP versions, header names, and header
  99. # values. Accepts ascii-strings, or bytes/bytearray/memoryview/..., and always
  100. # returns bytes.
  101. def bytesify(s: Union[bytes, bytearray, memoryview, int, str]) -> bytes:
  102. # Fast-path:
  103. if type(s) is bytes:
  104. return s
  105. if isinstance(s, str):
  106. s = s.encode("ascii")
  107. if isinstance(s, int):
  108. raise TypeError("expected bytes-like object, not int")
  109. return bytes(s)