Update prebuilt Clang to r416183b from Android.

https://android.googlesource.com/platform/prebuilts/clang/host/
linux-x86/+/06a71ddac05c22edb2d10b590e1769b3f8619bef

clang 12.0.5 (based on r416183b) from build 7284624.

Change-Id: I277a316abcf47307562d8b748b84870f31a72866
Signed-off-by: Olivier Deprez <olivier.deprez@arm.com>
diff --git a/linux-x64/clang/include/llvm/Analysis/VectorUtils.h b/linux-x64/clang/include/llvm/Analysis/VectorUtils.h
index d93d2bc..d8fb970 100644
--- a/linux-x64/clang/include/llvm/Analysis/VectorUtils.h
+++ b/linux-x64/clang/include/llvm/Analysis/VectorUtils.h
@@ -14,17 +14,282 @@
 #define LLVM_ANALYSIS_VECTORUTILS_H
 
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
-#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/IR/IRBuilder.h"
 #include "llvm/Support/CheckedArithmetic.h"
 
 namespace llvm {
+class TargetLibraryInfo;
+
+/// Describes the type of Parameters
+enum class VFParamKind {
+  Vector,            // No semantic information.
+  OMP_Linear,        // declare simd linear(i)
+  OMP_LinearRef,     // declare simd linear(ref(i))
+  OMP_LinearVal,     // declare simd linear(val(i))
+  OMP_LinearUVal,    // declare simd linear(uval(i))
+  OMP_LinearPos,     // declare simd linear(i:c) uniform(c)
+  OMP_LinearValPos,  // declare simd linear(val(i:c)) uniform(c)
+  OMP_LinearRefPos,  // declare simd linear(ref(i:c)) uniform(c)
+  OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c
+  OMP_Uniform,       // declare simd uniform(i)
+  GlobalPredicate,   // Global logical predicate that acts on all lanes
+                     // of the input and output mask concurrently. For
+                     // example, it is implied by the `M` token in the
+                     // Vector Function ABI mangled name.
+  Unknown
+};
+
+/// Describes the type of Instruction Set Architecture
+enum class VFISAKind {
+  AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
+  SVE,          // AArch64 Scalable Vector Extension
+  SSE,          // x86 SSE
+  AVX,          // x86 AVX
+  AVX2,         // x86 AVX2
+  AVX512,       // x86 AVX512
+  LLVM,         // LLVM internal ISA for functions that are not
+  // attached to an existing ABI via name mangling.
+  Unknown // Unknown ISA
+};
+
+/// Encapsulates information needed to describe a parameter.
+///
+/// The description of the parameter is not linked directly to
+/// OpenMP or any other vector function description. This structure
+/// is extendible to handle other paradigms that describe vector
+/// functions and their parameters.
+struct VFParameter {
+  unsigned ParamPos;         // Parameter Position in Scalar Function.
+  VFParamKind ParamKind;     // Kind of Parameter.
+  int LinearStepOrPos = 0;   // Step or Position of the Parameter.
+  Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1.
+
+  // Comparison operator.
+  bool operator==(const VFParameter &Other) const {
+    return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
+           std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
+                    Other.Alignment);
+  }
+};
+
+/// Contains the information about the kind of vectorization
+/// available.
+///
+/// This object in independent on the paradigm used to
+/// represent vector functions. in particular, it is not attached to
+/// any target-specific ABI.
+struct VFShape {
+  unsigned VF;     // Vectorization factor.
+  bool IsScalable; // True if the function is a scalable function.
+  SmallVector<VFParameter, 8> Parameters; // List of parameter information.
+  // Comparison operator.
+  bool operator==(const VFShape &Other) const {
+    return std::tie(VF, IsScalable, Parameters) ==
+           std::tie(Other.VF, Other.IsScalable, Other.Parameters);
+  }
+
+  /// Update the parameter in position P.ParamPos to P.
+  void updateParam(VFParameter P) {
+    assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
+    Parameters[P.ParamPos] = P;
+    assert(hasValidParameterList() && "Invalid parameter list");
+  }
+
+  // Retrieve the VFShape that can be used to map a (scalar) function to itself,
+  // with VF = 1.
+  static VFShape getScalarShape(const CallInst &CI) {
+    return VFShape::get(CI, ElementCount::getFixed(1),
+                        /*HasGlobalPredicate*/ false);
+  }
+
+  // Retrieve the basic vectorization shape of the function, where all
+  // parameters are mapped to VFParamKind::Vector with \p EC
+  // lanes. Specifies whether the function has a Global Predicate
+  // argument via \p HasGlobalPred.
+  static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
+    SmallVector<VFParameter, 8> Parameters;
+    for (unsigned I = 0; I < CI.arg_size(); ++I)
+      Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
+    if (HasGlobalPred)
+      Parameters.push_back(
+          VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
+
+    return {EC.getKnownMinValue(), EC.isScalable(), Parameters};
+  }
+  /// Sanity check on the Parameters in the VFShape.
+  bool hasValidParameterList() const;
+};
+
+/// Holds the VFShape for a specific scalar to vector function mapping.
+struct VFInfo {
+  VFShape Shape;          /// Classification of the vector function.
+  std::string ScalarName; /// Scalar Function Name.
+  std::string VectorName; /// Vector Function Name associated to this VFInfo.
+  VFISAKind ISA;          /// Instruction Set Architecture.
+
+  // Comparison operator.
+  bool operator==(const VFInfo &Other) const {
+    return std::tie(Shape, ScalarName, VectorName, ISA) ==
+           std::tie(Shape, Other.ScalarName, Other.VectorName, Other.ISA);
+  }
+};
+
+namespace VFABI {
+/// LLVM Internal VFABI ISA token for vector functions.
+static constexpr char const *_LLVM_ = "_LLVM_";
+/// Prefix for internal name redirection for vector function that
+/// tells the compiler to scalarize the call using the scalar name
+/// of the function. For example, a mangled name like
+/// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the
+/// vectorizer to vectorize the scalar call `foo`, and to scalarize
+/// it once vectorization is done.
+static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
+
+/// Function to construct a VFInfo out of a mangled names in the
+/// following format:
+///
+/// <VFABI_name>{(<redirection>)}
+///
+/// where <VFABI_name> is the name of the vector function, mangled according
+/// to the rules described in the Vector Function ABI of the target vector
+/// extension (or <isa> from now on). The <VFABI_name> is in the following
+/// format:
+///
+/// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
+///
+/// This methods support demangling rules for the following <isa>:
+///
+/// * AArch64: https://developer.arm.com/docs/101129/latest
+///
+/// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
+///  https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
+///
+/// \param MangledName -> input string in the format
+/// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
+/// \param M -> Module used to retrieve informations about the vector
+/// function that are not possible to retrieve from the mangled
+/// name. At the moment, this parameter is needed only to retrieve the
+/// Vectorization Factor of scalable vector functions from their
+/// respective IR declarations.
+Optional<VFInfo> tryDemangleForVFABI(StringRef MangledName, const Module &M);
+
+/// This routine mangles the given VectorName according to the LangRef
+/// specification for vector-function-abi-variant attribute and is specific to
+/// the TLI mappings. It is the responsibility of the caller to make sure that
+/// this is only used if all parameters in the vector function are vector type.
+/// This returned string holds scalar-to-vector mapping:
+///    _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>)
+///
+/// where:
+///
+/// <isa> = "_LLVM_"
+/// <mask> = "N". Note: TLI does not support masked interfaces.
+/// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor`
+///          field of the `VecDesc` struct.
+/// <vparams> = "v", as many as are the numArgs.
+/// <scalarname> = the name of the scalar function.
+/// <vectorname> = the name of the vector function.
+std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName,
+                                unsigned numArgs, unsigned VF);
+
+/// Retrieve the `VFParamKind` from a string token.
+VFParamKind getVFParamKindFromString(const StringRef Token);
+
+// Name of the attribute where the variant mappings are stored.
+static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
+
+/// Populates a set of strings representing the Vector Function ABI variants
+/// associated to the CallInst CI. If the CI does not contain the
+/// vector-function-abi-variant attribute, we return without populating
+/// VariantMappings, i.e. callers of getVectorVariantNames need not check for
+/// the presence of the attribute (see InjectTLIMappings).
+void getVectorVariantNames(const CallInst &CI,
+                           SmallVectorImpl<std::string> &VariantMappings);
+} // end namespace VFABI
+
+/// The Vector Function Database.
+///
+/// Helper class used to find the vector functions associated to a
+/// scalar CallInst.
+class VFDatabase {
+  /// The Module of the CallInst CI.
+  const Module *M;
+  /// The CallInst instance being queried for scalar to vector mappings.
+  const CallInst &CI;
+  /// List of vector functions descriptors associated to the call
+  /// instruction.
+  const SmallVector<VFInfo, 8> ScalarToVectorMappings;
+
+  /// Retrieve the scalar-to-vector mappings associated to the rule of
+  /// a vector Function ABI.
+  static void getVFABIMappings(const CallInst &CI,
+                               SmallVectorImpl<VFInfo> &Mappings) {
+    if (!CI.getCalledFunction())
+      return;
+
+    const StringRef ScalarName = CI.getCalledFunction()->getName();
+
+    SmallVector<std::string, 8> ListOfStrings;
+    // The check for the vector-function-abi-variant attribute is done when
+    // retrieving the vector variant names here.
+    VFABI::getVectorVariantNames(CI, ListOfStrings);
+    if (ListOfStrings.empty())
+      return;
+    for (const auto &MangledName : ListOfStrings) {
+      const Optional<VFInfo> Shape =
+          VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
+      // A match is found via scalar and vector names, and also by
+      // ensuring that the variant described in the attribute has a
+      // corresponding definition or declaration of the vector
+      // function in the Module M.
+      if (Shape.hasValue() && (Shape.getValue().ScalarName == ScalarName)) {
+        assert(CI.getModule()->getFunction(Shape.getValue().VectorName) &&
+               "Vector function is missing.");
+        Mappings.push_back(Shape.getValue());
+      }
+    }
+  }
+
+public:
+  /// Retrieve all the VFInfo instances associated to the CallInst CI.
+  static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
+    SmallVector<VFInfo, 8> Ret;
+
+    // Get mappings from the Vector Function ABI variants.
+    getVFABIMappings(CI, Ret);
+
+    // Other non-VFABI variants should be retrieved here.
+
+    return Ret;
+  }
+
+  /// Constructor, requires a CallInst instance.
+  VFDatabase(CallInst &CI)
+      : M(CI.getModule()), CI(CI),
+        ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
+  /// \defgroup VFDatabase query interface.
+  ///
+  /// @{
+  /// Retrieve the Function with VFShape \p Shape.
+  Function *getVectorizedFunction(const VFShape &Shape) const {
+    if (Shape == VFShape::getScalarShape(CI))
+      return CI.getCalledFunction();
+
+    for (const auto &Info : ScalarToVectorMappings)
+      if (Info.Shape == Shape)
+        return M->getFunction(Info.VectorName);
+
+    return nullptr;
+  }
+  /// @}
+};
 
 template <typename T> class ArrayRef;
 class DemandedBits;
 class GetElementPtrInst;
 template <typename InstTy> class InterleaveGroup;
