matcher.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from __future__ import annotations
  2. import re
  3. import typing as t
  4. from dataclasses import dataclass
  5. from dataclasses import field
  6. from .converters import ValidationError
  7. from .exceptions import NoMatch
  8. from .exceptions import RequestAliasRedirect
  9. from .exceptions import RequestPath
  10. from .rules import Rule
  11. from .rules import RulePart
  12. class SlashRequired(Exception):
  13. pass
  14. @dataclass
  15. class State:
  16. """A representation of a rule state.
  17. This includes the *rules* that correspond to the state and the
  18. possible *static* and *dynamic* transitions to the next state.
  19. """
  20. dynamic: list[tuple[RulePart, State]] = field(default_factory=list)
  21. rules: list[Rule] = field(default_factory=list)
  22. static: dict[str, State] = field(default_factory=dict)
  23. class StateMachineMatcher:
  24. def __init__(self, merge_slashes: bool) -> None:
  25. self._root = State()
  26. self.merge_slashes = merge_slashes
  27. def add(self, rule: Rule) -> None:
  28. state = self._root
  29. for part in rule._parts:
  30. if part.static:
  31. state.static.setdefault(part.content, State())
  32. state = state.static[part.content]
  33. else:
  34. for test_part, new_state in state.dynamic:
  35. if test_part == part:
  36. state = new_state
  37. break
  38. else:
  39. new_state = State()
  40. state.dynamic.append((part, new_state))
  41. state = new_state
  42. state.rules.append(rule)
  43. def update(self) -> None:
  44. # For every state the dynamic transitions should be sorted by
  45. # the weight of the transition
  46. state = self._root
  47. def _update_state(state: State) -> None:
  48. state.dynamic.sort(key=lambda entry: entry[0].weight)
  49. for new_state in state.static.values():
  50. _update_state(new_state)
  51. for _, new_state in state.dynamic:
  52. _update_state(new_state)
  53. _update_state(state)
  54. def match(
  55. self, domain: str, path: str, method: str, websocket: bool
  56. ) -> tuple[Rule, t.MutableMapping[str, t.Any]]:
  57. # To match to a rule we need to start at the root state and
  58. # try to follow the transitions until we find a match, or find
  59. # there is no transition to follow.
  60. have_match_for = set()
  61. websocket_mismatch = False
  62. def _match(
  63. state: State, parts: list[str], values: list[str]
  64. ) -> tuple[Rule, list[str]] | None:
  65. # This function is meant to be called recursively, and will attempt
  66. # to match the head part to the state's transitions.
  67. nonlocal have_match_for, websocket_mismatch
  68. # The base case is when all parts have been matched via
  69. # transitions. Hence if there is a rule with methods &
  70. # websocket that work return it and the dynamic values
  71. # extracted.
  72. if parts == []:
  73. for rule in state.rules:
  74. if rule.methods is not None and method not in rule.methods:
  75. have_match_for.update(rule.methods)
  76. elif rule.websocket != websocket:
  77. websocket_mismatch = True
  78. else:
  79. return rule, values
  80. # Test if there is a match with this path with a
  81. # trailing slash, if so raise an exception to report
  82. # that matching is possible with an additional slash
  83. if "" in state.static:
  84. for rule in state.static[""].rules:
  85. if websocket == rule.websocket and (
  86. rule.methods is None or method in rule.methods
  87. ):
  88. if rule.strict_slashes:
  89. raise SlashRequired()
  90. else:
  91. return rule, values
  92. return None
  93. part = parts[0]
  94. # To match this part try the static transitions first
  95. if part in state.static:
  96. rv = _match(state.static[part], parts[1:], values)
  97. if rv is not None:
  98. return rv
  99. # No match via the static transitions, so try the dynamic
  100. # ones.
  101. for test_part, new_state in state.dynamic:
  102. target = part
  103. remaining = parts[1:]
  104. # A final part indicates a transition that always
  105. # consumes the remaining parts i.e. transitions to a
  106. # final state.
  107. if test_part.final:
  108. target = "/".join(parts)
  109. remaining = []
  110. match = re.compile(test_part.content).match(target)
  111. if match is not None:
  112. if test_part.suffixed:
  113. # If a part_isolating=False part has a slash suffix, remove the
  114. # suffix from the match and check for the slash redirect next.
  115. suffix = match.groups()[-1]
  116. if suffix == "/":
  117. remaining = [""]
  118. converter_groups = sorted(
  119. match.groupdict().items(), key=lambda entry: entry[0]
  120. )
  121. groups = [
  122. value
  123. for key, value in converter_groups
  124. if key[:11] == "__werkzeug_"
  125. ]
  126. rv = _match(new_state, remaining, values + groups)
  127. if rv is not None:
  128. return rv
  129. # If there is no match and the only part left is a
  130. # trailing slash ("") consider rules that aren't
  131. # strict-slashes as these should match if there is a final
  132. # slash part.
  133. if parts == [""]:
  134. for rule in state.rules:
  135. if rule.strict_slashes:
  136. continue
  137. if rule.methods is not None and method not in rule.methods:
  138. have_match_for.update(rule.methods)
  139. elif rule.websocket != websocket:
  140. websocket_mismatch = True
  141. else:
  142. return rule, values
  143. return None
  144. try:
  145. rv = _match(self._root, [domain, *path.split("/")], [])
  146. except SlashRequired:
  147. raise RequestPath(f"{path}/") from None
  148. if self.merge_slashes and rv is None:
  149. # Try to match again, but with slashes merged
  150. path = re.sub("/{2,}?", "/", path)
  151. try:
  152. rv = _match(self._root, [domain, *path.split("/")], [])
  153. except SlashRequired:
  154. raise RequestPath(f"{path}/") from None
  155. if rv is None or rv[0].merge_slashes is False:
  156. raise NoMatch(have_match_for, websocket_mismatch)
  157. else:
  158. raise RequestPath(f"{path}")
  159. elif rv is not None:
  160. rule, values = rv
  161. result = {}
  162. for name, value in zip(rule._converters.keys(), values):
  163. try:
  164. value = rule._converters[name].to_python(value)
  165. except ValidationError:
  166. raise NoMatch(have_match_for, websocket_mismatch) from None
  167. result[str(name)] = value
  168. if rule.defaults:
  169. result.update(rule.defaults)
  170. if rule.alias and rule.map.redirect_defaults:
  171. raise RequestAliasRedirect(result, rule.endpoint)
  172. return rule, result
  173. raise NoMatch(have_match_for, websocket_mismatch)