aboutsummaryrefslogtreecommitdiffstats
path: root/include/llvm/ADT/BitSetVector.h
blob: cdcd52d948659e3add170ed806aaaadcc17f1000 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
//===-- BitVectorSet.h - A bit-vector representation of sets -----*- C++ -*--=//
//
// class BitVectorSet --
// 
// An implementation of the bit-vector representation of sets.
// Unlike vector<bool>, this allows much more efficient parallel set
// operations on bits, by using the bitset template .  The bitset template
// unfortunately can only represent sets with a size chosen at compile-time.
// We therefore use a vector of bitsets.  The maxmimum size of our sets
// (i.e., the size of the universal set) can be chosen at creation time.
//
// External functions:
// 
// bool Disjoint(const BitSetVector& set1, const BitSetVector& set2):
//    Tests if two sets have an empty intersection.
//    This is more efficient than !(set1 & set2).any().
// 
//===----------------------------------------------------------------------===//

#ifndef SUPPORT_BITSETVECTOR_H
#define SUPPORT_BITSETVECTOR_H

#include <bitset>
#include <vector>
#include <functional>
#include <iostream>

class BitSetVector {
  enum { BITSET_WORDSIZE = sizeof(long)*8 };

  // Types used internal to the representation
  typedef std::bitset<BITSET_WORDSIZE> bitword;
  typedef bitword::reference reference;
  class iterator;

  // Data used in the representation
  std::vector<bitword> bitsetVec;
  unsigned maxSize;

private:
  // Utility functions for the representation
  static unsigned NumWords(unsigned Size) {
    return (Size+BITSET_WORDSIZE-1)/BITSET_WORDSIZE;
  } 
  static unsigned LastWordSize(unsigned Size) { return Size % BITSET_WORDSIZE; }

  // Clear the unused bits in the last word.
  // The unused bits are the high (BITSET_WORDSIZE - LastWordSize()) bits
  void ClearUnusedBits() {
    unsigned long usedBits = (1U << LastWordSize(size())) - 1;
    bitsetVec.back() &= bitword(usedBits);
  }

  const bitword& getWord(unsigned i) const { return bitsetVec[i]; }
        bitword& getWord(unsigned i)       { return bitsetVec[i]; }

  friend bool Disjoint(const BitSetVector& set1,
                       const BitSetVector& set2);

  BitSetVector();                       // do not implement!

public:
  /// 
  /// Constructor: create a set of the maximum size maxSetSize.
  /// The set is initialized to empty.
  ///
  BitSetVector(unsigned maxSetSize)
    : bitsetVec(NumWords(maxSetSize)), maxSize(maxSetSize) { }

  /// size - Return the number of bits tracked by this bit vector...
  unsigned size() const { return maxSize; }

  /// 
  ///  Modifier methods: reset, set for entire set, operator[] for one element.
  ///  
  void reset() {
    for (unsigned i=0, N = bitsetVec.size(); i < N; ++i)
      bitsetVec[i].reset();
  }
  void set() {
    for (unsigned i=0, N = bitsetVec.size(); i < N; ++i) // skip last word
      bitsetVec[i].set();
    ClearUnusedBits();
  }
  reference operator[](unsigned n) {
    assert(n  < size() && "BitSetVector: Bit number out of range");
    unsigned ndiv = n / BITSET_WORDSIZE, nmod = n % BITSET_WORDSIZE;
    return bitsetVec[ndiv][nmod];
  }
  iterator begin() { return iterator::begin(*this); }
  iterator end()   { return iterator::end(*this);   } 

  /// 
  ///  Comparison operations: equal, not equal
  /// 
  bool operator == (const BitSetVector& set2) const {
    assert(maxSize == set2.maxSize && "Illegal == comparison");
    for (unsigned i = 0; i < bitsetVec.size(); ++i)
      if (getWord(i) != set2.getWord(i))
        return false;
    return true;
  }
  bool operator != (const BitSetVector& set2) const {
    return ! (*this == set2);
  }

  /// 
  ///  Set membership operations: single element, any, none, count
  ///  
  bool test(unsigned n) const {
    assert(n  < size() && "BitSetVector: Bit number out of range");
    unsigned ndiv = n / BITSET_WORDSIZE, nmod = n % BITSET_WORDSIZE;
    return bitsetVec[ndiv].test(nmod);
  }
  bool any() const {
    for (unsigned i = 0; i < bitsetVec.size(); ++i)
      if (bitsetVec[i].any())
        return true;
    return false;
  }
  bool none() const {
    return ! any();
  }
  unsigned count() const {
    unsigned n = 0;
    for (unsigned i = 0; i < bitsetVec.size(); ++i)
      n += bitsetVec[i].count();
    return n;
  }
  bool all() const {
    return (count() == size());
  }

