Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:43:19

0001 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 // This file defines some vectorizer utilities.
0010 //
0011 //===----------------------------------------------------------------------===//
0012 
0013 #ifndef LLVM_ANALYSIS_VECTORUTILS_H
0014 #define LLVM_ANALYSIS_VECTORUTILS_H
0015 
0016 #include "llvm/ADT/MapVector.h"
0017 #include "llvm/ADT/SmallVector.h"
0018 #include "llvm/Analysis/LoopAccessAnalysis.h"
0019 #include "llvm/IR/Module.h"
0020 #include "llvm/IR/VFABIDemangler.h"
0021 #include "llvm/IR/VectorTypeUtils.h"
0022 #include "llvm/Support/CheckedArithmetic.h"
0023 
0024 namespace llvm {
0025 class TargetLibraryInfo;
0026 
0027 /// The Vector Function Database.
0028 ///
0029 /// Helper class used to find the vector functions associated to a
0030 /// scalar CallInst.
0031 class VFDatabase {
0032   /// The Module of the CallInst CI.
0033   const Module *M;
0034   /// The CallInst instance being queried for scalar to vector mappings.
0035   const CallInst &CI;
0036   /// List of vector functions descriptors associated to the call
0037   /// instruction.
0038   const SmallVector<VFInfo, 8> ScalarToVectorMappings;
0039 
0040   /// Retrieve the scalar-to-vector mappings associated to the rule of
0041   /// a vector Function ABI.
0042   static void getVFABIMappings(const CallInst &CI,
0043                                SmallVectorImpl<VFInfo> &Mappings) {
0044     if (!CI.getCalledFunction())
0045       return;
0046 
0047     const StringRef ScalarName = CI.getCalledFunction()->getName();
0048 
0049     SmallVector<std::string, 8> ListOfStrings;
0050     // The check for the vector-function-abi-variant attribute is done when
0051     // retrieving the vector variant names here.
0052     VFABI::getVectorVariantNames(CI, ListOfStrings);
0053     if (ListOfStrings.empty())
0054       return;
0055     for (const auto &MangledName : ListOfStrings) {
0056       const std::optional<VFInfo> Shape =
0057           VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType());
0058       // A match is found via scalar and vector names, and also by
0059       // ensuring that the variant described in the attribute has a
0060       // corresponding definition or declaration of the vector
0061       // function in the Module M.
0062       if (Shape && (Shape->ScalarName == ScalarName)) {
0063         assert(CI.getModule()->getFunction(Shape->VectorName) &&
0064                "Vector function is missing.");
0065         Mappings.push_back(*Shape);
0066       }
0067     }
0068   }
0069 
0070 public:
0071   /// Retrieve all the VFInfo instances associated to the CallInst CI.
0072   static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
0073     SmallVector<VFInfo, 8> Ret;
0074 
0075     // Get mappings from the Vector Function ABI variants.
0076     getVFABIMappings(CI, Ret);
0077 
0078     // Other non-VFABI variants should be retrieved here.
0079 
0080     return Ret;
0081   }
0082 
0083   static bool hasMaskedVariant(const CallInst &CI,
0084                                std::optional<ElementCount> VF = std::nullopt) {
0085     // Check whether we have at least one masked vector version of a scalar
0086     // function. If no VF is specified then we check for any masked variant,
0087     // otherwise we look for one that matches the supplied VF.
0088     auto Mappings = VFDatabase::getMappings(CI);
0089     for (VFInfo Info : Mappings)
0090       if (!VF || Info.Shape.VF == *VF)
0091         if (Info.isMasked())
0092           return true;
0093 
0094     return false;
0095   }
0096 
0097   /// Constructor, requires a CallInst instance.
0098   VFDatabase(CallInst &CI)
0099       : M(CI.getModule()), CI(CI),
0100         ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
0101 
0102   /// \defgroup VFDatabase query interface.
0103   ///
0104   /// @{
0105   /// Retrieve the Function with VFShape \p Shape.
0106   Function *getVectorizedFunction(const VFShape &Shape) const {
0107     if (Shape == VFShape::getScalarShape(CI.getFunctionType()))
0108       return CI.getCalledFunction();
0109 
0110     for (const auto &Info : ScalarToVectorMappings)
0111       if (Info.Shape == Shape)
0112         return M->getFunction(Info.VectorName);
0113 
0114     return nullptr;
0115   }
0116   /// @}
0117 };
0118 
0119 template <typename T> class ArrayRef;
0120 class DemandedBits;
0121 template <typename InstTy> class InterleaveGroup;
0122 class IRBuilderBase;
0123 class Loop;
0124 class TargetTransformInfo;
0125 class Value;
0126 
0127 namespace Intrinsic {
0128 typedef unsigned ID;
0129 }
0130 
0131 /// Identify if the intrinsic is trivially vectorizable.
0132 /// This method returns true if the intrinsic's argument types are all scalars
0133 /// for the scalar form of the intrinsic and all vectors (or scalars handled by
0134 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
0135 ///
0136 /// Note: isTriviallyVectorizable implies isTriviallyScalarizable.
0137 bool isTriviallyVectorizable(Intrinsic::ID ID);
0138 
0139 /// Identify if the intrinsic is trivially scalarizable.
0140 /// This method returns true following the same predicates of
0141 /// isTriviallyVectorizable.
0142 
0143 /// Note: There are intrinsics where implementing vectorization for the
0144 /// intrinsic is redundant, but we want to implement scalarization of the
0145 /// vector. To prevent the requirement that an intrinsic also implements
0146 /// vectorization we provide this seperate function.
0147 bool isTriviallyScalarizable(Intrinsic::ID ID, const TargetTransformInfo *TTI);
0148 
0149 /// Identifies if the vector form of the intrinsic has a scalar operand.
0150 /// \p TTI is used to consider target specific intrinsics, if no target specific
0151 /// intrinsics will be considered then it is appropriate to pass in nullptr.
0152 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx,
0153                                         const TargetTransformInfo *TTI);
0154 
0155 /// Identifies if the vector form of the intrinsic is overloaded on the type of
0156 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1.
0157 /// \p TTI is used to consider target specific intrinsics, if no target specific
0158 /// intrinsics will be considered then it is appropriate to pass in nullptr.
0159 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx,
0160                                             const TargetTransformInfo *TTI);
0161 
0162 /// Identifies if the vector form of the intrinsic that returns a struct is
0163 /// overloaded at the struct element index \p RetIdx. /// \p TTI is used to
0164 /// consider target specific intrinsics, if no target specific intrinsics
0165 /// will be considered then it is appropriate to pass in nullptr.
0166 bool isVectorIntrinsicWithStructReturnOverloadAtField(
0167     Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI);
0168 
0169 /// Returns intrinsic ID for call.
0170 /// For the input call instruction it finds mapping intrinsic and returns
0171 /// its intrinsic ID, in case it does not found it return not_intrinsic.
0172 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
0173                                           const TargetLibraryInfo *TLI);
0174 
0175 /// Given a vector and an element number, see if the scalar value is
0176 /// already around as a register, for example if it were inserted then extracted
0177 /// from the vector.
0178 Value *findScalarElement(Value *V, unsigned EltNo);
0179 
0180 /// If all non-negative \p Mask elements are the same value, return that value.
0181 /// If all elements are negative (undefined) or \p Mask contains different
0182 /// non-negative values, return -1.
0183 int getSplatIndex(ArrayRef<int> Mask);
0184 
0185 /// Get splat value if the input is a splat vector or return nullptr.
0186 /// The value may be extracted from a splat constants vector or from
0187 /// a sequence of instructions that broadcast a single value into a vector.
0188 Value *getSplatValue(const Value *V);
0189 
0190 /// Return true if each element of the vector value \p V is poisoned or equal to
0191 /// every other non-poisoned element. If an index element is specified, either
0192 /// every element of the vector is poisoned or the element at that index is not
0193 /// poisoned and equal to every other non-poisoned element.
0194 /// This may be more powerful than the related getSplatValue() because it is
0195 /// not limited by finding a scalar source value to a splatted vector.
0196 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
0197 
0198 /// Transform a shuffle mask's output demanded element mask into demanded
0199 /// element masks for the 2 operands, returns false if the mask isn't valid.
0200 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
0201 /// \p AllowUndefElts permits "-1" indices to be treated as undef.
0202 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
0203                             const APInt &DemandedElts, APInt &DemandedLHS,
0204                             APInt &DemandedRHS, bool AllowUndefElts = false);
0205 
0206 /// Replace each shuffle mask index with the scaled sequential indices for an
0207 /// equivalent mask of narrowed elements. Mask elements that are less than 0
0208 /// (sentinel values) are repeated in the output mask.
0209 ///
0210 /// Example with Scale = 4:
0211 ///   <4 x i32> <3, 2, 0, -1> -->
0212 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
0213 ///
0214 /// This is the reverse process of widening shuffle mask elements, but it always
0215 /// succeeds because the indexes can always be multiplied (scaled up) to map to
0216 /// narrower vector elements.
0217 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
0218                            SmallVectorImpl<int> &ScaledMask);
0219 
0220 /// Try to transform a shuffle mask by replacing elements with the scaled index
0221 /// for an equivalent mask of widened elements. If all mask elements that would
0222 /// map to a wider element of the new mask are the same negative number
0223 /// (sentinel value), that element of the new mask is the same value. If any
0224 /// element in a given slice is negative and some other element in that slice is
0225 /// not the same value, return false (partial matches with sentinel values are
0226 /// not allowed).
0227 ///
0228 /// Example with Scale = 4:
0229 ///   <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
0230 ///   <4 x i32> <3, 2, 0, -1>
0231 ///
0232 /// This is the reverse process of narrowing shuffle mask elements if it
0233 /// succeeds. This transform is not always possible because indexes may not
0234 /// divide evenly (scale down) to map to wider vector elements.
0235 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
0236                           SmallVectorImpl<int> &ScaledMask);
0237 
0238 /// A variant of the previous method which is specialized for Scale=2, and
0239 /// treats -1 as undef and allows widening when a wider element is partially
0240 /// undef in the narrow form of the mask.  This transformation discards
0241 /// information about which bytes in the original shuffle were undef.
0242 bool widenShuffleMaskElts(ArrayRef<int> M, SmallVectorImpl<int> &NewMask);
0243 
0244 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target
0245 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts.
0246 /// This will assert unless NumDstElts is a multiple of Mask.size (or
0247 /// vice-versa). Returns false on failure, and ScaledMask will be in an
0248 /// undefined state.
0249 bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask,
0250                           SmallVectorImpl<int> &ScaledMask);
0251 
0252 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
0253 /// to get the shuffle mask with widest possible elements.
0254 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
0255                                   SmallVectorImpl<int> &ScaledMask);
0256 
0257 /// Splits and processes shuffle mask depending on the number of input and
0258 /// output registers. The function does 2 main things: 1) splits the
0259 /// source/destination vectors into real registers; 2) do the mask analysis to
0260 /// identify which real registers are permuted. Then the function processes
0261 /// resulting registers mask using provided action items. If no input register
0262 /// is defined, \p NoInputAction action is used. If only 1 input register is
0263 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
0264 /// process > 2 input registers and masks.
0265 /// \param Mask Original shuffle mask.
0266 /// \param NumOfSrcRegs Number of source registers.
0267 /// \param NumOfDestRegs Number of destination registers.
0268 /// \param NumOfUsedRegs Number of actually used destination registers.
0269 void processShuffleMasks(
0270     ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
0271     unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
0272     function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
0273     function_ref<void(ArrayRef<int>, unsigned, unsigned, bool)>
0274         ManyInputsAction);
0275 
0276 /// Compute the demanded elements mask of horizontal binary operations. A
0277 /// horizontal operation combines two adjacent elements in a vector operand.
0278 /// This function returns a mask for the elements that correspond to the first
0279 /// operand of this horizontal combination. For example, for two vectors
0280 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the
0281 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the
0282 /// result of this function to the left by 1.
0283 ///
0284 /// \param VectorBitWidth the total bit width of the vector
0285 /// \param DemandedElts   the demanded elements mask for the operation
0286 /// \param DemandedLHS    the demanded elements mask for the left operand
0287 /// \param DemandedRHS    the demanded elements mask for the right operand
0288 void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth,
0289                                          const APInt &DemandedElts,
0290                                          APInt &DemandedLHS,
0291                                          APInt &DemandedRHS);
0292 
0293 /// Compute a map of integer instructions to their minimum legal type
0294 /// size.
0295 ///
0296 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
0297 /// type (e.g. i32) whenever arithmetic is performed on them.
0298 ///
0299 /// For targets with native i8 or i16 operations, usually InstCombine can shrink
0300 /// the arithmetic type down again. However InstCombine refuses to create
0301 /// illegal types, so for targets without i8 or i16 registers, the lengthening
0302 /// and shrinking remains.
0303 ///
0304 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
0305 /// their scalar equivalents do not, so during vectorization it is important to
0306 /// remove these lengthens and truncates when deciding the profitability of
0307 /// vectorization.
0308 ///
0309 /// This function analyzes the given range of instructions and determines the
0310 /// minimum type size each can be converted to. It attempts to remove or
0311 /// minimize type size changes across each def-use chain, so for example in the
0312 /// following code:
0313 ///
0314 ///   %1 = load i8, i8*
0315 ///   %2 = add i8 %1, 2
0316 ///   %3 = load i16, i16*
0317 ///   %4 = zext i8 %2 to i32
0318 ///   %5 = zext i16 %3 to i32
0319 ///   %6 = add i32 %4, %5
0320 ///   %7 = trunc i32 %6 to i16
0321 ///
0322 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
0323 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
0324 ///
0325 /// If the optional TargetTransformInfo is provided, this function tries harder
0326 /// to do less work by only looking at illegal types.
0327 MapVector<Instruction*, uint64_t>
0328 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
0329                          DemandedBits &DB,
0330                          const TargetTransformInfo *TTI=nullptr);
0331 
0332 /// Compute the union of two access-group lists.
0333 ///
0334 /// If the list contains just one access group, it is returned directly. If the
0335 /// list is empty, returns nullptr.
0336 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
0337 
0338 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
0339 /// are both in. If either instruction does not access memory at all, it is
0340 /// considered to be in every list.
0341 ///
0342 /// If the list contains just one access group, it is returned directly. If the
0343 /// list is empty, returns nullptr.
0344 MDNode *intersectAccessGroups(const Instruction *Inst1,
0345                               const Instruction *Inst2);
0346 
0347 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
0348 /// MD_nontemporal, MD_access_group, MD_mmra].
0349 /// For K in Kinds, we get the MDNode for K from each of the
0350 /// elements of VL, compute their "intersection" (i.e., the most generic
0351 /// metadata value that covers all of the individual values), and set I's
0352 /// metadata for M equal to the intersection value.
0353 ///
0354 /// This function always sets a (possibly null) value for each K in Kinds.
0355 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
0356 
0357 /// Create a mask that filters the members of an interleave group where there
0358 /// are gaps.
0359 ///
0360 /// For example, the mask for \p Group with interleave-factor 3
0361 /// and \p VF 4, that has only its first member present is:
0362 ///
0363 ///   <1,0,0,1,0,0,1,0,0,1,0,0>
0364 ///
0365 /// Note: The result is a mask of 0's and 1's, as opposed to the other
0366 /// create[*]Mask() utilities which create a shuffle mask (mask that
0367 /// consists of indices).
0368 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
0369                                const InterleaveGroup<Instruction> &Group);
0370 
0371 /// Create a mask with replicated elements.
0372 ///
0373 /// This function creates a shuffle mask for replicating each of the \p VF
0374 /// elements in a vector \p ReplicationFactor times. It can be used to
0375 /// transform a mask of \p VF elements into a mask of
0376 /// \p VF * \p ReplicationFactor elements used by a predicated
0377 /// interleaved-group of loads/stores whose Interleaved-factor ==
0378 /// \p ReplicationFactor.
0379 ///
0380 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
0381 ///
0382 ///   <0,0,0,1,1,1,2,2,2,3,3,3>
0383 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
0384                                                 unsigned VF);
0385 
0386 /// Create an interleave shuffle mask.
0387 ///
0388 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
0389 /// vectorization factor \p VF into a single wide vector. The mask is of the
0390 /// form:
0391 ///
0392 ///   <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
0393 ///
0394 /// For example, the mask for VF = 4 and NumVecs = 2 is:
0395 ///
0396 ///   <0, 4, 1, 5, 2, 6, 3, 7>.
0397 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
0398 
0399 /// Create a stride shuffle mask.
0400 ///
0401 /// This function creates a shuffle mask whose elements begin at \p Start and
0402 /// are incremented by \p Stride. The mask can be used to deinterleave an
0403 /// interleaved vector into separate vectors of vectorization factor \p VF. The
0404 /// mask is of the form:
0405 ///
0406 ///   <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
0407 ///
0408 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
0409 ///
0410 ///   <0, 2, 4, 6>
0411 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
0412                                             unsigned VF);
0413 
0414 /// Create a sequential shuffle mask.
0415 ///
0416 /// This function creates shuffle mask whose elements are sequential and begin
0417 /// at \p Start.  The mask contains \p NumInts integers and is padded with \p
0418 /// NumUndefs undef values. The mask is of the form:
0419 ///
0420 ///   <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
0421 ///
0422 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
0423 ///
0424 ///   <0, 1, 2, 3, undef, undef, undef, undef>
0425 llvm::SmallVector<int, 16>
0426 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
0427 
0428 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
0429 /// mask assuming both operands are identical. This assumes that the unary
0430 /// shuffle will use elements from operand 0 (operand 1 will be unused).
0431 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
0432                                            unsigned NumElts);
0433 
0434 /// Concatenate a list of vectors.
0435 ///
0436 /// This function generates code that concatenate the vectors in \p Vecs into a
0437 /// single large vector. The number of vectors should be greater than one, and
0438 /// their element types should be the same. The number of elements in the
0439 /// vectors should also be the same; however, if the last vector has fewer
0440 /// elements, it will be padded with undefs.
0441 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
0442 
0443 /// Given a mask vector of i1, Return true if all of the elements of this
0444 /// predicate mask are known to be false or undef.  That is, return true if all
0445 /// lanes can be assumed inactive.
0446 bool maskIsAllZeroOrUndef(Value *Mask);
0447 
0448 /// Given a mask vector of i1, Return true if all of the elements of this
0449 /// predicate mask are known to be true or undef.  That is, return true if all
0450 /// lanes can be assumed active.
0451 bool maskIsAllOneOrUndef(Value *Mask);
0452 
0453 /// Given a mask vector of i1, Return true if any of the elements of this
0454 /// predicate mask are known to be true or undef.  That is, return true if at
0455 /// least one lane can be assumed active.
0456 bool maskContainsAllOneOrUndef(Value *Mask);
0457 
0458 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
0459 /// for each lane which may be active.
0460 APInt possiblyDemandedEltsInMask(Value *Mask);
0461 
0462 /// The group of interleaved loads/stores sharing the same stride and
0463 /// close to each other.
0464 ///
0465 /// Each member in this group has an index starting from 0, and the largest
0466 /// index should be less than interleaved factor, which is equal to the absolute
0467 /// value of the access's stride.
0468 ///
0469 /// E.g. An interleaved load group of factor 4:
0470 ///        for (unsigned i = 0; i < 1024; i+=4) {
0471 ///          a = A[i];                           // Member of index 0
0472 ///          b = A[i+1];                         // Member of index 1
0473 ///          d = A[i+3];                         // Member of index 3
0474 ///          ...
0475 ///        }
0476 ///
0477 ///      An interleaved store group of factor 4:
0478 ///        for (unsigned i = 0; i < 1024; i+=4) {
0479 ///          ...
0480 ///          A[i]   = a;                         // Member of index 0
0481 ///          A[i+1] = b;                         // Member of index 1
0482 ///          A[i+2] = c;                         // Member of index 2
0483 ///          A[i+3] = d;                         // Member of index 3
0484 ///        }
0485 ///
0486 /// Note: the interleaved load group could have gaps (missing members), but
0487 /// the interleaved store group doesn't allow gaps.
0488 template <typename InstTy> class InterleaveGroup {
0489 public:
0490   InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
0491       : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
0492         InsertPos(nullptr) {}
0493 
0494   InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
0495       : Alignment(Alignment), InsertPos(Instr) {
0496     Factor = std::abs(Stride);
0497     assert(Factor > 1 && "Invalid interleave factor");
0498 
0499     Reverse = Stride < 0;
0500     Members[0] = Instr;
0501   }
0502 
0503   bool isReverse() const { return Reverse; }
0504   uint32_t getFactor() const { return Factor; }
0505   Align getAlign() const { return Alignment; }
0506   uint32_t getNumMembers() const { return Members.size(); }
0507 
0508   /// Try to insert a new member \p Instr with index \p Index and
0509   /// alignment \p NewAlign. The index is related to the leader and it could be
0510   /// negative if it is the new leader.
0511   ///
0512   /// \returns false if the instruction doesn't belong to the group.
0513   bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
0514     // Make sure the key fits in an int32_t.
0515     std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
0516     if (!MaybeKey)
0517       return false;
0518     int32_t Key = *MaybeKey;
0519 
0520     // Skip if the key is used for either the tombstone or empty special values.
0521     if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
0522         DenseMapInfo<int32_t>::getEmptyKey() == Key)
0523       return false;
0524 
0525     // Skip if there is already a member with the same index.
0526     if (Members.contains(Key))
0527       return false;
0528 
0529     if (Key > LargestKey) {
0530       // The largest index is always less than the interleave factor.
0531       if (Index >= static_cast<int32_t>(Factor))
0532         return false;
0533 
0534       LargestKey = Key;
0535     } else if (Key < SmallestKey) {
0536 
0537       // Make sure the largest index fits in an int32_t.
0538       std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
0539       if (!MaybeLargestIndex)
0540         return false;
0541 
0542       // The largest index is always less than the interleave factor.
0543       if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
0544         return false;
0545 
0546       SmallestKey = Key;
0547     }
0548 
0549     // It's always safe to select the minimum alignment.
0550     Alignment = std::min(Alignment, NewAlign);
0551     Members[Key] = Instr;
0552     return true;
0553   }
0554 
0555   /// Get the member with the given index \p Index
0556   ///
0557   /// \returns nullptr if contains no such member.
0558   InstTy *getMember(uint32_t Index) const {
0559     int32_t Key = SmallestKey + Index;
0560     return Members.lookup(Key);
0561   }
0562 
0563   /// Get the index for the given member. Unlike the key in the member
0564   /// map, the index starts from 0.
0565   uint32_t getIndex(const InstTy *Instr) const {
0566     for (auto I : Members) {
0567       if (I.second == Instr)
0568         return I.first - SmallestKey;
0569     }
0570 
0571     llvm_unreachable("InterleaveGroup contains no such member");
0572   }
0573 
0574   InstTy *getInsertPos() const { return InsertPos; }
0575   void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
0576 
0577   /// Add metadata (e.g. alias info) from the instructions in this group to \p
0578   /// NewInst.
0579   ///
0580   /// FIXME: this function currently does not add noalias metadata a'la
0581   /// addNewMedata.  To do that we need to compute the intersection of the
0582   /// noalias info from all members.
0583   void addMetadata(InstTy *NewInst) const;
0584 
0585   /// Returns true if this Group requires a scalar iteration to handle gaps.
0586   bool requiresScalarEpilogue() const {
0587     // If the last member of the Group exists, then a scalar epilog is not
0588     // needed for this group.
0589     if (getMember(getFactor() - 1))
0590       return false;
0591 
0592     // We have a group with gaps. It therefore can't be a reversed access,
0593     // because such groups get invalidated (TODO).
0594     assert(!isReverse() && "Group should have been invalidated");
0595 
0596     // This is a group of loads, with gaps, and without a last-member
0597     return true;
0598   }
0599 
0600 private:
0601   uint32_t Factor; // Interleave Factor.
0602   bool Reverse;
0603   Align Alignment;
0604   DenseMap<int32_t, InstTy *> Members;
0605   int32_t SmallestKey = 0;
0606   int32_t LargestKey = 0;
0607 
0608   // To avoid breaking dependences, vectorized instructions of an interleave
0609   // group should be inserted at either the first load or the last store in
0610   // program order.
0611   //
0612   // E.g. %even = load i32             // Insert Position
0613   //      %add = add i32 %even         // Use of %even
0614   //      %odd = load i32
0615   //
0616   //      store i32 %even
0617   //      %odd = add i32               // Def of %odd
0618   //      store i32 %odd               // Insert Position
0619   InstTy *InsertPos;
0620 };
0621 
0622 /// Drive the analysis of interleaved memory accesses in the loop.
0623 ///
0624 /// Use this class to analyze interleaved accesses only when we can vectorize
0625 /// a loop. Otherwise it's meaningless to do analysis as the vectorization
0626 /// on interleaved accesses is unsafe.
0627 ///
0628 /// The analysis collects interleave groups and records the relationships
0629 /// between the member and the group in a map.
0630 class InterleavedAccessInfo {
0631 public:
0632   InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
0633                         DominatorTree *DT, LoopInfo *LI,
0634                         const LoopAccessInfo *LAI)
0635       : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
0636 
0637   ~InterleavedAccessInfo() { invalidateGroups(); }
0638 
0639   /// Analyze the interleaved accesses and collect them in interleave
0640   /// groups. Substitute symbolic strides using \p Strides.
0641   /// Consider also predicated loads/stores in the analysis if
0642   /// \p EnableMaskedInterleavedGroup is true.
0643   void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
0644 
0645   /// Invalidate groups, e.g., in case all blocks in loop will be predicated
0646   /// contrary to original assumption. Although we currently prevent group
0647   /// formation for predicated accesses, we may be able to relax this limitation
0648   /// in the future once we handle more complicated blocks. Returns true if any
0649   /// groups were invalidated.
0650   bool invalidateGroups() {
0651     if (InterleaveGroups.empty()) {
0652       assert(
0653           !RequiresScalarEpilogue &&
0654           "RequiresScalarEpilog should not be set without interleave groups");
0655       return false;
0656     }
0657 
0658     InterleaveGroupMap.clear();
0659     for (auto *Ptr : InterleaveGroups)
0660       delete Ptr;
0661     InterleaveGroups.clear();
0662     RequiresScalarEpilogue = false;
0663     return true;
0664   }
0665 
0666   /// Check if \p Instr belongs to any interleave group.
0667   bool isInterleaved(Instruction *Instr) const {
0668     return InterleaveGroupMap.contains(Instr);
0669   }
0670 
0671   /// Get the interleave group that \p Instr belongs to.
0672   ///
0673   /// \returns nullptr if doesn't have such group.
0674   InterleaveGroup<Instruction> *
0675   getInterleaveGroup(const Instruction *Instr) const {
0676     return InterleaveGroupMap.lookup(Instr);
0677   }
0678 
0679   iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
0680   getInterleaveGroups() {
0681     return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
0682   }
0683 
0684   /// Returns true if an interleaved group that may access memory
0685   /// out-of-bounds requires a scalar epilogue iteration for correctness.
0686   bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
0687 
0688   /// Invalidate groups that require a scalar epilogue (due to gaps). This can
0689   /// happen when optimizing for size forbids a scalar epilogue, and the gap
0690   /// cannot be filtered by masking the load/store.
0691   void invalidateGroupsRequiringScalarEpilogue();
0692 
0693   /// Returns true if we have any interleave groups.
0694   bool hasGroups() const { return !InterleaveGroups.empty(); }
0695 
0696 private:
0697   /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
0698   /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
0699   /// The interleaved access analysis can also add new predicates (for example
0700   /// by versioning strides of pointers).
0701   PredicatedScalarEvolution &PSE;
0702 
0703   Loop *TheLoop;
0704   DominatorTree *DT;
0705   LoopInfo *LI;
0706   const LoopAccessInfo *LAI;
0707 
0708   /// True if the loop may contain non-reversed interleaved groups with
0709   /// out-of-bounds accesses. We ensure we don't speculatively access memory
0710   /// out-of-bounds by executing at least one scalar epilogue iteration.
0711   bool RequiresScalarEpilogue = false;
0712 
0713   /// Holds the relationships between the members and the interleave group.
0714   DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
0715 
0716   SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
0717 
0718   /// Holds dependences among the memory accesses in the loop. It maps a source
0719   /// access to a set of dependent sink accesses.
0720   DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
0721 
0722   /// The descriptor for a strided memory access.
0723   struct StrideDescriptor {
0724     StrideDescriptor() = default;
0725     StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
0726                      Align Alignment)
0727         : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
0728 
0729     // The access's stride. It is negative for a reverse access.
0730     int64_t Stride = 0;
0731 
0732     // The scalar expression of this access.
0733     const SCEV *Scev = nullptr;
0734 
0735     // The size of the memory object.
0736     uint64_t Size = 0;
0737 
0738     // The alignment of this access.
0739     Align Alignment;
0740   };
0741 
0742   /// A type for holding instructions and their stride descriptors.
0743   using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
0744 
0745   /// Create a new interleave group with the given instruction \p Instr,
0746   /// stride \p Stride and alignment \p Align.
0747   ///
0748   /// \returns the newly created interleave group.
0749   InterleaveGroup<Instruction> *
0750   createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
0751     assert(!InterleaveGroupMap.count(Instr) &&
0752            "Already in an interleaved access group");
0753     InterleaveGroupMap[Instr] =
0754         new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
0755     InterleaveGroups.insert(InterleaveGroupMap[Instr]);
0756     return InterleaveGroupMap[Instr];
0757   }
0758 
0759   /// Release the group and remove all the relationships.
0760   void releaseGroup(InterleaveGroup<Instruction> *Group) {
0761     InterleaveGroups.erase(Group);
0762     releaseGroupWithoutRemovingFromSet(Group);
0763   }
0764 
0765   /// Do everything necessary to release the group, apart from removing it from
0766   /// the InterleaveGroups set.
0767   void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) {
0768     for (unsigned i = 0; i < Group->getFactor(); i++)
0769       if (Instruction *Member = Group->getMember(i))
0770         InterleaveGroupMap.erase(Member);
0771 
0772     delete Group;
0773   }
0774 
0775   /// Collect all the accesses with a constant stride in program order.
0776   void collectConstStrideAccesses(
0777       MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
0778       const DenseMap<Value *, const SCEV *> &Strides);
0779 
0780   /// Returns true if \p Stride is allowed in an interleaved group.
0781   static bool isStrided(int Stride);
0782 
0783   /// Returns true if \p BB is a predicated block.
0784   bool isPredicated(BasicBlock *BB) const {
0785     return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
0786   }
0787 
0788   /// Returns true if LoopAccessInfo can be used for dependence queries.
0789   bool areDependencesValid() const {
0790     return LAI && LAI->getDepChecker().getDependences();
0791   }
0792 
0793   /// Returns true if memory accesses \p A and \p B can be reordered, if
0794   /// necessary, when constructing interleaved groups.
0795   ///
0796   /// \p A must precede \p B in program order. We return false if reordering is
0797   /// not necessary or is prevented because \p A and \p B may be dependent.
0798   bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
0799                                                  StrideEntry *B) const {
0800     // Code motion for interleaved accesses can potentially hoist strided loads
0801     // and sink strided stores. The code below checks the legality of the
0802     // following two conditions:
0803     //
0804     // 1. Potentially moving a strided load (B) before any store (A) that
0805     //    precedes B, or
0806     //
0807     // 2. Potentially moving a strided store (A) after any load or store (B)
0808     //    that A precedes.
0809     //
0810     // It's legal to reorder A and B if we know there isn't a dependence from A
0811     // to B. Note that this determination is conservative since some
0812     // dependences could potentially be reordered safely.
0813 
0814     // A is potentially the source of a dependence.
0815     auto *Src = A->first;
0816     auto SrcDes = A->second;
0817 
0818     // B is potentially the sink of a dependence.
0819     auto *Sink = B->first;
0820     auto SinkDes = B->second;
0821 
0822     // Code motion for interleaved accesses can't violate WAR dependences.
0823     // Thus, reordering is legal if the source isn't a write.
0824     if (!Src->mayWriteToMemory())
0825       return true;
0826 
0827     // At least one of the accesses must be strided.
0828     if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
0829       return true;
0830 
0831     // If dependence information is not available from LoopAccessInfo,
0832     // conservatively assume the instructions can't be reordered.
0833     if (!areDependencesValid())
0834       return false;
0835 
0836     // If we know there is a dependence from source to sink, assume the
0837     // instructions can't be reordered. Otherwise, reordering is legal.
0838     return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink);
0839   }
0840 
0841   /// Collect the dependences from LoopAccessInfo.
0842   ///
0843   /// We process the dependences once during the interleaved access analysis to
0844   /// enable constant-time dependence queries.
0845   void collectDependences() {
0846     if (!areDependencesValid())
0847       return;
0848     const auto &DepChecker = LAI->getDepChecker();
0849     auto *Deps = DepChecker.getDependences();
0850     for (auto Dep : *Deps)
0851       Dependences[Dep.getSource(DepChecker)].insert(
0852           Dep.getDestination(DepChecker));
0853   }
0854 };
0855 
0856 } // llvm namespace
0857 
0858 #endif