_reloader.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. from __future__ import annotations
  2. import fnmatch
  3. import os
  4. import subprocess
  5. import sys
  6. import threading
  7. import time
  8. import typing as t
  9. from itertools import chain
  10. from pathlib import PurePath
  11. from ._internal import _log
  12. # The various system prefixes where imports are found. Base values are
  13. # different when running in a virtualenv. All reloaders will ignore the
  14. # base paths (usually the system installation). The stat reloader won't
  15. # scan the virtualenv paths, it will only include modules that are
  16. # already imported.
  17. _ignore_always = tuple({sys.base_prefix, sys.base_exec_prefix})
  18. prefix = {*_ignore_always, sys.prefix, sys.exec_prefix}
  19. if hasattr(sys, "real_prefix"):
  20. # virtualenv < 20
  21. prefix.add(sys.real_prefix)
  22. _stat_ignore_scan = tuple(prefix)
  23. del prefix
  24. _ignore_common_dirs = {
  25. "__pycache__",
  26. ".git",
  27. ".hg",
  28. ".tox",
  29. ".nox",
  30. ".pytest_cache",
  31. ".mypy_cache",
  32. }
  33. def _iter_module_paths() -> t.Iterator[str]:
  34. """Find the filesystem paths associated with imported modules."""
  35. # List is in case the value is modified by the app while updating.
  36. for module in list(sys.modules.values()):
  37. name = getattr(module, "__file__", None)
  38. if name is None or name.startswith(_ignore_always):
  39. continue
  40. while not os.path.isfile(name):
  41. # Zip file, find the base file without the module path.
  42. old = name
  43. name = os.path.dirname(name)
  44. if name == old: # skip if it was all directories somehow
  45. break
  46. else:
  47. yield name
  48. def _remove_by_pattern(paths: set[str], exclude_patterns: set[str]) -> None:
  49. for pattern in exclude_patterns:
  50. paths.difference_update(fnmatch.filter(paths, pattern))
  51. def _find_stat_paths(
  52. extra_files: set[str], exclude_patterns: set[str]
  53. ) -> t.Iterable[str]:
  54. """Find paths for the stat reloader to watch. Returns imported
  55. module files, Python files under non-system paths. Extra files and
  56. Python files under extra directories can also be scanned.
  57. System paths have to be excluded for efficiency. Non-system paths,
  58. such as a project root or ``sys.path.insert``, should be the paths
  59. of interest to the user anyway.
  60. """
  61. paths = set()
  62. for path in chain(list(sys.path), extra_files):
  63. path = os.path.abspath(path)
  64. if os.path.isfile(path):
  65. # zip file on sys.path, or extra file
  66. paths.add(path)
  67. continue
  68. parent_has_py = {os.path.dirname(path): True}
  69. for root, dirs, files in os.walk(path):
  70. # Optimizations: ignore system prefixes, __pycache__ will
  71. # have a py or pyc module at the import path, ignore some
  72. # common known dirs such as version control and tool caches.
  73. if (
  74. root.startswith(_stat_ignore_scan)
  75. or os.path.basename(root) in _ignore_common_dirs
  76. ):
  77. dirs.clear()
  78. continue
  79. has_py = False
  80. for name in files:
  81. if name.endswith((".py", ".pyc")):
  82. has_py = True
  83. paths.add(os.path.join(root, name))
  84. # Optimization: stop scanning a directory if neither it nor
  85. # its parent contained Python files.
  86. if not (has_py or parent_has_py[os.path.dirname(root)]):
  87. dirs.clear()
  88. continue
  89. parent_has_py[root] = has_py
  90. paths.update(_iter_module_paths())
  91. _remove_by_pattern(paths, exclude_patterns)
  92. return paths
  93. def _find_watchdog_paths(
  94. extra_files: set[str], exclude_patterns: set[str]
  95. ) -> t.Iterable[str]:
  96. """Find paths for the stat reloader to watch. Looks at the same
  97. sources as the stat reloader, but watches everything under
  98. directories instead of individual files.
  99. """
  100. dirs = set()
  101. for name in chain(list(sys.path), extra_files):
  102. name = os.path.abspath(name)
  103. if os.path.isfile(name):
  104. name = os.path.dirname(name)
  105. dirs.add(name)
  106. for name in _iter_module_paths():
  107. dirs.add(os.path.dirname(name))
  108. _remove_by_pattern(dirs, exclude_patterns)
  109. return _find_common_roots(dirs)
  110. def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]:
  111. root: dict[str, dict] = {}
  112. for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True):
  113. node = root
  114. for chunk in chunks:
  115. node = node.setdefault(chunk, {})
  116. node.clear()
  117. rv = set()
  118. def _walk(node: t.Mapping[str, dict], path: tuple[str, ...]) -> None:
  119. for prefix, child in node.items():
  120. _walk(child, path + (prefix,))
  121. if not node:
  122. rv.add(os.path.join(*path))
  123. _walk(root, ())
  124. return rv
  125. def _get_args_for_reloading() -> list[str]:
  126. """Determine how the script was executed, and return the args needed
  127. to execute it again in a new process.
  128. """
  129. if sys.version_info >= (3, 10):
  130. # sys.orig_argv, added in Python 3.10, contains the exact args used to invoke
  131. # Python. Still replace argv[0] with sys.executable for accuracy.
  132. return [sys.executable, *sys.orig_argv[1:]]
  133. rv = [sys.executable]
  134. py_script = sys.argv[0]
  135. args = sys.argv[1:]
  136. # Need to look at main module to determine how it was executed.
  137. __main__ = sys.modules["__main__"]
  138. # The value of __package__ indicates how Python was called. It may
  139. # not exist if a setuptools script is installed as an egg. It may be
  140. # set incorrectly for entry points created with pip on Windows.
  141. if getattr(__main__, "__package__", None) is None or (
  142. os.name == "nt"
  143. and __main__.__package__ == ""
  144. and not os.path.exists(py_script)
  145. and os.path.exists(f"{py_script}.exe")
  146. ):
  147. # Executed a file, like "python app.py".
  148. py_script = os.path.abspath(py_script)
  149. if os.name == "nt":
  150. # Windows entry points have ".exe" extension and should be
  151. # called directly.
  152. if not os.path.exists(py_script) and os.path.exists(f"{py_script}.exe"):
  153. py_script += ".exe"
  154. if (
  155. os.path.splitext(sys.executable)[1] == ".exe"
  156. and os.path.splitext(py_script)[1] == ".exe"
  157. ):
  158. rv.pop(0)
  159. rv.append(py_script)
  160. else:
  161. # Executed a module, like "python -m werkzeug.serving".
  162. if os.path.isfile(py_script):
  163. # Rewritten by Python from "-m script" to "/path/to/script.py".
  164. py_module = t.cast(str, __main__.__package__)
  165. name = os.path.splitext(os.path.basename(py_script))[0]
  166. if name != "__main__":
  167. py_module += f".{name}"
  168. else:
  169. # Incorrectly rewritten by pydevd debugger from "-m script" to "script".
  170. py_module = py_script
  171. rv.extend(("-m", py_module.lstrip(".")))
  172. rv.extend(args)
  173. return rv
  174. class ReloaderLoop:
  175. name = ""
  176. def __init__(
  177. self,
  178. extra_files: t.Iterable[str] | None = None,
  179. exclude_patterns: t.Iterable[str] | None = None,
  180. interval: int | float = 1,
  181. ) -> None:
  182. self.extra_files: set[str] = {os.path.abspath(x) for x in extra_files or ()}
  183. self.exclude_patterns: set[str] = set(exclude_patterns or ())
  184. self.interval = interval
  185. def __enter__(self) -> ReloaderLoop:
  186. """Do any setup, then run one step of the watch to populate the
  187. initial filesystem state.
  188. """
  189. self.run_step()
  190. return self
  191. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  192. """Clean up any resources associated with the reloader."""
  193. pass
  194. def run(self) -> None:
  195. """Continually run the watch step, sleeping for the configured
  196. interval after each step.
  197. """
  198. while True:
  199. self.run_step()
  200. time.sleep(self.interval)
  201. def run_step(self) -> None:
  202. """Run one step for watching the filesystem. Called once to set
  203. up initial state, then repeatedly to update it.
  204. """
  205. pass
  206. def restart_with_reloader(self) -> int:
  207. """Spawn a new Python interpreter with the same arguments as the
  208. current one, but running the reloader thread.
  209. """
  210. while True:
  211. _log("info", f" * Restarting with {self.name}")
  212. args = _get_args_for_reloading()
  213. new_environ = os.environ.copy()
  214. new_environ["WERKZEUG_RUN_MAIN"] = "true"
  215. exit_code = subprocess.call(args, env=new_environ, close_fds=False)
  216. if exit_code != 3:
  217. return exit_code
  218. def trigger_reload(self, filename: str) -> None:
  219. self.log_reload(filename)
  220. sys.exit(3)
  221. def log_reload(self, filename: str) -> None:
  222. filename = os.path.abspath(filename)
  223. _log("info", f" * Detected change in {filename!r}, reloading")
  224. class StatReloaderLoop(ReloaderLoop):
  225. name = "stat"
  226. def __enter__(self) -> ReloaderLoop:
  227. self.mtimes: dict[str, float] = {}
  228. return super().__enter__()
  229. def run_step(self) -> None:
  230. for name in _find_stat_paths(self.extra_files, self.exclude_patterns):
  231. try:
  232. mtime = os.stat(name).st_mtime
  233. except OSError:
  234. continue
  235. old_time = self.mtimes.get(name)
  236. if old_time is None:
  237. self.mtimes[name] = mtime
  238. continue
  239. if mtime > old_time:
  240. self.trigger_reload(name)
  241. class WatchdogReloaderLoop(ReloaderLoop):
  242. def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
  243. from watchdog.observers import Observer
  244. from watchdog.events import PatternMatchingEventHandler
  245. from watchdog.events import EVENT_TYPE_OPENED
  246. from watchdog.events import FileModifiedEvent
  247. super().__init__(*args, **kwargs)
  248. trigger_reload = self.trigger_reload
  249. class EventHandler(PatternMatchingEventHandler):
  250. def on_any_event(self, event: FileModifiedEvent): # type: ignore
  251. if event.event_type == EVENT_TYPE_OPENED:
  252. return
  253. trigger_reload(event.src_path)
  254. reloader_name = Observer.__name__.lower() # type: ignore[attr-defined]
  255. if reloader_name.endswith("observer"):
  256. reloader_name = reloader_name[:-8]
  257. self.name = f"watchdog ({reloader_name})"
  258. self.observer = Observer()
  259. # Extra patterns can be non-Python files, match them in addition
  260. # to all Python files in default and extra directories. Ignore
  261. # __pycache__ since a change there will always have a change to
  262. # the source file (or initial pyc file) as well. Ignore Git and
  263. # Mercurial internal changes.
  264. extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)]
  265. self.event_handler = EventHandler(
  266. patterns=["*.py", "*.pyc", "*.zip", *extra_patterns],
  267. ignore_patterns=[
  268. *[f"*/{d}/*" for d in _ignore_common_dirs],
  269. *self.exclude_patterns,
  270. ],
  271. )
  272. self.should_reload = False
  273. def trigger_reload(self, filename: str) -> None:
  274. # This is called inside an event handler, which means throwing
  275. # SystemExit has no effect.
  276. # https://github.com/gorakhargosh/watchdog/issues/294
  277. self.should_reload = True
  278. self.log_reload(filename)
  279. def __enter__(self) -> ReloaderLoop:
  280. self.watches: dict[str, t.Any] = {}
  281. self.observer.start()
  282. return super().__enter__()
  283. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  284. self.observer.stop()
  285. self.observer.join()
  286. def run(self) -> None:
  287. while not self.should_reload:
  288. self.run_step()
  289. time.sleep(self.interval)
  290. sys.exit(3)
  291. def run_step(self) -> None:
  292. to_delete = set(self.watches)
  293. for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns):
  294. if path not in self.watches:
  295. try:
  296. self.watches[path] = self.observer.schedule(
  297. self.event_handler, path, recursive=True
  298. )
  299. except OSError:
  300. # Clear this path from list of watches We don't want
  301. # the same error message showing again in the next
  302. # iteration.
  303. self.watches[path] = None
  304. to_delete.discard(path)
  305. for path in to_delete:
  306. watch = self.watches.pop(path, None)
  307. if watch is not None:
  308. self.observer.unschedule(watch)
  309. reloader_loops: dict[str, type[ReloaderLoop]] = {
  310. "stat": StatReloaderLoop,
  311. "watchdog": WatchdogReloaderLoop,
  312. }
  313. try:
  314. __import__("watchdog.observers")
  315. except ImportError:
  316. reloader_loops["auto"] = reloader_loops["stat"]
  317. else:
  318. reloader_loops["auto"] = reloader_loops["watchdog"]
  319. def ensure_echo_on() -> None:
  320. """Ensure that echo mode is enabled. Some tools such as PDB disable
  321. it which causes usability issues after a reload."""
  322. # tcgetattr will fail if stdin isn't a tty
  323. if sys.stdin is None or not sys.stdin.isatty():
  324. return
  325. try:
  326. import termios
  327. except ImportError:
  328. return
  329. attributes = termios.tcgetattr(sys.stdin)
  330. if not attributes[3] & termios.ECHO:
  331. attributes[3] |= termios.ECHO
  332. termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes)
  333. def run_with_reloader(
  334. main_func: t.Callable[[], None],
  335. extra_files: t.Iterable[str] | None = None,
  336. exclude_patterns: t.Iterable[str] | None = None,
  337. interval: int | float = 1,
  338. reloader_type: str = "auto",
  339. ) -> None:
  340. """Run the given function in an independent Python interpreter."""
  341. import signal
  342. signal.signal(signal.SIGTERM, lambda *args: sys.exit(0))
  343. reloader = reloader_loops[reloader_type](
  344. extra_files=extra_files, exclude_patterns=exclude_patterns, interval=interval
  345. )
  346. try:
  347. if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
  348. ensure_echo_on()
  349. t = threading.Thread(target=main_func, args=())
  350. t.daemon = True
  351. # Enter the reloader to set up initial state, then start
  352. # the app thread and reloader update loop.
  353. with reloader:
  354. t.start()
  355. reloader.run()
  356. else:
  357. sys.exit(reloader.restart_with_reloader())
  358. except KeyboardInterrupt:
  359. pass