+class IRBuilderBase;
 class Loop;
 class ScalarEvolution;
 class TargetTransformInfo;
@@ -32,7 +297,20 @@
 class Value;
 
 namespace Intrinsic {
-enum ID : unsigned;
+typedef unsigned ID;
+}
+
+/// A helper function for converting Scalar types to vector types. If
+/// the incoming type is void, we return void. If the EC represents a
+/// scalar, we return the scalar type.
+inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
+  if (Scalar->isVoidTy() || EC.isScalar())
+    return Scalar;
+  return VectorType::get(Scalar, EC);
+}
+
+inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
+  return ToVectorTy(Scalar, ElementCount::getFixed(VF));
 }
 
 /// Identify if the intrinsic is trivially vectorizable.
@@ -72,16 +350,55 @@
 /// from the vector.
 Value *findScalarElement(Value *V, unsigned EltNo);
 
+/// If all non-negative \p Mask elements are the same value, return that value.
+/// If all elements are negative (undefined) or \p Mask contains different
+/// non-negative values, return -1.
+int getSplatIndex(ArrayRef<int> Mask);
+
 /// Get splat value if the input is a splat vector or return nullptr.
 /// The value may be extracted from a splat constants vector or from
 /// a sequence of instructions that broadcast a single value into a vector.
