blob: 0a7dcfe226315c00c08e3e4a95250128214181f0 [file] [log] [blame]
Olivier Deprezf4ef2d02021-04-20 13:36:24 +02001//===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file A bitvector that uses an IntervalMap to coalesce adjacent elements
10/// into intervals.
11///
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_ADT_COALESCINGBITVECTOR_H
15#define LLVM_ADT_COALESCINGBITVECTOR_H
16
17#include "llvm/ADT/IntervalMap.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/iterator_range.h"
20#include "llvm/Support/Debug.h"
21#include "llvm/Support/raw_ostream.h"
22
23#include <algorithm>
24#include <initializer_list>
25
26namespace llvm {
27
28/// A bitvector that, under the hood, relies on an IntervalMap to coalesce
29/// elements into intervals. Good for representing sets which predominantly
30/// contain contiguous ranges. Bad for representing sets with lots of gaps
31/// between elements.
32///
33/// Compared to SparseBitVector, CoalescingBitVector offers more predictable
34/// performance for non-sequential find() operations.
35///
36/// \tparam IndexT - The type of the index into the bitvector.
37template <typename IndexT> class CoalescingBitVector {
38 static_assert(std::is_unsigned<IndexT>::value,
39 "Index must be an unsigned integer.");
40
41 using ThisT = CoalescingBitVector<IndexT>;
42
43 /// An interval map for closed integer ranges. The mapped values are unused.
44 using MapT = IntervalMap<IndexT, char>;
45
46 using UnderlyingIterator = typename MapT::const_iterator;
47
48 using IntervalT = std::pair<IndexT, IndexT>;
49
50public:
51 using Allocator = typename MapT::Allocator;
52
53 /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator
54 /// reference.
55 CoalescingBitVector(Allocator &Alloc)
56 : Alloc(&Alloc), Intervals(Alloc) {}
57
58 /// \name Copy/move constructors and assignment operators.
59 /// @{
60
61 CoalescingBitVector(const ThisT &Other)
62 : Alloc(Other.Alloc), Intervals(*Other.Alloc) {
63 set(Other);
64 }
65
66 ThisT &operator=(const ThisT &Other) {
67 clear();
68 set(Other);
69 return *this;
70 }
71
72 CoalescingBitVector(ThisT &&Other) = delete;
73 ThisT &operator=(ThisT &&Other) = delete;
74
75 /// @}
76
77 /// Clear all the bits.
78 void clear() { Intervals.clear(); }
79
80 /// Check whether no bits are set.
81 bool empty() const { return Intervals.empty(); }
82
83 /// Count the number of set bits.
84 unsigned count() const {
85 unsigned Bits = 0;
86 for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It)
87 Bits += 1 + It.stop() - It.start();
88 return Bits;
89 }
90
91 /// Set the bit at \p Index.
92 ///
93 /// This method does /not/ support setting a bit that has already been set,
94 /// for efficiency reasons. If possible, restructure your code to not set the
95 /// same bit multiple times, or use \ref test_and_set.
96 void set(IndexT Index) {
97 assert(!test(Index) && "Setting already-set bits not supported/efficient, "
98 "IntervalMap will assert");
99 insert(Index, Index);
100 }
101
102 /// Set the bits set in \p Other.
103 ///
104 /// This method does /not/ support setting already-set bits, see \ref set
105 /// for the rationale. For a safe set union operation, use \ref operator|=.
106 void set(const ThisT &Other) {
107 for (auto It = Other.Intervals.begin(), End = Other.Intervals.end();
108 It != End; ++It)
109 insert(It.start(), It.stop());
110 }
111
112 /// Set the bits at \p Indices. Used for testing, primarily.
113 void set(std::initializer_list<IndexT> Indices) {
114 for (IndexT Index : Indices)
115 set(Index);
116 }
117
118 /// Check whether the bit at \p Index is set.
119 bool test(IndexT Index) const {
120 const auto It = Intervals.find(Index);
121 if (It == Intervals.end())
122 return false;
123 assert(It.stop() >= Index && "Interval must end after Index");
124 return It.start() <= Index;
125 }
126
127 /// Set the bit at \p Index. Supports setting an already-set bit.
128 void test_and_set(IndexT Index) {
129 if (!test(Index))
130 set(Index);
131 }
132
133 /// Reset the bit at \p Index. Supports resetting an already-unset bit.
134 void reset(IndexT Index) {
135 auto It = Intervals.find(Index);
136 if (It == Intervals.end())
137 return;
138
139 // Split the interval containing Index into up to two parts: one from
140 // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to
141 // either Start or Stop, we create one new interval. If Index is equal to
142 // both Start and Stop, we simply erase the existing interval.
143 IndexT Start = It.start();
144 if (Index < Start)
145 // The index was not set.
146 return;
147 IndexT Stop = It.stop();
148 assert(Index <= Stop && "Wrong interval for index");
149 It.erase();
150 if (Start < Index)
151 insert(Start, Index - 1);
152 if (Index < Stop)
153 insert(Index + 1, Stop);
154 }
155
156 /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may
157 /// be a faster alternative.
158 void operator|=(const ThisT &RHS) {
159 // Get the overlaps between the two interval maps.
160 SmallVector<IntervalT, 8> Overlaps;
161 getOverlaps(RHS, Overlaps);
162
163 // Insert the non-overlapping parts of all the intervals from RHS.
164 for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end();
165 It != End; ++It) {
166 IndexT Start = It.start();
167 IndexT Stop = It.stop();
168 SmallVector<IntervalT, 8> NonOverlappingParts;
169 getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts);
170 for (IntervalT AdditivePortion : NonOverlappingParts)
171 insert(AdditivePortion.first, AdditivePortion.second);
172 }
173 }
174
175 /// Set intersection.
176 void operator&=(const ThisT &RHS) {
177 // Get the overlaps between the two interval maps (i.e. the intersection).
178 SmallVector<IntervalT, 8> Overlaps;
179 getOverlaps(RHS, Overlaps);
180 // Rebuild the interval map, including only the overlaps.
181 clear();
182 for (IntervalT Overlap : Overlaps)
183 insert(Overlap.first, Overlap.second);
184 }
185
186 /// Reset all bits present in \p Other.
187 void intersectWithComplement(const ThisT &Other) {
188 SmallVector<IntervalT, 8> Overlaps;
189 if (!getOverlaps(Other, Overlaps)) {
190 // If there is no overlap with Other, the intersection is empty.
191 return;
192 }
193
194 // Delete the overlapping intervals. Split up intervals that only partially
195 // intersect an overlap.
196 for (IntervalT Overlap : Overlaps) {
197 IndexT OlapStart, OlapStop;
198 std::tie(OlapStart, OlapStop) = Overlap;
199
200 auto It = Intervals.find(OlapStart);
201 IndexT CurrStart = It.start();
202 IndexT CurrStop = It.stop();
203 assert(CurrStart <= OlapStart && OlapStop <= CurrStop &&
204 "Expected some intersection!");
205
206 // Split the overlap interval into up to two parts: one from [CurrStart,
207 // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is
208 // equal to CurrStart, the first split interval is unnecessary. Ditto for
209 // when OlapStop is equal to CurrStop, we omit the second split interval.
210 It.erase();
211 if (CurrStart < OlapStart)
212 insert(CurrStart, OlapStart - 1);
213 if (OlapStop < CurrStop)
214 insert(OlapStop + 1, CurrStop);
215 }
216 }
217
218 bool operator==(const ThisT &RHS) const {
219 // We cannot just use std::equal because it checks the dereferenced values
220 // of an iterator pair for equality, not the iterators themselves. In our
221 // case that results in comparison of the (unused) IntervalMap values.
222 auto ItL = Intervals.begin();
223 auto ItR = RHS.Intervals.begin();
224 while (ItL != Intervals.end() && ItR != RHS.Intervals.end() &&
225 ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) {
226 ++ItL;
227 ++ItR;
228 }
229 return ItL == Intervals.end() && ItR == RHS.Intervals.end();
230 }
231
232 bool operator!=(const ThisT &RHS) const { return !operator==(RHS); }
233
234 class const_iterator
235 : public std::iterator<std::forward_iterator_tag, IndexT> {
236 friend class CoalescingBitVector;
237
238 // For performance reasons, make the offset at the end different than the
239 // one used in \ref begin, to optimize the common `It == end()` pattern.
240 static constexpr unsigned kIteratorAtTheEndOffset = ~0u;
241
242 UnderlyingIterator MapIterator;
243 unsigned OffsetIntoMapIterator = 0;
244
245 // Querying the start/stop of an IntervalMap iterator can be very expensive.
246 // Cache these values for performance reasons.
247 IndexT CachedStart = IndexT();
248 IndexT CachedStop = IndexT();
249
250 void setToEnd() {
251 OffsetIntoMapIterator = kIteratorAtTheEndOffset;
252 CachedStart = IndexT();
253 CachedStop = IndexT();
254 }
255
256 /// MapIterator has just changed, reset the cached state to point to the
257 /// start of the new underlying iterator.
258 void resetCache() {
259 if (MapIterator.valid()) {
260 OffsetIntoMapIterator = 0;
261 CachedStart = MapIterator.start();
262 CachedStop = MapIterator.stop();
263 } else {
264 setToEnd();
265 }
266 }
267
268 /// Advance the iterator to \p Index, if it is contained within the current
269 /// interval. The public-facing method which supports advancing past the
270 /// current interval is \ref advanceToLowerBound.
271 void advanceTo(IndexT Index) {
272 assert(Index <= CachedStop && "Cannot advance to OOB index");
273 if (Index < CachedStart)
274 // We're already past this index.
275 return;
276 OffsetIntoMapIterator = Index - CachedStart;
277 }
278
279 const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) {
280 resetCache();
281 }
282
283 public:
284 const_iterator() { setToEnd(); }
285
286 bool operator==(const const_iterator &RHS) const {
287 // Do /not/ compare MapIterator for equality, as this is very expensive.
288 // The cached start/stop values make that check unnecessary.
289 return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) ==
290 std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart,
291 RHS.CachedStop);
292 }
293
294 bool operator!=(const const_iterator &RHS) const {
295 return !operator==(RHS);
296 }
297
298 IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; }
299
300 const_iterator &operator++() { // Pre-increment (++It).
301 if (CachedStart + OffsetIntoMapIterator < CachedStop) {
302 // Keep going within the current interval.
303 ++OffsetIntoMapIterator;
304 } else {
305 // We reached the end of the current interval: advance.
306 ++MapIterator;
307 resetCache();
308 }
309 return *this;
310 }
311
312 const_iterator operator++(int) { // Post-increment (It++).
313 const_iterator tmp = *this;
314 operator++();
315 return tmp;
316 }
317
318 /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If
319 /// no such set bit exists, advance to end(). This is like std::lower_bound.
320 /// This is useful if \p Index is close to the current iterator position.
321 /// However, unlike \ref find(), this has worst-case O(n) performance.
322 void advanceToLowerBound(IndexT Index) {
323 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
324 return;
325
326 // Advance to the first interval containing (or past) Index, or to end().
327 while (Index > CachedStop) {
328 ++MapIterator;
329 resetCache();
330 if (OffsetIntoMapIterator == kIteratorAtTheEndOffset)
331 return;
332 }
333
334 advanceTo(Index);
335 }
336 };
337
338 const_iterator begin() const { return const_iterator(Intervals.begin()); }
339
340 const_iterator end() const { return const_iterator(); }
341
342 /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index.
343 /// If no such set bit exists, return end(). This is like std::lower_bound.
344 /// This has worst-case logarithmic performance (roughly O(log(gaps between
345 /// contiguous ranges))).
346 const_iterator find(IndexT Index) const {
347 auto UnderlyingIt = Intervals.find(Index);
348 if (UnderlyingIt == Intervals.end())
349 return end();
350 auto It = const_iterator(UnderlyingIt);
351 It.advanceTo(Index);
352 return It;
353 }
354
355 /// Return a range iterator which iterates over all of the set bits in the
356 /// half-open range [Start, End).
357 iterator_range<const_iterator> half_open_range(IndexT Start,
358 IndexT End) const {
359 assert(Start < End && "Not a valid range");
360 auto StartIt = find(Start);
361 if (StartIt == end() || *StartIt >= End)
362 return {end(), end()};
363 auto EndIt = StartIt;
364 EndIt.advanceToLowerBound(End);
365 return {StartIt, EndIt};
366 }
367
368 void print(raw_ostream &OS) const {
369 OS << "{";
370 for (auto It = Intervals.begin(), End = Intervals.end(); It != End;
371 ++It) {
372 OS << "[" << It.start();
373 if (It.start() != It.stop())
374 OS << ", " << It.stop();
375 OS << "]";
376 }
377 OS << "}";
378 }
379
380#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
381 LLVM_DUMP_METHOD void dump() const {
382 // LLDB swallows the first line of output after callling dump(). Add
383 // newlines before/after the braces to work around this.
384 dbgs() << "\n";
385 print(dbgs());
386 dbgs() << "\n";
387 }
388#endif
389
390private:
391 void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); }
392
393 /// Record the overlaps between \p this and \p Other in \p Overlaps. Return
394 /// true if there is any overlap.
395 bool getOverlaps(const ThisT &Other,
396 SmallVectorImpl<IntervalT> &Overlaps) const {
397 for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals);
398 I.valid(); ++I)
399 Overlaps.emplace_back(I.start(), I.stop());
400 assert(llvm::is_sorted(Overlaps,
401 [](IntervalT LHS, IntervalT RHS) {
402 return LHS.second < RHS.first;
403 }) &&
404 "Overlaps must be sorted");
405 return !Overlaps.empty();
406 }
407
408 /// Given the set of overlaps between this and some other bitvector, and an
409 /// interval [Start, Stop] from that bitvector, determine the portions of the
410 /// interval which do not overlap with this.
411 void getNonOverlappingParts(IndexT Start, IndexT Stop,
412 const SmallVectorImpl<IntervalT> &Overlaps,
413 SmallVectorImpl<IntervalT> &NonOverlappingParts) {
414 IndexT NextUncoveredBit = Start;
415 for (IntervalT Overlap : Overlaps) {
416 IndexT OlapStart, OlapStop;
417 std::tie(OlapStart, OlapStop) = Overlap;
418
419 // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop
420 // and Start <= OlapStop.
421 bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop;
422 if (!DoesOverlap)
423 continue;
424
425 // Cover the range [NextUncoveredBit, OlapStart). This puts the start of
426 // the next uncovered range at OlapStop+1.
427 if (NextUncoveredBit < OlapStart)
428 NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1);
429 NextUncoveredBit = OlapStop + 1;
430 if (NextUncoveredBit > Stop)
431 break;
432 }
433 if (NextUncoveredBit <= Stop)
434 NonOverlappingParts.emplace_back(NextUncoveredBit, Stop);
435 }
436
437 Allocator *Alloc;
438 MapT Intervals;
439};
440
441} // namespace llvm
442
443#endif // LLVM_ADT_COALESCINGBITVECTOR_H