_cmp.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # SPDX-License-Identifier: MIT
  2. import functools
  3. import types
  4. from ._make import _make_ne
  5. _operation_names = {"eq": "==", "lt": "<", "le": "<=", "gt": ">", "ge": ">="}
  6. def cmp_using(
  7. eq=None,
  8. lt=None,
  9. le=None,
  10. gt=None,
  11. ge=None,
  12. require_same_type=True,
  13. class_name="Comparable",
  14. ):
  15. """
  16. Create a class that can be passed into `attrs.field`'s ``eq``, ``order``,
  17. and ``cmp`` arguments to customize field comparison.
  18. The resulting class will have a full set of ordering methods if at least
  19. one of ``{lt, le, gt, ge}`` and ``eq`` are provided.
  20. :param Optional[callable] eq: `callable` used to evaluate equality of two
  21. objects.
  22. :param Optional[callable] lt: `callable` used to evaluate whether one
  23. object is less than another object.
  24. :param Optional[callable] le: `callable` used to evaluate whether one
  25. object is less than or equal to another object.
  26. :param Optional[callable] gt: `callable` used to evaluate whether one
  27. object is greater than another object.
  28. :param Optional[callable] ge: `callable` used to evaluate whether one
  29. object is greater than or equal to another object.
  30. :param bool require_same_type: When `True`, equality and ordering methods
  31. will return `NotImplemented` if objects are not of the same type.
  32. :param Optional[str] class_name: Name of class. Defaults to 'Comparable'.
  33. See `comparison` for more details.
  34. .. versionadded:: 21.1.0
  35. """
  36. body = {
  37. "__slots__": ["value"],
  38. "__init__": _make_init(),
  39. "_requirements": [],
  40. "_is_comparable_to": _is_comparable_to,
  41. }
  42. # Add operations.
  43. num_order_functions = 0
  44. has_eq_function = False
  45. if eq is not None:
  46. has_eq_function = True
  47. body["__eq__"] = _make_operator("eq", eq)
  48. body["__ne__"] = _make_ne()
  49. if lt is not None:
  50. num_order_functions += 1
  51. body["__lt__"] = _make_operator("lt", lt)
  52. if le is not None:
  53. num_order_functions += 1
  54. body["__le__"] = _make_operator("le", le)
  55. if gt is not None:
  56. num_order_functions += 1
  57. body["__gt__"] = _make_operator("gt", gt)
  58. if ge is not None:
  59. num_order_functions += 1
  60. body["__ge__"] = _make_operator("ge", ge)
  61. type_ = types.new_class(
  62. class_name, (object,), {}, lambda ns: ns.update(body)
  63. )
  64. # Add same type requirement.
  65. if require_same_type:
  66. type_._requirements.append(_check_same_type)
  67. # Add total ordering if at least one operation was defined.
  68. if 0 < num_order_functions < 4:
  69. if not has_eq_function:
  70. # functools.total_ordering requires __eq__ to be defined,
  71. # so raise early error here to keep a nice stack.
  72. msg = "eq must be define is order to complete ordering from lt, le, gt, ge."
  73. raise ValueError(msg)
  74. type_ = functools.total_ordering(type_)
  75. return type_
  76. def _make_init():
  77. """
  78. Create __init__ method.
  79. """
  80. def __init__(self, value):
  81. """
  82. Initialize object with *value*.
  83. """
  84. self.value = value
  85. return __init__
  86. def _make_operator(name, func):
  87. """
  88. Create operator method.
  89. """
  90. def method(self, other):
  91. if not self._is_comparable_to(other):
  92. return NotImplemented
  93. result = func(self.value, other.value)
  94. if result is NotImplemented:
  95. return NotImplemented
  96. return result
  97. method.__name__ = f"__{name}__"
  98. method.__doc__ = (
  99. f"Return a {_operation_names[name]} b. Computed by attrs."
  100. )
  101. return method
  102. def _is_comparable_to(self, other):
  103. """
  104. Check whether `other` is comparable to `self`.
  105. """
  106. return all(func(self, other) for func in self._requirements)
  107. def _check_same_type(self, other):
  108. """
  109. Return True if *self* and *other* are of the same type, False otherwise.
  110. """
  111. return other.value.__class__ is self.value.__class__