-const Value *getSplatValue(const Value *V);
+Value *getSplatValue(const Value *V);
 
-/// Return true if the input value is known to be a vector with all identical
-/// elements (potentially including undefined elements).
+/// Return true if each element of the vector value \p V is poisoned or equal to
+/// every other non-poisoned element. If an index element is specified, either
+/// every element of the vector is poisoned or the element at that index is not
+/// poisoned and equal to every other non-poisoned element.
 /// This may be more powerful than the related getSplatValue() because it is
 /// not limited by finding a scalar source value to a splatted vector.
-bool isSplatValue(const Value *V, unsigned Depth = 0);
+bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
+
+/// Replace each shuffle mask index with the scaled sequential indices for an
+/// equivalent mask of narrowed elements. Mask elements that are less than 0
+/// (sentinel values) are repeated in the output mask.
+///
+/// Example with Scale = 4:
+///   <4 x i32> <3, 2, 0, -1> -->
+///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
+///
+/// This is the reverse process of widening shuffle mask elements, but it always
+/// succeeds because the indexes can always be multiplied (scaled up) to map to
+/// narrower vector elements.
+void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
+                           SmallVectorImpl<int> &ScaledMask);
+
+/// Try to transform a shuffle mask by replacing elements with the scaled index
+/// for an equivalent mask of widened elements. If all mask elements that would
+/// map to a wider element of the new mask are the same negative number
+/// (sentinel value), that element of the new mask is the same value. If any
+/// element in a given slice is negative and some other element in that slice is
+/// not the same value, return false (partial matches with sentinel values are
+/// not allowed).
+///
+/// Example with Scale = 4:
+///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
+///   <4 x i32> <3, 2, 0, -1>
+///
+/// This is the reverse process of narrowing shuffle mask elements if it
+/// succeeds. This transform is not always possible because indexes may not
+/// divide evenly (scale down) to map to wider vector elements.
+bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
+                          SmallVectorImpl<int> &ScaledMask);
 
 /// Compute a map of integer instructions to their minimum legal type
 /// size.
