structs.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import itertools
  2. from .compat import collections_abc
  3. class DirectedGraph(object):
  4. """A graph structure with directed edges."""
  5. def __init__(self):
  6. self._vertices = set()
  7. self._forwards = {} # <key> -> Set[<key>]
  8. self._backwards = {} # <key> -> Set[<key>]
  9. def __iter__(self):
  10. return iter(self._vertices)
  11. def __len__(self):
  12. return len(self._vertices)
  13. def __contains__(self, key):
  14. return key in self._vertices
  15. def copy(self):
  16. """Return a shallow copy of this graph."""
  17. other = DirectedGraph()
  18. other._vertices = set(self._vertices)
  19. other._forwards = {k: set(v) for k, v in self._forwards.items()}
  20. other._backwards = {k: set(v) for k, v in self._backwards.items()}
  21. return other
  22. def add(self, key):
  23. """Add a new vertex to the graph."""
  24. if key in self._vertices:
  25. raise ValueError("vertex exists")
  26. self._vertices.add(key)
  27. self._forwards[key] = set()
  28. self._backwards[key] = set()
  29. def remove(self, key):
  30. """Remove a vertex from the graph, disconnecting all edges from/to it."""
  31. self._vertices.remove(key)
  32. for f in self._forwards.pop(key):
  33. self._backwards[f].remove(key)
  34. for t in self._backwards.pop(key):
  35. self._forwards[t].remove(key)
  36. def connected(self, f, t):
  37. return f in self._backwards[t] and t in self._forwards[f]
  38. def connect(self, f, t):
  39. """Connect two existing vertices.
  40. Nothing happens if the vertices are already connected.
  41. """
  42. if t not in self._vertices:
  43. raise KeyError(t)
  44. self._forwards[f].add(t)
  45. self._backwards[t].add(f)
  46. def iter_edges(self):
  47. for f, children in self._forwards.items():
  48. for t in children:
  49. yield f, t
  50. def iter_children(self, key):
  51. return iter(self._forwards[key])
  52. def iter_parents(self, key):
  53. return iter(self._backwards[key])
  54. class IteratorMapping(collections_abc.Mapping):
  55. def __init__(self, mapping, accessor, appends=None):
  56. self._mapping = mapping
  57. self._accessor = accessor
  58. self._appends = appends or {}
  59. def __repr__(self):
  60. return "IteratorMapping({!r}, {!r}, {!r})".format(
  61. self._mapping,
  62. self._accessor,
  63. self._appends,
  64. )
  65. def __bool__(self):
  66. return bool(self._mapping or self._appends)
  67. __nonzero__ = __bool__ # XXX: Python 2.
  68. def __contains__(self, key):
  69. return key in self._mapping or key in self._appends
  70. def __getitem__(self, k):
  71. try:
  72. v = self._mapping[k]
  73. except KeyError:
  74. return iter(self._appends[k])
  75. return itertools.chain(self._accessor(v), self._appends.get(k, ()))
  76. def __iter__(self):
  77. more = (k for k in self._appends if k not in self._mapping)
  78. return itertools.chain(self._mapping, more)
  79. def __len__(self):
  80. more = sum(1 for k in self._appends if k not in self._mapping)
  81. return len(self._mapping) + more
  82. class _FactoryIterableView(object):
  83. """Wrap an iterator factory returned by `find_matches()`.
  84. Calling `iter()` on this class would invoke the underlying iterator
  85. factory, making it a "collection with ordering" that can be iterated
  86. through multiple times, but lacks random access methods presented in
  87. built-in Python sequence types.
  88. """
  89. def __init__(self, factory):
  90. self._factory = factory
  91. def __repr__(self):
  92. return "{}({})".format(type(self).__name__, list(self._factory()))
  93. def __bool__(self):
  94. try:
  95. next(self._factory())
  96. except StopIteration:
  97. return False
  98. return True
  99. __nonzero__ = __bool__ # XXX: Python 2.
  100. def __iter__(self):
  101. return self._factory()
  102. class _SequenceIterableView(object):
  103. """Wrap an iterable returned by find_matches().
  104. This is essentially just a proxy to the underlying sequence that provides
  105. the same interface as `_FactoryIterableView`.
  106. """
  107. def __init__(self, sequence):
  108. self._sequence = sequence
  109. def __repr__(self):
  110. return "{}({})".format(type(self).__name__, self._sequence)
  111. def __bool__(self):
  112. return bool(self._sequence)
  113. __nonzero__ = __bool__ # XXX: Python 2.
  114. def __iter__(self):
  115. return iter(self._sequence)
  116. def build_iter_view(matches):
  117. """Build an iterable view from the value returned by `find_matches()`."""
  118. if callable(matches):
  119. return _FactoryIterableView(matches)
  120. if not isinstance(matches, collections_abc.Sequence):
  121. matches = list(matches)
  122. return _SequenceIterableView(matches)