_parser.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. # SPDX-License-Identifier: MIT
  2. # SPDX-FileCopyrightText: 2021 Taneli Hukkinen
  3. # Licensed to PSF under a Contributor Agreement.
  4. from __future__ import annotations
  5. from collections.abc import Iterable
  6. import string
  7. from types import MappingProxyType
  8. from typing import Any, BinaryIO, NamedTuple
  9. from ._re import (
  10. RE_DATETIME,
  11. RE_LOCALTIME,
  12. RE_NUMBER,
  13. match_to_datetime,
  14. match_to_localtime,
  15. match_to_number,
  16. )
  17. from ._types import Key, ParseFloat, Pos
  18. ASCII_CTRL = frozenset(chr(i) for i in range(32)) | frozenset(chr(127))
  19. # Neither of these sets include quotation mark or backslash. They are
  20. # currently handled as separate cases in the parser functions.
  21. ILLEGAL_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t")
  22. ILLEGAL_MULTILINE_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t\n")
  23. ILLEGAL_LITERAL_STR_CHARS = ILLEGAL_BASIC_STR_CHARS
  24. ILLEGAL_MULTILINE_LITERAL_STR_CHARS = ILLEGAL_MULTILINE_BASIC_STR_CHARS
  25. ILLEGAL_COMMENT_CHARS = ILLEGAL_BASIC_STR_CHARS
  26. TOML_WS = frozenset(" \t")
  27. TOML_WS_AND_NEWLINE = TOML_WS | frozenset("\n")
  28. BARE_KEY_CHARS = frozenset(string.ascii_letters + string.digits + "-_")
  29. KEY_INITIAL_CHARS = BARE_KEY_CHARS | frozenset("\"'")
  30. HEXDIGIT_CHARS = frozenset(string.hexdigits)
  31. BASIC_STR_ESCAPE_REPLACEMENTS = MappingProxyType(
  32. {
  33. "\\b": "\u0008", # backspace
  34. "\\t": "\u0009", # tab
  35. "\\n": "\u000A", # linefeed
  36. "\\f": "\u000C", # form feed
  37. "\\r": "\u000D", # carriage return
  38. '\\"': "\u0022", # quote
  39. "\\\\": "\u005C", # backslash
  40. }
  41. )
  42. class TOMLDecodeError(ValueError):
  43. """An error raised if a document is not valid TOML."""
  44. def load(__fp: BinaryIO, *, parse_float: ParseFloat = float) -> dict[str, Any]:
  45. """Parse TOML from a binary file object."""
  46. b = __fp.read()
  47. try:
  48. s = b.decode()
  49. except AttributeError:
  50. raise TypeError(
  51. "File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`"
  52. ) from None
  53. return loads(s, parse_float=parse_float)
  54. def loads(__s: str, *, parse_float: ParseFloat = float) -> dict[str, Any]: # noqa: C901
  55. """Parse TOML from a string."""
  56. # The spec allows converting "\r\n" to "\n", even in string
  57. # literals. Let's do so to simplify parsing.
  58. src = __s.replace("\r\n", "\n")
  59. pos = 0
  60. out = Output(NestedDict(), Flags())
  61. header: Key = ()
  62. parse_float = make_safe_parse_float(parse_float)
  63. # Parse one statement at a time
  64. # (typically means one line in TOML source)
  65. while True:
  66. # 1. Skip line leading whitespace
  67. pos = skip_chars(src, pos, TOML_WS)
  68. # 2. Parse rules. Expect one of the following:
  69. # - end of file
  70. # - end of line
  71. # - comment
  72. # - key/value pair
  73. # - append dict to list (and move to its namespace)
  74. # - create dict (and move to its namespace)
  75. # Skip trailing whitespace when applicable.
  76. try:
  77. char = src[pos]
  78. except IndexError:
  79. break
  80. if char == "\n":
  81. pos += 1
  82. continue
  83. if char in KEY_INITIAL_CHARS:
  84. pos = key_value_rule(src, pos, out, header, parse_float)
  85. pos = skip_chars(src, pos, TOML_WS)
  86. elif char == "[":
  87. try:
  88. second_char: str | None = src[pos + 1]
  89. except IndexError:
  90. second_char = None
  91. out.flags.finalize_pending()
  92. if second_char == "[":
  93. pos, header = create_list_rule(src, pos, out)
  94. else:
  95. pos, header = create_dict_rule(src, pos, out)
  96. pos = skip_chars(src, pos, TOML_WS)
  97. elif char != "#":
  98. raise suffixed_err(src, pos, "Invalid statement")
  99. # 3. Skip comment
  100. pos = skip_comment(src, pos)
  101. # 4. Expect end of line or end of file
  102. try:
  103. char = src[pos]
  104. except IndexError:
  105. break
  106. if char != "\n":
  107. raise suffixed_err(
  108. src, pos, "Expected newline or end of document after a statement"
  109. )
  110. pos += 1
  111. return out.data.dict
  112. class Flags:
  113. """Flags that map to parsed keys/namespaces."""
  114. # Marks an immutable namespace (inline array or inline table).
  115. FROZEN = 0
  116. # Marks a nest that has been explicitly created and can no longer
  117. # be opened using the "[table]" syntax.
  118. EXPLICIT_NEST = 1
  119. def __init__(self) -> None:
  120. self._flags: dict[str, dict] = {}
  121. self._pending_flags: set[tuple[Key, int]] = set()
  122. def add_pending(self, key: Key, flag: int) -> None:
  123. self._pending_flags.add((key, flag))
  124. def finalize_pending(self) -> None:
  125. for key, flag in self._pending_flags:
  126. self.set(key, flag, recursive=False)
  127. self._pending_flags.clear()
  128. def unset_all(self, key: Key) -> None:
  129. cont = self._flags
  130. for k in key[:-1]:
  131. if k not in cont:
  132. return
  133. cont = cont[k]["nested"]
  134. cont.pop(key[-1], None)
  135. def set(self, key: Key, flag: int, *, recursive: bool) -> None: # noqa: A003
  136. cont = self._flags
  137. key_parent, key_stem = key[:-1], key[-1]
  138. for k in key_parent:
  139. if k not in cont:
  140. cont[k] = {"flags": set(), "recursive_flags": set(), "nested": {}}
  141. cont = cont[k]["nested"]
  142. if key_stem not in cont:
  143. cont[key_stem] = {"flags": set(), "recursive_flags": set(), "nested": {}}
  144. cont[key_stem]["recursive_flags" if recursive else "flags"].add(flag)
  145. def is_(self, key: Key, flag: int) -> bool:
  146. if not key:
  147. return False # document root has no flags
  148. cont = self._flags
  149. for k in key[:-1]:
  150. if k not in cont:
  151. return False
  152. inner_cont = cont[k]
  153. if flag in inner_cont["recursive_flags"]:
  154. return True
  155. cont = inner_cont["nested"]
  156. key_stem = key[-1]
  157. if key_stem in cont:
  158. cont = cont[key_stem]
  159. return flag in cont["flags"] or flag in cont["recursive_flags"]
  160. return False
  161. class NestedDict:
  162. def __init__(self) -> None:
  163. # The parsed content of the TOML document
  164. self.dict: dict[str, Any] = {}
  165. def get_or_create_nest(
  166. self,
  167. key: Key,
  168. *,
  169. access_lists: bool = True,
  170. ) -> dict:
  171. cont: Any = self.dict
  172. for k in key:
  173. if k not in cont:
  174. cont[k] = {}
  175. cont = cont[k]
  176. if access_lists and isinstance(cont, list):
  177. cont = cont[-1]
  178. if not isinstance(cont, dict):
  179. raise KeyError("There is no nest behind this key")
  180. return cont
  181. def append_nest_to_list(self, key: Key) -> None:
  182. cont = self.get_or_create_nest(key[:-1])
  183. last_key = key[-1]
  184. if last_key in cont:
  185. list_ = cont[last_key]
  186. if not isinstance(list_, list):
  187. raise KeyError("An object other than list found behind this key")
  188. list_.append({})
  189. else:
  190. cont[last_key] = [{}]
  191. class Output(NamedTuple):
  192. data: NestedDict
  193. flags: Flags
  194. def skip_chars(src: str, pos: Pos, chars: Iterable[str]) -> Pos:
  195. try:
  196. while src[pos] in chars:
  197. pos += 1
  198. except IndexError:
  199. pass
  200. return pos
  201. def skip_until(
  202. src: str,
  203. pos: Pos,
  204. expect: str,
  205. *,
  206. error_on: frozenset[str],
  207. error_on_eof: bool,
  208. ) -> Pos:
  209. try:
  210. new_pos = src.index(expect, pos)
  211. except ValueError:
  212. new_pos = len(src)
  213. if error_on_eof:
  214. raise suffixed_err(src, new_pos, f"Expected {expect!r}") from None
  215. if not error_on.isdisjoint(src[pos:new_pos]):
  216. while src[pos] not in error_on:
  217. pos += 1
  218. raise suffixed_err(src, pos, f"Found invalid character {src[pos]!r}")
  219. return new_pos
  220. def skip_comment(src: str, pos: Pos) -> Pos:
  221. try:
  222. char: str | None = src[pos]
  223. except IndexError:
  224. char = None
  225. if char == "#":
  226. return skip_until(
  227. src, pos + 1, "\n", error_on=ILLEGAL_COMMENT_CHARS, error_on_eof=False
  228. )
  229. return pos
  230. def skip_comments_and_array_ws(src: str, pos: Pos) -> Pos:
  231. while True:
  232. pos_before_skip = pos
  233. pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
  234. pos = skip_comment(src, pos)
  235. if pos == pos_before_skip:
  236. return pos
  237. def create_dict_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
  238. pos += 1 # Skip "["
  239. pos = skip_chars(src, pos, TOML_WS)
  240. pos, key = parse_key(src, pos)
  241. if out.flags.is_(key, Flags.EXPLICIT_NEST) or out.flags.is_(key, Flags.FROZEN):
  242. raise suffixed_err(src, pos, f"Cannot declare {key} twice")
  243. out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
  244. try:
  245. out.data.get_or_create_nest(key)
  246. except KeyError:
  247. raise suffixed_err(src, pos, "Cannot overwrite a value") from None
  248. if not src.startswith("]", pos):
  249. raise suffixed_err(src, pos, "Expected ']' at the end of a table declaration")
  250. return pos + 1, key
  251. def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
  252. pos += 2 # Skip "[["
  253. pos = skip_chars(src, pos, TOML_WS)
  254. pos, key = parse_key(src, pos)
  255. if out.flags.is_(key, Flags.FROZEN):
  256. raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
  257. # Free the namespace now that it points to another empty list item...
  258. out.flags.unset_all(key)
  259. # ...but this key precisely is still prohibited from table declaration
  260. out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
  261. try:
  262. out.data.append_nest_to_list(key)
  263. except KeyError:
  264. raise suffixed_err(src, pos, "Cannot overwrite a value") from None
  265. if not src.startswith("]]", pos):
  266. raise suffixed_err(src, pos, "Expected ']]' at the end of an array declaration")
  267. return pos + 2, key
  268. def key_value_rule(
  269. src: str, pos: Pos, out: Output, header: Key, parse_float: ParseFloat
  270. ) -> Pos:
  271. pos, key, value = parse_key_value_pair(src, pos, parse_float)
  272. key_parent, key_stem = key[:-1], key[-1]
  273. abs_key_parent = header + key_parent
  274. relative_path_cont_keys = (header + key[:i] for i in range(1, len(key)))
  275. for cont_key in relative_path_cont_keys:
  276. # Check that dotted key syntax does not redefine an existing table
  277. if out.flags.is_(cont_key, Flags.EXPLICIT_NEST):
  278. raise suffixed_err(src, pos, f"Cannot redefine namespace {cont_key}")
  279. # Containers in the relative path can't be opened with the table syntax or
  280. # dotted key/value syntax in following table sections.
  281. out.flags.add_pending(cont_key, Flags.EXPLICIT_NEST)
  282. if out.flags.is_(abs_key_parent, Flags.FROZEN):
  283. raise suffixed_err(
  284. src, pos, f"Cannot mutate immutable namespace {abs_key_parent}"
  285. )
  286. try:
  287. nest = out.data.get_or_create_nest(abs_key_parent)
  288. except KeyError:
  289. raise suffixed_err(src, pos, "Cannot overwrite a value") from None
  290. if key_stem in nest:
  291. raise suffixed_err(src, pos, "Cannot overwrite a value")
  292. # Mark inline table and array namespaces recursively immutable
  293. if isinstance(value, (dict, list)):
  294. out.flags.set(header + key, Flags.FROZEN, recursive=True)
  295. nest[key_stem] = value
  296. return pos
  297. def parse_key_value_pair(
  298. src: str, pos: Pos, parse_float: ParseFloat
  299. ) -> tuple[Pos, Key, Any]:
  300. pos, key = parse_key(src, pos)
  301. try:
  302. char: str | None = src[pos]
  303. except IndexError:
  304. char = None
  305. if char != "=":
  306. raise suffixed_err(src, pos, "Expected '=' after a key in a key/value pair")
  307. pos += 1
  308. pos = skip_chars(src, pos, TOML_WS)
  309. pos, value = parse_value(src, pos, parse_float)
  310. return pos, key, value
  311. def parse_key(src: str, pos: Pos) -> tuple[Pos, Key]:
  312. pos, key_part = parse_key_part(src, pos)
  313. key: Key = (key_part,)
  314. pos = skip_chars(src, pos, TOML_WS)
  315. while True:
  316. try:
  317. char: str | None = src[pos]
  318. except IndexError:
  319. char = None
  320. if char != ".":
  321. return pos, key
  322. pos += 1
  323. pos = skip_chars(src, pos, TOML_WS)
  324. pos, key_part = parse_key_part(src, pos)
  325. key += (key_part,)
  326. pos = skip_chars(src, pos, TOML_WS)
  327. def parse_key_part(src: str, pos: Pos) -> tuple[Pos, str]:
  328. try:
  329. char: str | None = src[pos]
  330. except IndexError:
  331. char = None
  332. if char in BARE_KEY_CHARS:
  333. start_pos = pos
  334. pos = skip_chars(src, pos, BARE_KEY_CHARS)
  335. return pos, src[start_pos:pos]
  336. if char == "'":
  337. return parse_literal_str(src, pos)
  338. if char == '"':
  339. return parse_one_line_basic_str(src, pos)
  340. raise suffixed_err(src, pos, "Invalid initial character for a key part")
  341. def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]:
  342. pos += 1
  343. return parse_basic_str(src, pos, multiline=False)
  344. def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]:
  345. pos += 1
  346. array: list = []
  347. pos = skip_comments_and_array_ws(src, pos)
  348. if src.startswith("]", pos):
  349. return pos + 1, array
  350. while True:
  351. pos, val = parse_value(src, pos, parse_float)
  352. array.append(val)
  353. pos = skip_comments_and_array_ws(src, pos)
  354. c = src[pos : pos + 1]
  355. if c == "]":
  356. return pos + 1, array
  357. if c != ",":
  358. raise suffixed_err(src, pos, "Unclosed array")
  359. pos += 1
  360. pos = skip_comments_and_array_ws(src, pos)
  361. if src.startswith("]", pos):
  362. return pos + 1, array
  363. def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, dict]:
  364. pos += 1
  365. nested_dict = NestedDict()
  366. flags = Flags()
  367. pos = skip_chars(src, pos, TOML_WS)
  368. if src.startswith("}", pos):
  369. return pos + 1, nested_dict.dict
  370. while True:
  371. pos, key, value = parse_key_value_pair(src, pos, parse_float)
  372. key_parent, key_stem = key[:-1], key[-1]
  373. if flags.is_(key, Flags.FROZEN):
  374. raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
  375. try:
  376. nest = nested_dict.get_or_create_nest(key_parent, access_lists=False)
  377. except KeyError:
  378. raise suffixed_err(src, pos, "Cannot overwrite a value") from None
  379. if key_stem in nest:
  380. raise suffixed_err(src, pos, f"Duplicate inline table key {key_stem!r}")
  381. nest[key_stem] = value
  382. pos = skip_chars(src, pos, TOML_WS)
  383. c = src[pos : pos + 1]
  384. if c == "}":
  385. return pos + 1, nested_dict.dict
  386. if c != ",":
  387. raise suffixed_err(src, pos, "Unclosed inline table")
  388. if isinstance(value, (dict, list)):
  389. flags.set(key, Flags.FROZEN, recursive=True)
  390. pos += 1
  391. pos = skip_chars(src, pos, TOML_WS)
  392. def parse_basic_str_escape(
  393. src: str, pos: Pos, *, multiline: bool = False
  394. ) -> tuple[Pos, str]:
  395. escape_id = src[pos : pos + 2]
  396. pos += 2
  397. if multiline and escape_id in {"\\ ", "\\\t", "\\\n"}:
  398. # Skip whitespace until next non-whitespace character or end of
  399. # the doc. Error if non-whitespace is found before newline.
  400. if escape_id != "\\\n":
  401. pos = skip_chars(src, pos, TOML_WS)
  402. try:
  403. char = src[pos]
  404. except IndexError:
  405. return pos, ""
  406. if char != "\n":
  407. raise suffixed_err(src, pos, "Unescaped '\\' in a string")
  408. pos += 1
  409. pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
  410. return pos, ""
  411. if escape_id == "\\u":
  412. return parse_hex_char(src, pos, 4)
  413. if escape_id == "\\U":
  414. return parse_hex_char(src, pos, 8)
  415. try:
  416. return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id]
  417. except KeyError:
  418. raise suffixed_err(src, pos, "Unescaped '\\' in a string") from None
  419. def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]:
  420. return parse_basic_str_escape(src, pos, multiline=True)
  421. def parse_hex_char(src: str, pos: Pos, hex_len: int) -> tuple[Pos, str]:
  422. hex_str = src[pos : pos + hex_len]
  423. if len(hex_str) != hex_len or not HEXDIGIT_CHARS.issuperset(hex_str):
  424. raise suffixed_err(src, pos, "Invalid hex value")
  425. pos += hex_len
  426. hex_int = int(hex_str, 16)
  427. if not is_unicode_scalar_value(hex_int):
  428. raise suffixed_err(src, pos, "Escaped character is not a Unicode scalar value")
  429. return pos, chr(hex_int)
  430. def parse_literal_str(src: str, pos: Pos) -> tuple[Pos, str]:
  431. pos += 1 # Skip starting apostrophe
  432. start_pos = pos
  433. pos = skip_until(
  434. src, pos, "'", error_on=ILLEGAL_LITERAL_STR_CHARS, error_on_eof=True
  435. )
  436. return pos + 1, src[start_pos:pos] # Skip ending apostrophe
  437. def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]:
  438. pos += 3
  439. if src.startswith("\n", pos):
  440. pos += 1
  441. if literal:
  442. delim = "'"
  443. end_pos = skip_until(
  444. src,
  445. pos,
  446. "'''",
  447. error_on=ILLEGAL_MULTILINE_LITERAL_STR_CHARS,
  448. error_on_eof=True,
  449. )
  450. result = src[pos:end_pos]
  451. pos = end_pos + 3
  452. else:
  453. delim = '"'
  454. pos, result = parse_basic_str(src, pos, multiline=True)
  455. # Add at maximum two extra apostrophes/quotes if the end sequence
  456. # is 4 or 5 chars long instead of just 3.
  457. if not src.startswith(delim, pos):
  458. return pos, result
  459. pos += 1
  460. if not src.startswith(delim, pos):
  461. return pos, result + delim
  462. pos += 1
  463. return pos, result + (delim * 2)
  464. def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
  465. if multiline:
  466. error_on = ILLEGAL_MULTILINE_BASIC_STR_CHARS
  467. parse_escapes = parse_basic_str_escape_multiline
  468. else:
  469. error_on = ILLEGAL_BASIC_STR_CHARS
  470. parse_escapes = parse_basic_str_escape
  471. result = ""
  472. start_pos = pos
  473. while True:
  474. try:
  475. char = src[pos]
  476. except IndexError:
  477. raise suffixed_err(src, pos, "Unterminated string") from None
  478. if char == '"':
  479. if not multiline:
  480. return pos + 1, result + src[start_pos:pos]
  481. if src.startswith('"""', pos):
  482. return pos + 3, result + src[start_pos:pos]
  483. pos += 1
  484. continue
  485. if char == "\\":
  486. result += src[start_pos:pos]
  487. pos, parsed_escape = parse_escapes(src, pos)
  488. result += parsed_escape
  489. start_pos = pos
  490. continue
  491. if char in error_on:
  492. raise suffixed_err(src, pos, f"Illegal character {char!r}")
  493. pos += 1
  494. def parse_value( # noqa: C901
  495. src: str, pos: Pos, parse_float: ParseFloat
  496. ) -> tuple[Pos, Any]:
  497. try:
  498. char: str | None = src[pos]
  499. except IndexError:
  500. char = None
  501. # IMPORTANT: order conditions based on speed of checking and likelihood
  502. # Basic strings
  503. if char == '"':
  504. if src.startswith('"""', pos):
  505. return parse_multiline_str(src, pos, literal=False)
  506. return parse_one_line_basic_str(src, pos)
  507. # Literal strings
  508. if char == "'":
  509. if src.startswith("'''", pos):
  510. return parse_multiline_str(src, pos, literal=True)
  511. return parse_literal_str(src, pos)
  512. # Booleans
  513. if char == "t":
  514. if src.startswith("true", pos):
  515. return pos + 4, True
  516. if char == "f":
  517. if src.startswith("false", pos):
  518. return pos + 5, False
  519. # Arrays
  520. if char == "[":
  521. return parse_array(src, pos, parse_float)
  522. # Inline tables
  523. if char == "{":
  524. return parse_inline_table(src, pos, parse_float)
  525. # Dates and times
  526. datetime_match = RE_DATETIME.match(src, pos)
  527. if datetime_match:
  528. try:
  529. datetime_obj = match_to_datetime(datetime_match)
  530. except ValueError as e:
  531. raise suffixed_err(src, pos, "Invalid date or datetime") from e
  532. return datetime_match.end(), datetime_obj
  533. localtime_match = RE_LOCALTIME.match(src, pos)
  534. if localtime_match:
  535. return localtime_match.end(), match_to_localtime(localtime_match)
  536. # Integers and "normal" floats.
  537. # The regex will greedily match any type starting with a decimal
  538. # char, so needs to be located after handling of dates and times.
  539. number_match = RE_NUMBER.match(src, pos)
  540. if number_match:
  541. return number_match.end(), match_to_number(number_match, parse_float)
  542. # Special floats
  543. first_three = src[pos : pos + 3]
  544. if first_three in {"inf", "nan"}:
  545. return pos + 3, parse_float(first_three)
  546. first_four = src[pos : pos + 4]
  547. if first_four in {"-inf", "+inf", "-nan", "+nan"}:
  548. return pos + 4, parse_float(first_four)
  549. raise suffixed_err(src, pos, "Invalid value")
  550. def suffixed_err(src: str, pos: Pos, msg: str) -> TOMLDecodeError:
  551. """Return a `TOMLDecodeError` where error message is suffixed with
  552. coordinates in source."""
  553. def coord_repr(src: str, pos: Pos) -> str:
  554. if pos >= len(src):
  555. return "end of document"
  556. line = src.count("\n", 0, pos) + 1
  557. if line == 1:
  558. column = pos + 1
  559. else:
  560. column = pos - src.rindex("\n", 0, pos)
  561. return f"line {line}, column {column}"
  562. return TOMLDecodeError(f"{msg} (at {coord_repr(src, pos)})")
  563. def is_unicode_scalar_value(codepoint: int) -> bool:
  564. return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111)
  565. def make_safe_parse_float(parse_float: ParseFloat) -> ParseFloat:
  566. """A decorator to make `parse_float` safe.
  567. `parse_float` must not return dicts or lists, because these types
  568. would be mixed with parsed TOML tables and arrays, thus confusing
  569. the parser. The returned decorated callable raises `ValueError`
  570. instead of returning illegal types.
  571. """
  572. # The default `float` callable never returns illegal types. Optimize it.
  573. if parse_float is float: # type: ignore[comparison-overlap]
  574. return float
  575. def safe_parse_float(float_str: str) -> Any:
  576. float_value = parse_float(float_str)
  577. if isinstance(float_value, (dict, list)):
  578. raise ValueError("parse_float must not return dicts or lists")
  579. return float_value
  580. return safe_parse_float