@@ -158,7 +475,7 @@
 /// Note: The result is a mask of 0's and 1's, as opposed to the other
 /// create[*]Mask() utilities which create a shuffle mask (mask that
 /// consists of indices).
-Constant *createBitMaskForGaps(IRBuilder<> &Builder, unsigned VF,
+Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
                                const InterleaveGroup<Instruction> &Group);
 
 /// Create a mask with replicated elements.
@@ -173,8 +490,8 @@
 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
 ///
 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
-Constant *createReplicatedMask(IRBuilder<> &Builder, unsigned ReplicationFactor,
-                               unsigned VF);
+llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
+                                                unsigned VF);
 
 /// Create an interleave shuffle mask.
 ///
@@ -187,8 +504,7 @@
 /// For example, the mask for VF = 4 and NumVecs = 2 is:
 ///
 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
-Constant *createInterleaveMask(IRBuilder<> &Builder, unsigned VF,
-                               unsigned NumVecs);
+llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
 
 /// Create a stride shuffle mask.
 ///
@@ -202,8 +518,8 @@
 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
 ///
 ///   <0, 2, 4, 6>
-Constant *createStrideMask(IRBuilder<> &Builder, unsigned Start,
-                           unsigned Stride, unsigned VF);
+llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
+                                            unsigned VF);
 
 /// Create a sequential shuffle mask.
 ///
