Olivier Deprez | f4ef2d0 | 2021-04-20 13:36:24 +0200 | [diff] [blame^] | 1 | //===- llvm/ADT/CoalescingBitVector.h - A coalescing bitvector --*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file A bitvector that uses an IntervalMap to coalesce adjacent elements |
| 10 | /// into intervals. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef LLVM_ADT_COALESCINGBITVECTOR_H |
| 15 | #define LLVM_ADT_COALESCINGBITVECTOR_H |
| 16 | |
| 17 | #include "llvm/ADT/IntervalMap.h" |
| 18 | #include "llvm/ADT/SmallVector.h" |
| 19 | #include "llvm/ADT/iterator_range.h" |
| 20 | #include "llvm/Support/Debug.h" |
| 21 | #include "llvm/Support/raw_ostream.h" |
| 22 | |
| 23 | #include <algorithm> |
| 24 | #include <initializer_list> |
| 25 | |
| 26 | namespace llvm { |
| 27 | |
| 28 | /// A bitvector that, under the hood, relies on an IntervalMap to coalesce |
| 29 | /// elements into intervals. Good for representing sets which predominantly |
| 30 | /// contain contiguous ranges. Bad for representing sets with lots of gaps |
| 31 | /// between elements. |
| 32 | /// |
| 33 | /// Compared to SparseBitVector, CoalescingBitVector offers more predictable |
| 34 | /// performance for non-sequential find() operations. |
| 35 | /// |
| 36 | /// \tparam IndexT - The type of the index into the bitvector. |
| 37 | template <typename IndexT> class CoalescingBitVector { |
| 38 | static_assert(std::is_unsigned<IndexT>::value, |
| 39 | "Index must be an unsigned integer."); |
| 40 | |
| 41 | using ThisT = CoalescingBitVector<IndexT>; |
| 42 | |
| 43 | /// An interval map for closed integer ranges. The mapped values are unused. |
| 44 | using MapT = IntervalMap<IndexT, char>; |
| 45 | |
| 46 | using UnderlyingIterator = typename MapT::const_iterator; |
| 47 | |
| 48 | using IntervalT = std::pair<IndexT, IndexT>; |
| 49 | |
| 50 | public: |
| 51 | using Allocator = typename MapT::Allocator; |
| 52 | |
| 53 | /// Construct by passing in a CoalescingBitVector<IndexT>::Allocator |
| 54 | /// reference. |
| 55 | CoalescingBitVector(Allocator &Alloc) |
| 56 | : Alloc(&Alloc), Intervals(Alloc) {} |
| 57 | |
| 58 | /// \name Copy/move constructors and assignment operators. |
| 59 | /// @{ |
| 60 | |
| 61 | CoalescingBitVector(const ThisT &Other) |
| 62 | : Alloc(Other.Alloc), Intervals(*Other.Alloc) { |
| 63 | set(Other); |
| 64 | } |
| 65 | |
| 66 | ThisT &operator=(const ThisT &Other) { |
| 67 | clear(); |
| 68 | set(Other); |
| 69 | return *this; |
| 70 | } |
| 71 | |
| 72 | CoalescingBitVector(ThisT &&Other) = delete; |
| 73 | ThisT &operator=(ThisT &&Other) = delete; |
| 74 | |
| 75 | /// @} |
| 76 | |
| 77 | /// Clear all the bits. |
| 78 | void clear() { Intervals.clear(); } |
| 79 | |
| 80 | /// Check whether no bits are set. |
| 81 | bool empty() const { return Intervals.empty(); } |
| 82 | |
| 83 | /// Count the number of set bits. |
| 84 | unsigned count() const { |
| 85 | unsigned Bits = 0; |
| 86 | for (auto It = Intervals.begin(), End = Intervals.end(); It != End; ++It) |
| 87 | Bits += 1 + It.stop() - It.start(); |
| 88 | return Bits; |
| 89 | } |
| 90 | |
| 91 | /// Set the bit at \p Index. |
| 92 | /// |
| 93 | /// This method does /not/ support setting a bit that has already been set, |
| 94 | /// for efficiency reasons. If possible, restructure your code to not set the |
| 95 | /// same bit multiple times, or use \ref test_and_set. |
| 96 | void set(IndexT Index) { |
| 97 | assert(!test(Index) && "Setting already-set bits not supported/efficient, " |
| 98 | "IntervalMap will assert"); |
| 99 | insert(Index, Index); |
| 100 | } |
| 101 | |
| 102 | /// Set the bits set in \p Other. |
| 103 | /// |
| 104 | /// This method does /not/ support setting already-set bits, see \ref set |
| 105 | /// for the rationale. For a safe set union operation, use \ref operator|=. |
| 106 | void set(const ThisT &Other) { |
| 107 | for (auto It = Other.Intervals.begin(), End = Other.Intervals.end(); |
| 108 | It != End; ++It) |
| 109 | insert(It.start(), It.stop()); |
| 110 | } |
| 111 | |
| 112 | /// Set the bits at \p Indices. Used for testing, primarily. |
| 113 | void set(std::initializer_list<IndexT> Indices) { |
| 114 | for (IndexT Index : Indices) |
| 115 | set(Index); |
| 116 | } |
| 117 | |
| 118 | /// Check whether the bit at \p Index is set. |
| 119 | bool test(IndexT Index) const { |
| 120 | const auto It = Intervals.find(Index); |
| 121 | if (It == Intervals.end()) |
| 122 | return false; |
| 123 | assert(It.stop() >= Index && "Interval must end after Index"); |
| 124 | return It.start() <= Index; |
| 125 | } |
| 126 | |
| 127 | /// Set the bit at \p Index. Supports setting an already-set bit. |
| 128 | void test_and_set(IndexT Index) { |
| 129 | if (!test(Index)) |
| 130 | set(Index); |
| 131 | } |
| 132 | |
| 133 | /// Reset the bit at \p Index. Supports resetting an already-unset bit. |
| 134 | void reset(IndexT Index) { |
| 135 | auto It = Intervals.find(Index); |
| 136 | if (It == Intervals.end()) |
| 137 | return; |
| 138 | |
| 139 | // Split the interval containing Index into up to two parts: one from |
| 140 | // [Start, Index-1] and another from [Index+1, Stop]. If Index is equal to |
| 141 | // either Start or Stop, we create one new interval. If Index is equal to |
| 142 | // both Start and Stop, we simply erase the existing interval. |
| 143 | IndexT Start = It.start(); |
| 144 | if (Index < Start) |
| 145 | // The index was not set. |
| 146 | return; |
| 147 | IndexT Stop = It.stop(); |
| 148 | assert(Index <= Stop && "Wrong interval for index"); |
| 149 | It.erase(); |
| 150 | if (Start < Index) |
| 151 | insert(Start, Index - 1); |
| 152 | if (Index < Stop) |
| 153 | insert(Index + 1, Stop); |
| 154 | } |
| 155 | |
| 156 | /// Set union. If \p RHS is guaranteed to not overlap with this, \ref set may |
| 157 | /// be a faster alternative. |
| 158 | void operator|=(const ThisT &RHS) { |
| 159 | // Get the overlaps between the two interval maps. |
| 160 | SmallVector<IntervalT, 8> Overlaps; |
| 161 | getOverlaps(RHS, Overlaps); |
| 162 | |
| 163 | // Insert the non-overlapping parts of all the intervals from RHS. |
| 164 | for (auto It = RHS.Intervals.begin(), End = RHS.Intervals.end(); |
| 165 | It != End; ++It) { |
| 166 | IndexT Start = It.start(); |
| 167 | IndexT Stop = It.stop(); |
| 168 | SmallVector<IntervalT, 8> NonOverlappingParts; |
| 169 | getNonOverlappingParts(Start, Stop, Overlaps, NonOverlappingParts); |
| 170 | for (IntervalT AdditivePortion : NonOverlappingParts) |
| 171 | insert(AdditivePortion.first, AdditivePortion.second); |
| 172 | } |
| 173 | } |
| 174 | |
| 175 | /// Set intersection. |
| 176 | void operator&=(const ThisT &RHS) { |
| 177 | // Get the overlaps between the two interval maps (i.e. the intersection). |
| 178 | SmallVector<IntervalT, 8> Overlaps; |
| 179 | getOverlaps(RHS, Overlaps); |
| 180 | // Rebuild the interval map, including only the overlaps. |
| 181 | clear(); |
| 182 | for (IntervalT Overlap : Overlaps) |
| 183 | insert(Overlap.first, Overlap.second); |
| 184 | } |
| 185 | |
| 186 | /// Reset all bits present in \p Other. |
| 187 | void intersectWithComplement(const ThisT &Other) { |
| 188 | SmallVector<IntervalT, 8> Overlaps; |
| 189 | if (!getOverlaps(Other, Overlaps)) { |
| 190 | // If there is no overlap with Other, the intersection is empty. |
| 191 | return; |
| 192 | } |
| 193 | |
| 194 | // Delete the overlapping intervals. Split up intervals that only partially |
| 195 | // intersect an overlap. |
| 196 | for (IntervalT Overlap : Overlaps) { |
| 197 | IndexT OlapStart, OlapStop; |
| 198 | std::tie(OlapStart, OlapStop) = Overlap; |
| 199 | |
| 200 | auto It = Intervals.find(OlapStart); |
| 201 | IndexT CurrStart = It.start(); |
| 202 | IndexT CurrStop = It.stop(); |
| 203 | assert(CurrStart <= OlapStart && OlapStop <= CurrStop && |
| 204 | "Expected some intersection!"); |
| 205 | |
| 206 | // Split the overlap interval into up to two parts: one from [CurrStart, |
| 207 | // OlapStart-1] and another from [OlapStop+1, CurrStop]. If OlapStart is |
| 208 | // equal to CurrStart, the first split interval is unnecessary. Ditto for |
| 209 | // when OlapStop is equal to CurrStop, we omit the second split interval. |
| 210 | It.erase(); |
| 211 | if (CurrStart < OlapStart) |
| 212 | insert(CurrStart, OlapStart - 1); |
| 213 | if (OlapStop < CurrStop) |
| 214 | insert(OlapStop + 1, CurrStop); |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | bool operator==(const ThisT &RHS) const { |
| 219 | // We cannot just use std::equal because it checks the dereferenced values |
| 220 | // of an iterator pair for equality, not the iterators themselves. In our |
| 221 | // case that results in comparison of the (unused) IntervalMap values. |
| 222 | auto ItL = Intervals.begin(); |
| 223 | auto ItR = RHS.Intervals.begin(); |
| 224 | while (ItL != Intervals.end() && ItR != RHS.Intervals.end() && |
| 225 | ItL.start() == ItR.start() && ItL.stop() == ItR.stop()) { |
| 226 | ++ItL; |
| 227 | ++ItR; |
| 228 | } |
| 229 | return ItL == Intervals.end() && ItR == RHS.Intervals.end(); |
| 230 | } |
| 231 | |
| 232 | bool operator!=(const ThisT &RHS) const { return !operator==(RHS); } |
| 233 | |
| 234 | class const_iterator |
| 235 | : public std::iterator<std::forward_iterator_tag, IndexT> { |
| 236 | friend class CoalescingBitVector; |
| 237 | |
| 238 | // For performance reasons, make the offset at the end different than the |
| 239 | // one used in \ref begin, to optimize the common `It == end()` pattern. |
| 240 | static constexpr unsigned kIteratorAtTheEndOffset = ~0u; |
| 241 | |
| 242 | UnderlyingIterator MapIterator; |
| 243 | unsigned OffsetIntoMapIterator = 0; |
| 244 | |
| 245 | // Querying the start/stop of an IntervalMap iterator can be very expensive. |
| 246 | // Cache these values for performance reasons. |
| 247 | IndexT CachedStart = IndexT(); |
| 248 | IndexT CachedStop = IndexT(); |
| 249 | |
| 250 | void setToEnd() { |
| 251 | OffsetIntoMapIterator = kIteratorAtTheEndOffset; |
| 252 | CachedStart = IndexT(); |
| 253 | CachedStop = IndexT(); |
| 254 | } |
| 255 | |
| 256 | /// MapIterator has just changed, reset the cached state to point to the |
| 257 | /// start of the new underlying iterator. |
| 258 | void resetCache() { |
| 259 | if (MapIterator.valid()) { |
| 260 | OffsetIntoMapIterator = 0; |
| 261 | CachedStart = MapIterator.start(); |
| 262 | CachedStop = MapIterator.stop(); |
| 263 | } else { |
| 264 | setToEnd(); |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | /// Advance the iterator to \p Index, if it is contained within the current |
| 269 | /// interval. The public-facing method which supports advancing past the |
| 270 | /// current interval is \ref advanceToLowerBound. |
| 271 | void advanceTo(IndexT Index) { |
| 272 | assert(Index <= CachedStop && "Cannot advance to OOB index"); |
| 273 | if (Index < CachedStart) |
| 274 | // We're already past this index. |
| 275 | return; |
| 276 | OffsetIntoMapIterator = Index - CachedStart; |
| 277 | } |
| 278 | |
| 279 | const_iterator(UnderlyingIterator MapIt) : MapIterator(MapIt) { |
| 280 | resetCache(); |
| 281 | } |
| 282 | |
| 283 | public: |
| 284 | const_iterator() { setToEnd(); } |
| 285 | |
| 286 | bool operator==(const const_iterator &RHS) const { |
| 287 | // Do /not/ compare MapIterator for equality, as this is very expensive. |
| 288 | // The cached start/stop values make that check unnecessary. |
| 289 | return std::tie(OffsetIntoMapIterator, CachedStart, CachedStop) == |
| 290 | std::tie(RHS.OffsetIntoMapIterator, RHS.CachedStart, |
| 291 | RHS.CachedStop); |
| 292 | } |
| 293 | |
| 294 | bool operator!=(const const_iterator &RHS) const { |
| 295 | return !operator==(RHS); |
| 296 | } |
| 297 | |
| 298 | IndexT operator*() const { return CachedStart + OffsetIntoMapIterator; } |
| 299 | |
| 300 | const_iterator &operator++() { // Pre-increment (++It). |
| 301 | if (CachedStart + OffsetIntoMapIterator < CachedStop) { |
| 302 | // Keep going within the current interval. |
| 303 | ++OffsetIntoMapIterator; |
| 304 | } else { |
| 305 | // We reached the end of the current interval: advance. |
| 306 | ++MapIterator; |
| 307 | resetCache(); |
| 308 | } |
| 309 | return *this; |
| 310 | } |
| 311 | |
| 312 | const_iterator operator++(int) { // Post-increment (It++). |
| 313 | const_iterator tmp = *this; |
| 314 | operator++(); |
| 315 | return tmp; |
| 316 | } |
| 317 | |
| 318 | /// Advance the iterator to the first set bit AT, OR AFTER, \p Index. If |
| 319 | /// no such set bit exists, advance to end(). This is like std::lower_bound. |
| 320 | /// This is useful if \p Index is close to the current iterator position. |
| 321 | /// However, unlike \ref find(), this has worst-case O(n) performance. |
| 322 | void advanceToLowerBound(IndexT Index) { |
| 323 | if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) |
| 324 | return; |
| 325 | |
| 326 | // Advance to the first interval containing (or past) Index, or to end(). |
| 327 | while (Index > CachedStop) { |
| 328 | ++MapIterator; |
| 329 | resetCache(); |
| 330 | if (OffsetIntoMapIterator == kIteratorAtTheEndOffset) |
| 331 | return; |
| 332 | } |
| 333 | |
| 334 | advanceTo(Index); |
| 335 | } |
| 336 | }; |
| 337 | |
| 338 | const_iterator begin() const { return const_iterator(Intervals.begin()); } |
| 339 | |
| 340 | const_iterator end() const { return const_iterator(); } |
| 341 | |
| 342 | /// Return an iterator pointing to the first set bit AT, OR AFTER, \p Index. |
| 343 | /// If no such set bit exists, return end(). This is like std::lower_bound. |
| 344 | /// This has worst-case logarithmic performance (roughly O(log(gaps between |
| 345 | /// contiguous ranges))). |
| 346 | const_iterator find(IndexT Index) const { |
| 347 | auto UnderlyingIt = Intervals.find(Index); |
| 348 | if (UnderlyingIt == Intervals.end()) |
| 349 | return end(); |
| 350 | auto It = const_iterator(UnderlyingIt); |
| 351 | It.advanceTo(Index); |
| 352 | return It; |
| 353 | } |
| 354 | |
| 355 | /// Return a range iterator which iterates over all of the set bits in the |
| 356 | /// half-open range [Start, End). |
| 357 | iterator_range<const_iterator> half_open_range(IndexT Start, |
| 358 | IndexT End) const { |
| 359 | assert(Start < End && "Not a valid range"); |
| 360 | auto StartIt = find(Start); |
| 361 | if (StartIt == end() || *StartIt >= End) |
| 362 | return {end(), end()}; |
| 363 | auto EndIt = StartIt; |
| 364 | EndIt.advanceToLowerBound(End); |
| 365 | return {StartIt, EndIt}; |
| 366 | } |
| 367 | |
| 368 | void print(raw_ostream &OS) const { |
| 369 | OS << "{"; |
| 370 | for (auto It = Intervals.begin(), End = Intervals.end(); It != End; |
| 371 | ++It) { |
| 372 | OS << "[" << It.start(); |
| 373 | if (It.start() != It.stop()) |
| 374 | OS << ", " << It.stop(); |
| 375 | OS << "]"; |
| 376 | } |
| 377 | OS << "}"; |
| 378 | } |
| 379 | |
| 380 | #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| 381 | LLVM_DUMP_METHOD void dump() const { |
| 382 | // LLDB swallows the first line of output after callling dump(). Add |
| 383 | // newlines before/after the braces to work around this. |
| 384 | dbgs() << "\n"; |
| 385 | print(dbgs()); |
| 386 | dbgs() << "\n"; |
| 387 | } |
| 388 | #endif |
| 389 | |
| 390 | private: |
| 391 | void insert(IndexT Start, IndexT End) { Intervals.insert(Start, End, 0); } |
| 392 | |
| 393 | /// Record the overlaps between \p this and \p Other in \p Overlaps. Return |
| 394 | /// true if there is any overlap. |
| 395 | bool getOverlaps(const ThisT &Other, |
| 396 | SmallVectorImpl<IntervalT> &Overlaps) const { |
| 397 | for (IntervalMapOverlaps<MapT, MapT> I(Intervals, Other.Intervals); |
| 398 | I.valid(); ++I) |
| 399 | Overlaps.emplace_back(I.start(), I.stop()); |
| 400 | assert(llvm::is_sorted(Overlaps, |
| 401 | [](IntervalT LHS, IntervalT RHS) { |
| 402 | return LHS.second < RHS.first; |
| 403 | }) && |
| 404 | "Overlaps must be sorted"); |
| 405 | return !Overlaps.empty(); |
| 406 | } |
| 407 | |
| 408 | /// Given the set of overlaps between this and some other bitvector, and an |
| 409 | /// interval [Start, Stop] from that bitvector, determine the portions of the |
| 410 | /// interval which do not overlap with this. |
| 411 | void getNonOverlappingParts(IndexT Start, IndexT Stop, |
| 412 | const SmallVectorImpl<IntervalT> &Overlaps, |
| 413 | SmallVectorImpl<IntervalT> &NonOverlappingParts) { |
| 414 | IndexT NextUncoveredBit = Start; |
| 415 | for (IntervalT Overlap : Overlaps) { |
| 416 | IndexT OlapStart, OlapStop; |
| 417 | std::tie(OlapStart, OlapStop) = Overlap; |
| 418 | |
| 419 | // [Start;Stop] and [OlapStart;OlapStop] overlap iff OlapStart <= Stop |
| 420 | // and Start <= OlapStop. |
| 421 | bool DoesOverlap = OlapStart <= Stop && Start <= OlapStop; |
| 422 | if (!DoesOverlap) |
| 423 | continue; |
| 424 | |
| 425 | // Cover the range [NextUncoveredBit, OlapStart). This puts the start of |
| 426 | // the next uncovered range at OlapStop+1. |
| 427 | if (NextUncoveredBit < OlapStart) |
| 428 | NonOverlappingParts.emplace_back(NextUncoveredBit, OlapStart - 1); |
| 429 | NextUncoveredBit = OlapStop + 1; |
| 430 | if (NextUncoveredBit > Stop) |
| 431 | break; |
| 432 | } |
| 433 | if (NextUncoveredBit <= Stop) |
| 434 | NonOverlappingParts.emplace_back(NextUncoveredBit, Stop); |
| 435 | } |
| 436 | |
| 437 | Allocator *Alloc; |
| 438 | MapT Intervals; |
| 439 | }; |
| 440 | |
| 441 | } // namespace llvm |
| 442 | |
| 443 | #endif // LLVM_ADT_COALESCINGBITVECTOR_H |