_impl.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from contextvars import ContextVar
  2. from typing import Optional
  3. import sys
  4. import threading
  5. current_async_library_cvar = ContextVar(
  6. "current_async_library_cvar", default=None
  7. ) # type: ContextVar[Optional[str]]
  8. class _ThreadLocal(threading.local):
  9. # Since threading.local provides no explicit mechanism is for setting
  10. # a default for a value, a custom class with a class attribute is used
  11. # instead.
  12. name = None # type: Optional[str]
  13. thread_local = _ThreadLocal()
  14. class AsyncLibraryNotFoundError(RuntimeError):
  15. pass
  16. def current_async_library() -> str:
  17. """Detect which async library is currently running.
  18. The following libraries are currently supported:
  19. ================ =========== ============================
  20. Library Requires Magic string
  21. ================ =========== ============================
  22. **Trio** Trio v0.6+ ``"trio"``
  23. **Curio** - ``"curio"``
  24. **asyncio** ``"asyncio"``
  25. **Trio-asyncio** v0.8.2+ ``"trio"`` or ``"asyncio"``,
  26. depending on current mode
  27. ================ =========== ============================
  28. Returns:
  29. A string like ``"trio"``.
  30. Raises:
  31. AsyncLibraryNotFoundError: if called from synchronous context,
  32. or if the current async library was not recognized.
  33. Examples:
  34. .. code-block:: python3
  35. from sniffio import current_async_library
  36. async def generic_sleep(seconds):
  37. library = current_async_library()
  38. if library == "trio":
  39. import trio
  40. await trio.sleep(seconds)
  41. elif library == "asyncio":
  42. import asyncio
  43. await asyncio.sleep(seconds)
  44. # ... and so on ...
  45. else:
  46. raise RuntimeError(f"Unsupported library {library!r}")
  47. """
  48. value = thread_local.name
  49. if value is not None:
  50. return value
  51. value = current_async_library_cvar.get()
  52. if value is not None:
  53. return value
  54. # Need to sniff for asyncio
  55. if "asyncio" in sys.modules:
  56. import asyncio
  57. try:
  58. current_task = asyncio.current_task # type: ignore[attr-defined]
  59. except AttributeError:
  60. current_task = asyncio.Task.current_task # type: ignore[attr-defined]
  61. try:
  62. if current_task() is not None:
  63. return "asyncio"
  64. except RuntimeError:
  65. pass
  66. # Sniff for curio (for now)
  67. if 'curio' in sys.modules:
  68. from curio.meta import curio_running
  69. if curio_running():
  70. return 'curio'
  71. raise AsyncLibraryNotFoundError(
  72. "unknown async library, or not in async context"
  73. )