blob: b267780f0ced73087c8670473fdf4e7fe22ff193 [file] [log] [blame]
Olivier Deprezf4ef2d02021-04-20 13:36:24 +02001# Access WeakSet through the weakref module.
2# This code is separated-out because it is needed
3# by abc.py to load everything else at startup.
4
5from _weakref import ref
6from types import GenericAlias
7
8__all__ = ['WeakSet']
9
10
11class _IterationGuard:
12 # This context manager registers itself in the current iterators of the
13 # weak container, such as to delay all removals until the context manager
14 # exits.
15 # This technique should be relatively thread-safe (since sets are).
16
17 def __init__(self, weakcontainer):
18 # Don't create cycles
19 self.weakcontainer = ref(weakcontainer)
20
21 def __enter__(self):
22 w = self.weakcontainer()
23 if w is not None:
24 w._iterating.add(self)
25 return self
26
27 def __exit__(self, e, t, b):
28 w = self.weakcontainer()
29 if w is not None:
30 s = w._iterating
31 s.remove(self)
32 if not s:
33 w._commit_removals()
34
35
36class WeakSet:
37 def __init__(self, data=None):
38 self.data = set()
39 def _remove(item, selfref=ref(self)):
40 self = selfref()
41 if self is not None:
42 if self._iterating:
43 self._pending_removals.append(item)
44 else:
45 self.data.discard(item)
46 self._remove = _remove
47 # A list of keys to be removed
48 self._pending_removals = []
49 self._iterating = set()
50 if data is not None:
51 self.update(data)
52
53 def _commit_removals(self):
54 l = self._pending_removals
55 discard = self.data.discard
56 while l:
57 discard(l.pop())
58
59 def __iter__(self):
60 with _IterationGuard(self):
61 for itemref in self.data:
62 item = itemref()
63 if item is not None:
64 # Caveat: the iterator will keep a strong reference to
65 # `item` until it is resumed or closed.
66 yield item
67
68 def __len__(self):
69 return len(self.data) - len(self._pending_removals)
70
71 def __contains__(self, item):
72 try:
73 wr = ref(item)
74 except TypeError:
75 return False
76 return wr in self.data
77
78 def __reduce__(self):
79 return (self.__class__, (list(self),),
80 getattr(self, '__dict__', None))
81
82 def add(self, item):
83 if self._pending_removals:
84 self._commit_removals()
85 self.data.add(ref(item, self._remove))
86
87 def clear(self):
88 if self._pending_removals:
89 self._commit_removals()
90 self.data.clear()
91
92 def copy(self):
93 return self.__class__(self)
94
95 def pop(self):
96 if self._pending_removals:
97 self._commit_removals()
98 while True:
99 try:
100 itemref = self.data.pop()
101 except KeyError:
102 raise KeyError('pop from empty WeakSet') from None
103 item = itemref()
104 if item is not None:
105 return item
106
107 def remove(self, item):
108 if self._pending_removals:
109 self._commit_removals()
110 self.data.remove(ref(item))
111
112 def discard(self, item):
113 if self._pending_removals:
114 self._commit_removals()
115 self.data.discard(ref(item))
116
117 def update(self, other):
118 if self._pending_removals:
119 self._commit_removals()
120 for element in other:
121 self.add(element)
122
123 def __ior__(self, other):
124 self.update(other)
125 return self
126
127 def difference(self, other):
128 newset = self.copy()
129 newset.difference_update(other)
130 return newset
131 __sub__ = difference
132
133 def difference_update(self, other):
134 self.__isub__(other)
135 def __isub__(self, other):
136 if self._pending_removals:
137 self._commit_removals()
138 if self is other:
139 self.data.clear()
140 else:
141 self.data.difference_update(ref(item) for item in other)
142 return self
143
144 def intersection(self, other):
145 return self.__class__(item for item in other if item in self)
146 __and__ = intersection
147
148 def intersection_update(self, other):
149 self.__iand__(other)
150 def __iand__(self, other):
151 if self._pending_removals:
152 self._commit_removals()
153 self.data.intersection_update(ref(item) for item in other)
154 return self
155
156 def issubset(self, other):
157 return self.data.issubset(ref(item) for item in other)
158 __le__ = issubset
159
160 def __lt__(self, other):
161 return self.data < set(map(ref, other))
162
163 def issuperset(self, other):
164 return self.data.issuperset(ref(item) for item in other)
165 __ge__ = issuperset
166
167 def __gt__(self, other):
168 return self.data > set(map(ref, other))
169
170 def __eq__(self, other):
171 if not isinstance(other, self.__class__):
172 return NotImplemented
173 return self.data == set(map(ref, other))
174
175 def symmetric_difference(self, other):
176 newset = self.copy()
177 newset.symmetric_difference_update(other)
178 return newset
179 __xor__ = symmetric_difference
180
181 def symmetric_difference_update(self, other):
182 self.__ixor__(other)
183 def __ixor__(self, other):
184 if self._pending_removals:
185 self._commit_removals()
186 if self is other:
187 self.data.clear()
188 else:
189 self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
190 return self
191
192 def union(self, other):
193 return self.__class__(e for s in (self, other) for e in s)
194 __or__ = union
195
196 def isdisjoint(self, other):
197 return len(self.intersection(other)) == 0
198
199 def __repr__(self):
200 return repr(self.data)
201
202 __class_getitem__ = classmethod(GenericAlias)