_multidict_py.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. import sys
  2. import types
  3. from array import array
  4. from collections import abc
  5. from ._abc import MultiMapping, MutableMultiMapping
  6. _marker = object()
  7. if sys.version_info >= (3, 9):
  8. GenericAlias = types.GenericAlias
  9. else:
  10. def GenericAlias(cls):
  11. return cls
  12. class istr(str):
  13. """Case insensitive str."""
  14. __is_istr__ = True
  15. upstr = istr # for relaxing backward compatibility problems
  16. def getversion(md):
  17. if not isinstance(md, _Base):
  18. raise TypeError("Parameter should be multidict or proxy")
  19. return md._impl._version
  20. _version = array("Q", [0])
  21. class _Impl:
  22. __slots__ = ("_items", "_version")
  23. def __init__(self):
  24. self._items = []
  25. self.incr_version()
  26. def incr_version(self):
  27. global _version
  28. v = _version
  29. v[0] += 1
  30. self._version = v[0]
  31. if sys.implementation.name != "pypy":
  32. def __sizeof__(self):
  33. return object.__sizeof__(self) + sys.getsizeof(self._items)
  34. class _Base:
  35. def _title(self, key):
  36. return key
  37. def getall(self, key, default=_marker):
  38. """Return a list of all values matching the key."""
  39. identity = self._title(key)
  40. res = [v for i, k, v in self._impl._items if i == identity]
  41. if res:
  42. return res
  43. if not res and default is not _marker:
  44. return default
  45. raise KeyError("Key not found: %r" % key)
  46. def getone(self, key, default=_marker):
  47. """Get first value matching the key.
  48. Raises KeyError if the key is not found and no default is provided.
  49. """
  50. identity = self._title(key)
  51. for i, k, v in self._impl._items:
  52. if i == identity:
  53. return v
  54. if default is not _marker:
  55. return default
  56. raise KeyError("Key not found: %r" % key)
  57. # Mapping interface #
  58. def __getitem__(self, key):
  59. return self.getone(key)
  60. def get(self, key, default=None):
  61. """Get first value matching the key.
  62. If the key is not found, returns the default (or None if no default is provided)
  63. """
  64. return self.getone(key, default)
  65. def __iter__(self):
  66. return iter(self.keys())
  67. def __len__(self):
  68. return len(self._impl._items)
  69. def keys(self):
  70. """Return a new view of the dictionary's keys."""
  71. return _KeysView(self._impl)
  72. def items(self):
  73. """Return a new view of the dictionary's items *(key, value) pairs)."""
  74. return _ItemsView(self._impl)
  75. def values(self):
  76. """Return a new view of the dictionary's values."""
  77. return _ValuesView(self._impl)
  78. def __eq__(self, other):
  79. if not isinstance(other, abc.Mapping):
  80. return NotImplemented
  81. if isinstance(other, _Base):
  82. lft = self._impl._items
  83. rht = other._impl._items
  84. if len(lft) != len(rht):
  85. return False
  86. for (i1, k2, v1), (i2, k2, v2) in zip(lft, rht):
  87. if i1 != i2 or v1 != v2:
  88. return False
  89. return True
  90. if len(self._impl._items) != len(other):
  91. return False
  92. for k, v in self.items():
  93. nv = other.get(k, _marker)
  94. if v != nv:
  95. return False
  96. return True
  97. def __contains__(self, key):
  98. identity = self._title(key)
  99. for i, k, v in self._impl._items:
  100. if i == identity:
  101. return True
  102. return False
  103. def __repr__(self):
  104. body = ", ".join("'{}': {!r}".format(k, v) for k, v in self.items())
  105. return "<{}({})>".format(self.__class__.__name__, body)
  106. __class_getitem__ = classmethod(GenericAlias)
  107. class MultiDictProxy(_Base, MultiMapping):
  108. """Read-only proxy for MultiDict instance."""
  109. def __init__(self, arg):
  110. if not isinstance(arg, (MultiDict, MultiDictProxy)):
  111. raise TypeError(
  112. "ctor requires MultiDict or MultiDictProxy instance"
  113. ", not {}".format(type(arg))
  114. )
  115. self._impl = arg._impl
  116. def __reduce__(self):
  117. raise TypeError("can't pickle {} objects".format(self.__class__.__name__))
  118. def copy(self):
  119. """Return a copy of itself."""
  120. return MultiDict(self.items())
  121. class CIMultiDictProxy(MultiDictProxy):
  122. """Read-only proxy for CIMultiDict instance."""
  123. def __init__(self, arg):
  124. if not isinstance(arg, (CIMultiDict, CIMultiDictProxy)):
  125. raise TypeError(
  126. "ctor requires CIMultiDict or CIMultiDictProxy instance"
  127. ", not {}".format(type(arg))
  128. )
  129. self._impl = arg._impl
  130. def _title(self, key):
  131. return key.title()
  132. def copy(self):
  133. """Return a copy of itself."""
  134. return CIMultiDict(self.items())
  135. class MultiDict(_Base, MutableMultiMapping):
  136. """Dictionary with the support for duplicate keys."""
  137. def __init__(self, *args, **kwargs):
  138. self._impl = _Impl()
  139. self._extend(args, kwargs, self.__class__.__name__, self._extend_items)
  140. if sys.implementation.name != "pypy":
  141. def __sizeof__(self):
  142. return object.__sizeof__(self) + sys.getsizeof(self._impl)
  143. def __reduce__(self):
  144. return (self.__class__, (list(self.items()),))
  145. def _title(self, key):
  146. return key
  147. def _key(self, key):
  148. if isinstance(key, str):
  149. return key
  150. else:
  151. raise TypeError(
  152. "MultiDict keys should be either str " "or subclasses of str"
  153. )
  154. def add(self, key, value):
  155. identity = self._title(key)
  156. self._impl._items.append((identity, self._key(key), value))
  157. self._impl.incr_version()
  158. def copy(self):
  159. """Return a copy of itself."""
  160. cls = self.__class__
  161. return cls(self.items())
  162. __copy__ = copy
  163. def extend(self, *args, **kwargs):
  164. """Extend current MultiDict with more values.
  165. This method must be used instead of update.
  166. """
  167. self._extend(args, kwargs, "extend", self._extend_items)
  168. def _extend(self, args, kwargs, name, method):
  169. if len(args) > 1:
  170. raise TypeError(
  171. "{} takes at most 1 positional argument"
  172. " ({} given)".format(name, len(args))
  173. )
  174. if args:
  175. arg = args[0]
  176. if isinstance(args[0], (MultiDict, MultiDictProxy)) and not kwargs:
  177. items = arg._impl._items
  178. else:
  179. if hasattr(arg, "items"):
  180. arg = arg.items()
  181. if kwargs:
  182. arg = list(arg)
  183. arg.extend(list(kwargs.items()))
  184. items = []
  185. for item in arg:
  186. if not len(item) == 2:
  187. raise TypeError(
  188. "{} takes either dict or list of (key, value) "
  189. "tuples".format(name)
  190. )
  191. items.append((self._title(item[0]), self._key(item[0]), item[1]))
  192. method(items)
  193. else:
  194. method(
  195. [
  196. (self._title(key), self._key(key), value)
  197. for key, value in kwargs.items()
  198. ]
  199. )
  200. def _extend_items(self, items):
  201. for identity, key, value in items:
  202. self.add(key, value)
  203. def clear(self):
  204. """Remove all items from MultiDict."""
  205. self._impl._items.clear()
  206. self._impl.incr_version()
  207. # Mapping interface #
  208. def __setitem__(self, key, value):
  209. self._replace(key, value)
  210. def __delitem__(self, key):
  211. identity = self._title(key)
  212. items = self._impl._items
  213. found = False
  214. for i in range(len(items) - 1, -1, -1):
  215. if items[i][0] == identity:
  216. del items[i]
  217. found = True
  218. if not found:
  219. raise KeyError(key)
  220. else:
  221. self._impl.incr_version()
  222. def setdefault(self, key, default=None):
  223. """Return value for key, set value to default if key is not present."""
  224. identity = self._title(key)
  225. for i, k, v in self._impl._items:
  226. if i == identity:
  227. return v
  228. self.add(key, default)
  229. return default
  230. def popone(self, key, default=_marker):
  231. """Remove specified key and return the corresponding value.
  232. If key is not found, d is returned if given, otherwise
  233. KeyError is raised.
  234. """
  235. identity = self._title(key)
  236. for i in range(len(self._impl._items)):
  237. if self._impl._items[i][0] == identity:
  238. value = self._impl._items[i][2]
  239. del self._impl._items[i]
  240. self._impl.incr_version()
  241. return value
  242. if default is _marker:
  243. raise KeyError(key)
  244. else:
  245. return default
  246. pop = popone # type: ignore
  247. def popall(self, key, default=_marker):
  248. """Remove all occurrences of key and return the list of corresponding
  249. values.
  250. If key is not found, default is returned if given, otherwise
  251. KeyError is raised.
  252. """
  253. found = False
  254. identity = self._title(key)
  255. ret = []
  256. for i in range(len(self._impl._items) - 1, -1, -1):
  257. item = self._impl._items[i]
  258. if item[0] == identity:
  259. ret.append(item[2])
  260. del self._impl._items[i]
  261. self._impl.incr_version()
  262. found = True
  263. if not found:
  264. if default is _marker:
  265. raise KeyError(key)
  266. else:
  267. return default
  268. else:
  269. ret.reverse()
  270. return ret
  271. def popitem(self):
  272. """Remove and return an arbitrary (key, value) pair."""
  273. if self._impl._items:
  274. i = self._impl._items.pop(0)
  275. self._impl.incr_version()
  276. return i[1], i[2]
  277. else:
  278. raise KeyError("empty multidict")
  279. def update(self, *args, **kwargs):
  280. """Update the dictionary from *other*, overwriting existing keys."""
  281. self._extend(args, kwargs, "update", self._update_items)
  282. def _update_items(self, items):
  283. if not items:
  284. return
  285. used_keys = {}
  286. for identity, key, value in items:
  287. start = used_keys.get(identity, 0)
  288. for i in range(start, len(self._impl._items)):
  289. item = self._impl._items[i]
  290. if item[0] == identity:
  291. used_keys[identity] = i + 1
  292. self._impl._items[i] = (identity, key, value)
  293. break
  294. else:
  295. self._impl._items.append((identity, key, value))
  296. used_keys[identity] = len(self._impl._items)
  297. # drop tails
  298. i = 0
  299. while i < len(self._impl._items):
  300. item = self._impl._items[i]
  301. identity = item[0]
  302. pos = used_keys.get(identity)
  303. if pos is None:
  304. i += 1
  305. continue
  306. if i >= pos:
  307. del self._impl._items[i]
  308. else:
  309. i += 1
  310. self._impl.incr_version()
  311. def _replace(self, key, value):
  312. key = self._key(key)
  313. identity = self._title(key)
  314. items = self._impl._items
  315. for i in range(len(items)):
  316. item = items[i]
  317. if item[0] == identity:
  318. items[i] = (identity, key, value)
  319. # i points to last found item
  320. rgt = i
  321. self._impl.incr_version()
  322. break
  323. else:
  324. self._impl._items.append((identity, key, value))
  325. self._impl.incr_version()
  326. return
  327. # remove all tail items
  328. i = rgt + 1
  329. while i < len(items):
  330. item = items[i]
  331. if item[0] == identity:
  332. del items[i]
  333. else:
  334. i += 1
  335. class CIMultiDict(MultiDict):
  336. """Dictionary with the support for duplicate case-insensitive keys."""
  337. def _title(self, key):
  338. return key.title()
  339. class _Iter:
  340. __slots__ = ("_size", "_iter")
  341. def __init__(self, size, iterator):
  342. self._size = size
  343. self._iter = iterator
  344. def __iter__(self):
  345. return self
  346. def __next__(self):
  347. return next(self._iter)
  348. def __length_hint__(self):
  349. return self._size
  350. class _ViewBase:
  351. def __init__(self, impl):
  352. self._impl = impl
  353. def __len__(self):
  354. return len(self._impl._items)
  355. class _ItemsView(_ViewBase, abc.ItemsView):
  356. def __contains__(self, item):
  357. assert isinstance(item, tuple) or isinstance(item, list)
  358. assert len(item) == 2
  359. for i, k, v in self._impl._items:
  360. if item[0] == k and item[1] == v:
  361. return True
  362. return False
  363. def __iter__(self):
  364. return _Iter(len(self), self._iter(self._impl._version))
  365. def _iter(self, version):
  366. for i, k, v in self._impl._items:
  367. if version != self._impl._version:
  368. raise RuntimeError("Dictionary changed during iteration")
  369. yield k, v
  370. def __repr__(self):
  371. lst = []
  372. for item in self._impl._items:
  373. lst.append("{!r}: {!r}".format(item[1], item[2]))
  374. body = ", ".join(lst)
  375. return "{}({})".format(self.__class__.__name__, body)
  376. class _ValuesView(_ViewBase, abc.ValuesView):
  377. def __contains__(self, value):
  378. for item in self._impl._items:
  379. if item[2] == value:
  380. return True
  381. return False
  382. def __iter__(self):
  383. return _Iter(len(self), self._iter(self._impl._version))
  384. def _iter(self, version):
  385. for item in self._impl._items:
  386. if version != self._impl._version:
  387. raise RuntimeError("Dictionary changed during iteration")
  388. yield item[2]
  389. def __repr__(self):
  390. lst = []
  391. for item in self._impl._items:
  392. lst.append("{!r}".format(item[2]))
  393. body = ", ".join(lst)
  394. return "{}({})".format(self.__class__.__name__, body)
  395. class _KeysView(_ViewBase, abc.KeysView):
  396. def __contains__(self, key):
  397. for item in self._impl._items:
  398. if item[1] == key:
  399. return True
  400. return False
  401. def __iter__(self):
  402. return _Iter(len(self), self._iter(self._impl._version))
  403. def _iter(self, version):
  404. for item in self._impl._items:
  405. if version != self._impl._version:
  406. raise RuntimeError("Dictionary changed during iteration")
  407. yield item[1]
  408. def __repr__(self):
  409. lst = []
  410. for item in self._impl._items:
  411. lst.append("{!r}".format(item[1]))
  412. body = ", ".join(lst)
  413. return "{}({})".format(self.__class__.__name__, body)