@@ -216,8 +532,8 @@
 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
 ///
 ///   <0, 1, 2, 3, undef, undef, undef, undef>
-Constant *createSequentialMask(IRBuilder<> &Builder, unsigned Start,
-                               unsigned NumInts, unsigned NumUndefs);
+llvm::SmallVector<int, 16>
+createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
 
 /// Concatenate a list of vectors.
 ///
@@ -226,22 +542,22 @@
 /// their element types should be the same. The number of elements in the
 /// vectors should also be the same; however, if the last vector has fewer
 /// elements, it will be padded with undefs.
-Value *concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs);
+Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
 
-/// Given a mask vector of the form <Y x i1>, Return true if all of the
-/// elements of this predicate mask are false or undef.  That is, return true
-/// if all lanes can be assumed inactive. 
+/// Given a mask vector of i1, Return true if all of the elements of this
+/// predicate mask are known to be false or undef.  That is, return true if all
+/// lanes can be assumed inactive.
 bool maskIsAllZeroOrUndef(Value *Mask);
 
-/// Given a mask vector of the form <Y x i1>, Return true if all of the
-/// elements of this predicate mask are true or undef.  That is, return true
-/// if all lanes can be assumed active. 
+/// Given a mask vector of i1, Return true if all of the elements of this
+/// predicate mask are known to be true or undef.  That is, return true if all
+/// lanes can be assumed active.
 bool maskIsAllOneOrUndef(Value *Mask);
 
 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
-  
+
 /// The group of interleaved loads/stores sharing the same stride and
 /// close to each other.
 ///
