poset.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # This code is part of Grandalf
  2. # Copyright (C) 2008-2015 Axel Tillequin (bdcht3@gmail.com) and others
  3. # published under GPLv2 license or EPLv1 license
  4. # Contributor(s): Axel Tillequin
  5. from collections import OrderedDict
  6. __all__ = ["Poset"]
  7. # ------------------------------------------------------------------------------
  8. # Poset class implements a set but allows to interate over the elements in a
  9. # deterministic way and to get specific objects in the set.
  10. # Membership operator defaults to comparing __hash__ of objects but Poset
  11. # allows to check for __cmp__/__eq__ membership by using contains__cmp__(obj)
  12. class Poset(object):
  13. def __init__(self, L):
  14. self.o = OrderedDict()
  15. for obj in L:
  16. self.add(obj)
  17. def __repr__(self):
  18. return "Poset(%r)" % (self.o,)
  19. def __str__(self):
  20. f = "%%%dd" % len(str(len(self.o)))
  21. s = []
  22. for i, x in enumerate(self.o.values()):
  23. s.append(f % i + ".| %s" % repr(x))
  24. return "\n".join(s)
  25. def add(self, obj):
  26. if obj in self:
  27. return self.get(obj)
  28. else:
  29. self.o[obj] = obj
  30. return obj
  31. def remove(self, obj):
  32. if obj in self:
  33. obj = self.get(obj)
  34. del self.o[obj]
  35. return obj
  36. return None
  37. def index(self, obj):
  38. return list(self.o.values()).index(obj)
  39. def get(self, obj):
  40. return self.o.get(obj, None)
  41. def __getitem__(self, i):
  42. return list(self.o.values())[i]
  43. def __len__(self):
  44. return len(self.o)
  45. def __iter__(self):
  46. for obj in iter(self.o.values()):
  47. yield obj
  48. def __cmp__(self, other):
  49. s1 = set(other.o.values())
  50. s2 = set(self.o.values())
  51. return cmp(s1, s2)
  52. def __eq__(self, other):
  53. s1 = set(other.o.values())
  54. s2 = set(self.o.values())
  55. return s1 == s2
  56. def __ne__(self, other):
  57. s1 = set(other.o.values())
  58. s2 = set(self.o.values())
  59. return s1 != s2
  60. def copy(self):
  61. return Poset(self.o.values())
  62. __copy__ = copy
  63. def deepcopy(self):
  64. from copy import deepcopy
  65. L = deepcopy(self.o.values())
  66. return Poset(L)
  67. def __or__(self, other):
  68. return self.union(other)
  69. def union(self, other):
  70. p = Poset([])
  71. p.o.update(self.o)
  72. p.o.update(other.o)
  73. return p
  74. def update(self, other):
  75. self.o.update(other.o)
  76. def __and__(self, other):
  77. s1 = set(self.o.values())
  78. s2 = set(other.o.values())
  79. return Poset(s1.intersection(s2))
  80. def intersection(self, *args):
  81. p = self
  82. for other in args:
  83. p = p & other
  84. return p
  85. def __xor__(self, other):
  86. s1 = set(self.o.values())
  87. s2 = set(other.o.values())
  88. return Poset(s1.symmetric_difference(s2))
  89. def symmetric_difference(self, *args):
  90. p = self
  91. for other in args:
  92. p = p ^ other
  93. return p
  94. def __sub__(self, other):
  95. s1 = set(self.o.values())
  96. s2 = set(other.o.values())
  97. return Poset(s1.difference(s2))
  98. def difference(self, *args):
  99. p = self
  100. for other in args:
  101. p = p - other
  102. return p
  103. def __contains__(self, obj):
  104. return obj in self.o
  105. def contains__cmp__(self, obj):
  106. return obj in self.o.values()
  107. def issubset(self, other):
  108. s1 = set(self.o.values())
  109. s2 = set(other.o.values())
  110. return s1.issubset(s2)
  111. def issuperset(self, other):
  112. s1 = set(self.o.values())
  113. s2 = set(other.o.values())
  114. return s1.issuperset(s2)
  115. __le__ = issubset
  116. __ge__ = issuperset
  117. def __lt__(self, other):
  118. return self <= other and len(self) != len(other)
  119. def __gt__(self, other):
  120. return self >= other and len(self) != len(other)