  /// 
  ///  Set operations: intersection, union, disjoint union, complement.
  ///  
  BitSetVector operator& (const BitSetVector& set2) const {
    assert(maxSize == set2.maxSize && "Illegal intersection");
    BitSetVector result(maxSize);
    for (unsigned i = 0; i < bitsetVec.size(); ++i)
      result.getWord(i) = getWord(i) & set2.getWord(i);
    return result;
  }
  BitSetVector operator| (const BitSetVector& set2) const {
    assert(maxSize == set2.maxSize && "Illegal intersection");
    BitSetVector result(maxSize);
    for (unsigned i = 0; i < bitsetVec.size(); ++i)
      result.getWord(i) = getWord(i) | set2.getWord(i);
    return result;
  }
  BitSetVector operator^ (const BitSetVector& set2) const {
    assert(maxSize == set2.maxSize && "Illegal intersection");
    BitSetVector result(maxSize);
    for (unsigned i = 0; i < bitsetVec.size(); ++i)
      result.getWord(i) = getWord(i) ^ set2.getWord(i);
    return result;
  }
  BitSetVector operator~ () const {
    BitSetVector result(maxSize);
    for (unsigned i = 0; i < bitsetVec.size(); ++i)
      (result.getWord(i) = getWord(i)).flip();
    result.ClearUnusedBits();
    return result;
  }

  /// 
  ///  Printing and debugging support
  ///  
  void print(std::ostream &O) const;
  void dump() const { print(std::cerr); }

public:
  // 
  // An iterator to enumerate the bits in a BitSetVector.
  // Eventually, this needs to inherit from bidirectional_iterator.
  // But this iterator may not be as useful as I once thought and
  // may just go away.
  // 
  class iterator {
    unsigned   currentBit;
    unsigned   currentWord;
    BitSetVector* bitvec;
    iterator(unsigned B, unsigned W, BitSetVector& _bitvec)
      : currentBit(B), currentWord(W), bitvec(&_bitvec) { }
  public:
    iterator(BitSetVector& _bitvec)
      : currentBit(0), currentWord(0), bitvec(&_bitvec) { }
    iterator(const iterator& I)
      : currentBit(I.currentBit),currentWord(I.currentWord),bitvec(I.bitvec) { }
    iterator& operator=(const iterator& I) {
      currentWord = I.currentWord;
      currentBit = I.currentBit;
      bitvec = I.bitvec;
      return *this;
    }

    // Increment and decrement operators (pre and post)
    iterator& operator++() {
      if (++currentBit == BITSET_WORDSIZE)
        { currentBit = 0; if (currentWord < bitvec->size()) ++currentWord; }
      return *this;
    }
    iterator& operator--() {
      if (currentBit == 0) {
        currentBit = BITSET_WORDSIZE-1;
        currentWord = (currentWord == 0)? bitvec->size() : --currentWord;
      }
      else
        --currentBit;
      return *this;
    }
    iterator operator++(int) { iterator copy(*this); ++*this; return copy; }
    iterator operator--(int) { iterator copy(*this); --*this; return copy; }

    // Dereferencing operators
    reference operator*() {
      assert(currentWord < bitvec->size() &&
             "Dereferencing iterator past the end of a BitSetVector");
      return bitvec->getWord(currentWord)[currentBit];
    }

    // Comparison operator
    bool operator==(const iterator& I) {
      return (I.bitvec == bitvec &&
              I.currentWord == currentWord && I.currentBit == currentBit);
    }

  protected:
    static iterator begin(BitSetVector& _bitvec) { return iterator(_bitvec); }
    static iterator end(BitSetVector& _bitvec)   { return iterator(0,
                                                    _bitvec.size(), _bitvec); }
    friend class BitSetVector;
  };
};


inline void BitSetVector::print(std::ostream& O) const
{
  for (std::vector<bitword>::const_iterator
         I=bitsetVec.begin(), E=bitsetVec.end(); I != E; ++I)
    O << "<" << (*I) << ">" << (I+1 == E? "\n" : ", ");
}

inline std::ostream& operator<< (std::ostream& O, const BitSetVector& bset)
{
  bset.print(O);
  return O;
};


///
/// Optimized versions of fundamental comparison operations
/// 
inline bool Disjoint(const BitSetVector& set1,
                     const BitSetVector& set2)
{
  assert(set1.size() == set2.size() && "Illegal intersection");
  for (unsigned i = 0; i < set1.bitsetVec.size(); ++i)
    if ((set1.getWord(i) & set2.getWord(i)).any())
      return false;
  return true;
}

#endif