@@ -270,13 +586,12 @@
 /// the interleaved store group doesn't allow gaps.
 template <typename InstTy> class InterleaveGroup {
 public:
-  InterleaveGroup(uint32_t Factor, bool Reverse, uint32_t Align)
-      : Factor(Factor), Reverse(Reverse), Align(Align), InsertPos(nullptr) {}
+  InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
+      : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
+        InsertPos(nullptr) {}
 
-  InterleaveGroup(InstTy *Instr, int32_t Stride, uint32_t Align)
-      : Align(Align), InsertPos(Instr) {
-    assert(Align && "The alignment should be non-zero");
-
+  InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
+      : Alignment(Alignment), InsertPos(Instr) {
     Factor = std::abs(Stride);
     assert(Factor > 1 && "Invalid interleave factor");
 
@@ -286,7 +601,11 @@
 
   bool isReverse() const { return Reverse; }
   uint32_t getFactor() const { return Factor; }
-  uint32_t getAlignment() const { return Align; }
+  LLVM_ATTRIBUTE_DEPRECATED(uint32_t getAlignment() const,
+                            "Use getAlign instead.") {
+    return Alignment.value();
+  }
+  Align getAlign() const { return Alignment; }
   uint32_t getNumMembers() const { return Members.size(); }
 
   /// Try to insert a new member \p Instr with index \p Index and
@@ -294,9 +613,7 @@
   /// negative if it is the new leader.
   ///
   /// \returns false if the instruction doesn't belong to the group.
-  bool insertMember(InstTy *Instr, int32_t Index, uint32_t NewAlign) {
-    assert(NewAlign && "The new member's alignment should be non-zero");
-
+  bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
     // Make sure the key fits in an int32_t.
     Optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
     if (!MaybeKey)
@@ -328,7 +645,7 @@
     }
 
     // It's always safe to select the minimum alignment.
-    Align = std::min(Align, NewAlign);
+    Alignment = std::min(Alignment, NewAlign);
     Members[Key] = Instr;
     return true;
   }
@@ -338,11 +655,7 @@
   /// \returns nullptr if contains no such member.
   InstTy *getMember(uint32_t Index) const {
     int32_t Key = SmallestKey + Index;
-    auto Member = Members.find(Key);
-    if (Member == Members.end())
-      return nullptr;
-
-    return Member->second;
+    return Members.lookup(Key);
   }
 
   /// Get the index for the given member. Unlike the key in the member
@@ -387,7 +700,7 @@
 private:
   uint32_t Factor; // Interleave Factor.
   bool Reverse;
-  uint32_t Align;
+  Align Alignment;
   DenseMap<int32_t, InstTy *> Members;
   int32_t SmallestKey = 0;
   int32_t LargestKey = 0;
@@ -421,7 +734,7 @@
                         const LoopAccessInfo *LAI)
       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
 
-  ~InterleavedAccessInfo() { reset(); }
+  ~InterleavedAccessInfo() { invalidateGroups(); }
 
   /// Analyze the interleaved accesses and collect them in interleave
   /// groups. Substitute symbolic strides using \p Strides.
@@ -432,18 +745,23 @@
   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
   /// contrary to original assumption. Although we currently prevent group
   /// formation for predicated accesses, we may be able to relax this limitation
-  /// in the future once we handle more complicated blocks.
-  void reset() {
-    SmallPtrSet<InterleaveGroup<Instruction> *, 4> DelSet;
-    // Avoid releasing a pointer twice.
-    for (auto &I : InterleaveGroupMap)
-      DelSet.insert(I.second);
-    for (auto *Ptr : DelSet)
-      delete Ptr;
-    InterleaveGroupMap.clear();
-    RequiresScalarEpilogue = false;
-  }
+  /// in the future once we handle more complicated blocks. Returns true if any
+  /// groups were invalidated.
+  bool invalidateGroups() {
+    if (InterleaveGroups.empty()) {
+      assert(
+          !RequiresScalarEpilogue &&
+          "RequiresScalarEpilog should not be set without interleave groups");
+      return false;
+    }
 
+    InterleaveGroupMap.clear();
+    for (auto *Ptr : InterleaveGroups)
+      delete Ptr;
+    InterleaveGroups.clear();
+    RequiresScalarEpilogue = false;
+    return true;
+  }
 
   /// Check if \p Instr belongs to any interleave group.
   bool isInterleaved(Instruction *Instr) const {
@@ -455,9 +773,7 @@
   /// \returns nullptr if doesn't have such group.
   InterleaveGroup<Instruction> *
   getInterleaveGroup(const Instruction *Instr) const {
-    if (InterleaveGroupMap.count(Instr))
-      return InterleaveGroupMap.find(Instr)->second;
-    return nullptr;
+    return InterleaveGroupMap.lookup(Instr);
   }
 
   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
@@ -504,8 +820,8 @@
   struct StrideDescriptor {
     StrideDescriptor() = default;
     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
-                     unsigned Align)
-        : Stride(Stride), Scev(Scev), Size(Size), Align(Align) {}
+                     Align Alignment)
+        : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
 
     // The access's stride. It is negative for a reverse access.
     int64_t Stride = 0;
@@ -517,7 +833,7 @@
     uint64_t Size = 0;
 
     // The alignment of this access.
-    unsigned Align = 0;
+    Align Alignment;
   };
 
   /// A type for holding instructions and their stride descriptors.
@@ -528,11 +844,11 @@
   ///
   /// \returns the newly created interleave group.
   InterleaveGroup<Instruction> *
-  createInterleaveGroup(Instruction *Instr, int Stride, unsigned Align) {
+  createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
     assert(!InterleaveGroupMap.count(Instr) &&
            "Already in an interleaved access group");
     InterleaveGroupMap[Instr] =
-        new InterleaveGroup<Instruction>(Instr, Stride, Align);
+        new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
     return InterleaveGroupMap[Instr];
   }