|
|
|||
File indexing completed on 2026-05-10 08:43:19
0001 //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===// 0002 // 0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 0004 // See https://llvm.org/LICENSE.txt for license information. 0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 0006 // 0007 //===----------------------------------------------------------------------===// 0008 // 0009 // This file defines some vectorizer utilities. 0010 // 0011 //===----------------------------------------------------------------------===// 0012 0013 #ifndef LLVM_ANALYSIS_VECTORUTILS_H 0014 #define LLVM_ANALYSIS_VECTORUTILS_H 0015 0016 #include "llvm/ADT/MapVector.h" 0017 #include "llvm/ADT/SmallVector.h" 0018 #include "llvm/Analysis/LoopAccessAnalysis.h" 0019 #include "llvm/IR/Module.h" 0020 #include "llvm/IR/VFABIDemangler.h" 0021 #include "llvm/IR/VectorTypeUtils.h" 0022 #include "llvm/Support/CheckedArithmetic.h" 0023 0024 namespace llvm { 0025 class TargetLibraryInfo; 0026 0027 /// The Vector Function Database. 0028 /// 0029 /// Helper class used to find the vector functions associated to a 0030 /// scalar CallInst. 0031 class VFDatabase { 0032 /// The Module of the CallInst CI. 0033 const Module *M; 0034 /// The CallInst instance being queried for scalar to vector mappings. 0035 const CallInst &CI; 0036 /// List of vector functions descriptors associated to the call 0037 /// instruction. 0038 const SmallVector<VFInfo, 8> ScalarToVectorMappings; 0039 0040 /// Retrieve the scalar-to-vector mappings associated to the rule of 0041 /// a vector Function ABI. 0042 static void getVFABIMappings(const CallInst &CI, 0043 SmallVectorImpl<VFInfo> &Mappings) { 0044 if (!CI.getCalledFunction()) 0045 return; 0046 0047 const StringRef ScalarName = CI.getCalledFunction()->getName(); 0048 0049 SmallVector<std::string, 8> ListOfStrings; 0050 // The check for the vector-function-abi-variant attribute is done when 0051 // retrieving the vector variant names here. 0052 VFABI::getVectorVariantNames(CI, ListOfStrings); 0053 if (ListOfStrings.empty()) 0054 return; 0055 for (const auto &MangledName : ListOfStrings) { 0056 const std::optional<VFInfo> Shape = 0057 VFABI::tryDemangleForVFABI(MangledName, CI.getFunctionType()); 0058 // A match is found via scalar and vector names, and also by 0059 // ensuring that the variant described in the attribute has a 0060 // corresponding definition or declaration of the vector 0061 // function in the Module M. 0062 if (Shape && (Shape->ScalarName == ScalarName)) { 0063 assert(CI.getModule()->getFunction(Shape->VectorName) && 0064 "Vector function is missing."); 0065 Mappings.push_back(*Shape); 0066 } 0067 } 0068 } 0069 0070 public: 0071 /// Retrieve all the VFInfo instances associated to the CallInst CI. 0072 static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) { 0073 SmallVector<VFInfo, 8> Ret; 0074 0075 // Get mappings from the Vector Function ABI variants. 0076 getVFABIMappings(CI, Ret); 0077 0078 // Other non-VFABI variants should be retrieved here. 0079 0080 return Ret; 0081 } 0082 0083 static bool hasMaskedVariant(const CallInst &CI, 0084 std::optional<ElementCount> VF = std::nullopt) { 0085 // Check whether we have at least one masked vector version of a scalar 0086 // function. If no VF is specified then we check for any masked variant, 0087 // otherwise we look for one that matches the supplied VF. 0088 auto Mappings = VFDatabase::getMappings(CI); 0089 for (VFInfo Info : Mappings) 0090 if (!VF || Info.Shape.VF == *VF) 0091 if (Info.isMasked()) 0092 return true; 0093 0094 return false; 0095 } 0096 0097 /// Constructor, requires a CallInst instance. 0098 VFDatabase(CallInst &CI) 0099 : M(CI.getModule()), CI(CI), 0100 ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} 0101 0102 /// \defgroup VFDatabase query interface. 0103 /// 0104 /// @{ 0105 /// Retrieve the Function with VFShape \p Shape. 0106 Function *getVectorizedFunction(const VFShape &Shape) const { 0107 if (Shape == VFShape::getScalarShape(CI.getFunctionType())) 0108 return CI.getCalledFunction(); 0109 0110 for (const auto &Info : ScalarToVectorMappings) 0111 if (Info.Shape == Shape) 0112 return M->getFunction(Info.VectorName); 0113 0114 return nullptr; 0115 } 0116 /// @} 0117 }; 0118 0119 template <typename T> class ArrayRef; 0120 class DemandedBits; 0121 template <typename InstTy> class InterleaveGroup; 0122 class IRBuilderBase; 0123 class Loop; 0124 class TargetTransformInfo; 0125 class Value; 0126 0127 namespace Intrinsic { 0128 typedef unsigned ID; 0129 } 0130 0131 /// Identify if the intrinsic is trivially vectorizable. 0132 /// This method returns true if the intrinsic's argument types are all scalars 0133 /// for the scalar form of the intrinsic and all vectors (or scalars handled by 0134 /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic. 0135 /// 0136 /// Note: isTriviallyVectorizable implies isTriviallyScalarizable. 0137 bool isTriviallyVectorizable(Intrinsic::ID ID); 0138 0139 /// Identify if the intrinsic is trivially scalarizable. 0140 /// This method returns true following the same predicates of 0141 /// isTriviallyVectorizable. 0142 0143 /// Note: There are intrinsics where implementing vectorization for the 0144 /// intrinsic is redundant, but we want to implement scalarization of the 0145 /// vector. To prevent the requirement that an intrinsic also implements 0146 /// vectorization we provide this seperate function. 0147 bool isTriviallyScalarizable(Intrinsic::ID ID, const TargetTransformInfo *TTI); 0148 0149 /// Identifies if the vector form of the intrinsic has a scalar operand. 0150 /// \p TTI is used to consider target specific intrinsics, if no target specific 0151 /// intrinsics will be considered then it is appropriate to pass in nullptr. 0152 bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx, 0153 const TargetTransformInfo *TTI); 0154 0155 /// Identifies if the vector form of the intrinsic is overloaded on the type of 0156 /// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. 0157 /// \p TTI is used to consider target specific intrinsics, if no target specific 0158 /// intrinsics will be considered then it is appropriate to pass in nullptr. 0159 bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx, 0160 const TargetTransformInfo *TTI); 0161 0162 /// Identifies if the vector form of the intrinsic that returns a struct is 0163 /// overloaded at the struct element index \p RetIdx. /// \p TTI is used to 0164 /// consider target specific intrinsics, if no target specific intrinsics 0165 /// will be considered then it is appropriate to pass in nullptr. 0166 bool isVectorIntrinsicWithStructReturnOverloadAtField( 0167 Intrinsic::ID ID, int RetIdx, const TargetTransformInfo *TTI); 0168 0169 /// Returns intrinsic ID for call. 0170 /// For the input call instruction it finds mapping intrinsic and returns 0171 /// its intrinsic ID, in case it does not found it return not_intrinsic. 0172 Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI, 0173 const TargetLibraryInfo *TLI); 0174 0175 /// Given a vector and an element number, see if the scalar value is 0176 /// already around as a register, for example if it were inserted then extracted 0177 /// from the vector. 0178 Value *findScalarElement(Value *V, unsigned EltNo); 0179 0180 /// If all non-negative \p Mask elements are the same value, return that value. 0181 /// If all elements are negative (undefined) or \p Mask contains different 0182 /// non-negative values, return -1. 0183 int getSplatIndex(ArrayRef<int> Mask); 0184 0185 /// Get splat value if the input is a splat vector or return nullptr. 0186 /// The value may be extracted from a splat constants vector or from 0187 /// a sequence of instructions that broadcast a single value into a vector. 0188 Value *getSplatValue(const Value *V); 0189 0190 /// Return true if each element of the vector value \p V is poisoned or equal to 0191 /// every other non-poisoned element. If an index element is specified, either 0192 /// every element of the vector is poisoned or the element at that index is not 0193 /// poisoned and equal to every other non-poisoned element. 0194 /// This may be more powerful than the related getSplatValue() because it is 0195 /// not limited by finding a scalar source value to a splatted vector. 0196 bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); 0197 0198 /// Transform a shuffle mask's output demanded element mask into demanded 0199 /// element masks for the 2 operands, returns false if the mask isn't valid. 0200 /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth]. 0201 /// \p AllowUndefElts permits "-1" indices to be treated as undef. 0202 bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask, 0203 const APInt &DemandedElts, APInt &DemandedLHS, 0204 APInt &DemandedRHS, bool AllowUndefElts = false); 0205 0206 /// Replace each shuffle mask index with the scaled sequential indices for an 0207 /// equivalent mask of narrowed elements. Mask elements that are less than 0 0208 /// (sentinel values) are repeated in the output mask. 0209 /// 0210 /// Example with Scale = 4: 0211 /// <4 x i32> <3, 2, 0, -1> --> 0212 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> 0213 /// 0214 /// This is the reverse process of widening shuffle mask elements, but it always 0215 /// succeeds because the indexes can always be multiplied (scaled up) to map to 0216 /// narrower vector elements. 0217 void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask, 0218 SmallVectorImpl<int> &ScaledMask); 0219 0220 /// Try to transform a shuffle mask by replacing elements with the scaled index 0221 /// for an equivalent mask of widened elements. If all mask elements that would 0222 /// map to a wider element of the new mask are the same negative number 0223 /// (sentinel value), that element of the new mask is the same value. If any 0224 /// element in a given slice is negative and some other element in that slice is 0225 /// not the same value, return false (partial matches with sentinel values are 0226 /// not allowed). 0227 /// 0228 /// Example with Scale = 4: 0229 /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> --> 0230 /// <4 x i32> <3, 2, 0, -1> 0231 /// 0232 /// This is the reverse process of narrowing shuffle mask elements if it 0233 /// succeeds. This transform is not always possible because indexes may not 0234 /// divide evenly (scale down) to map to wider vector elements. 0235 bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask, 0236 SmallVectorImpl<int> &ScaledMask); 0237 0238 /// A variant of the previous method which is specialized for Scale=2, and 0239 /// treats -1 as undef and allows widening when a wider element is partially 0240 /// undef in the narrow form of the mask. This transformation discards 0241 /// information about which bytes in the original shuffle were undef. 0242 bool widenShuffleMaskElts(ArrayRef<int> M, SmallVectorImpl<int> &NewMask); 0243 0244 /// Attempt to narrow/widen the \p Mask shuffle mask to the \p NumDstElts target 0245 /// width. Internally this will call narrowShuffleMaskElts/widenShuffleMaskElts. 0246 /// This will assert unless NumDstElts is a multiple of Mask.size (or 0247 /// vice-versa). Returns false on failure, and ScaledMask will be in an 0248 /// undefined state. 0249 bool scaleShuffleMaskElts(unsigned NumDstElts, ArrayRef<int> Mask, 0250 SmallVectorImpl<int> &ScaledMask); 0251 0252 /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds, 0253 /// to get the shuffle mask with widest possible elements. 0254 void getShuffleMaskWithWidestElts(ArrayRef<int> Mask, 0255 SmallVectorImpl<int> &ScaledMask); 0256 0257 /// Splits and processes shuffle mask depending on the number of input and 0258 /// output registers. The function does 2 main things: 1) splits the 0259 /// source/destination vectors into real registers; 2) do the mask analysis to 0260 /// identify which real registers are permuted. Then the function processes 0261 /// resulting registers mask using provided action items. If no input register 0262 /// is defined, \p NoInputAction action is used. If only 1 input register is 0263 /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to 0264 /// process > 2 input registers and masks. 0265 /// \param Mask Original shuffle mask. 0266 /// \param NumOfSrcRegs Number of source registers. 0267 /// \param NumOfDestRegs Number of destination registers. 0268 /// \param NumOfUsedRegs Number of actually used destination registers. 0269 void processShuffleMasks( 0270 ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs, 0271 unsigned NumOfUsedRegs, function_ref<void()> NoInputAction, 0272 function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction, 0273 function_ref<void(ArrayRef<int>, unsigned, unsigned, bool)> 0274 ManyInputsAction); 0275 0276 /// Compute the demanded elements mask of horizontal binary operations. A 0277 /// horizontal operation combines two adjacent elements in a vector operand. 0278 /// This function returns a mask for the elements that correspond to the first 0279 /// operand of this horizontal combination. For example, for two vectors 0280 /// [X1, X2, X3, X4] and [Y1, Y2, Y3, Y4], the resulting mask can include the 0281 /// elements X1, X3, Y1, and Y3. To get the other operands, simply shift the 0282 /// result of this function to the left by 1. 0283 /// 0284 /// \param VectorBitWidth the total bit width of the vector 0285 /// \param DemandedElts the demanded elements mask for the operation 0286 /// \param DemandedLHS the demanded elements mask for the left operand 0287 /// \param DemandedRHS the demanded elements mask for the right operand 0288 void getHorizDemandedEltsForFirstOperand(unsigned VectorBitWidth, 0289 const APInt &DemandedElts, 0290 APInt &DemandedLHS, 0291 APInt &DemandedRHS); 0292 0293 /// Compute a map of integer instructions to their minimum legal type 0294 /// size. 0295 /// 0296 /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int 0297 /// type (e.g. i32) whenever arithmetic is performed on them. 0298 /// 0299 /// For targets with native i8 or i16 operations, usually InstCombine can shrink 0300 /// the arithmetic type down again. However InstCombine refuses to create 0301 /// illegal types, so for targets without i8 or i16 registers, the lengthening 0302 /// and shrinking remains. 0303 /// 0304 /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when 0305 /// their scalar equivalents do not, so during vectorization it is important to 0306 /// remove these lengthens and truncates when deciding the profitability of 0307 /// vectorization. 0308 /// 0309 /// This function analyzes the given range of instructions and determines the 0310 /// minimum type size each can be converted to. It attempts to remove or 0311 /// minimize type size changes across each def-use chain, so for example in the 0312 /// following code: 0313 /// 0314 /// %1 = load i8, i8* 0315 /// %2 = add i8 %1, 2 0316 /// %3 = load i16, i16* 0317 /// %4 = zext i8 %2 to i32 0318 /// %5 = zext i16 %3 to i32 0319 /// %6 = add i32 %4, %5 0320 /// %7 = trunc i32 %6 to i16 0321 /// 0322 /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes 0323 /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}. 0324 /// 0325 /// If the optional TargetTransformInfo is provided, this function tries harder 0326 /// to do less work by only looking at illegal types. 0327 MapVector<Instruction*, uint64_t> 0328 computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks, 0329 DemandedBits &DB, 0330 const TargetTransformInfo *TTI=nullptr); 0331 0332 /// Compute the union of two access-group lists. 0333 /// 0334 /// If the list contains just one access group, it is returned directly. If the 0335 /// list is empty, returns nullptr. 0336 MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2); 0337 0338 /// Compute the access-group list of access groups that @p Inst1 and @p Inst2 0339 /// are both in. If either instruction does not access memory at all, it is 0340 /// considered to be in every list. 0341 /// 0342 /// If the list contains just one access group, it is returned directly. If the 0343 /// list is empty, returns nullptr. 0344 MDNode *intersectAccessGroups(const Instruction *Inst1, 0345 const Instruction *Inst2); 0346 0347 /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath, 0348 /// MD_nontemporal, MD_access_group, MD_mmra]. 0349 /// For K in Kinds, we get the MDNode for K from each of the 0350 /// elements of VL, compute their "intersection" (i.e., the most generic 0351 /// metadata value that covers all of the individual values), and set I's 0352 /// metadata for M equal to the intersection value. 0353 /// 0354 /// This function always sets a (possibly null) value for each K in Kinds. 0355 Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL); 0356 0357 /// Create a mask that filters the members of an interleave group where there 0358 /// are gaps. 0359 /// 0360 /// For example, the mask for \p Group with interleave-factor 3 0361 /// and \p VF 4, that has only its first member present is: 0362 /// 0363 /// <1,0,0,1,0,0,1,0,0,1,0,0> 0364 /// 0365 /// Note: The result is a mask of 0's and 1's, as opposed to the other 0366 /// create[*]Mask() utilities which create a shuffle mask (mask that 0367 /// consists of indices). 0368 Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF, 0369 const InterleaveGroup<Instruction> &Group); 0370 0371 /// Create a mask with replicated elements. 0372 /// 0373 /// This function creates a shuffle mask for replicating each of the \p VF 0374 /// elements in a vector \p ReplicationFactor times. It can be used to 0375 /// transform a mask of \p VF elements into a mask of 0376 /// \p VF * \p ReplicationFactor elements used by a predicated 0377 /// interleaved-group of loads/stores whose Interleaved-factor == 0378 /// \p ReplicationFactor. 0379 /// 0380 /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: 0381 /// 0382 /// <0,0,0,1,1,1,2,2,2,3,3,3> 0383 llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor, 0384 unsigned VF); 0385 0386 /// Create an interleave shuffle mask. 0387 /// 0388 /// This function creates a shuffle mask for interleaving \p NumVecs vectors of 0389 /// vectorization factor \p VF into a single wide vector. The mask is of the 0390 /// form: 0391 /// 0392 /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...> 0393 /// 0394 /// For example, the mask for VF = 4 and NumVecs = 2 is: 0395 /// 0396 /// <0, 4, 1, 5, 2, 6, 3, 7>. 0397 llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs); 0398 0399 /// Create a stride shuffle mask. 0400 /// 0401 /// This function creates a shuffle mask whose elements begin at \p Start and 0402 /// are incremented by \p Stride. The mask can be used to deinterleave an 0403 /// interleaved vector into separate vectors of vectorization factor \p VF. The 0404 /// mask is of the form: 0405 /// 0406 /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)> 0407 /// 0408 /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is: 0409 /// 0410 /// <0, 2, 4, 6> 0411 llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride, 0412 unsigned VF); 0413 0414 /// Create a sequential shuffle mask. 0415 /// 0416 /// This function creates shuffle mask whose elements are sequential and begin 0417 /// at \p Start. The mask contains \p NumInts integers and is padded with \p 0418 /// NumUndefs undef values. The mask is of the form: 0419 /// 0420 /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs> 0421 /// 0422 /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is: 0423 /// 0424 /// <0, 1, 2, 3, undef, undef, undef, undef> 0425 llvm::SmallVector<int, 16> 0426 createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs); 0427 0428 /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle 0429 /// mask assuming both operands are identical. This assumes that the unary 0430 /// shuffle will use elements from operand 0 (operand 1 will be unused). 0431 llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask, 0432 unsigned NumElts); 0433 0434 /// Concatenate a list of vectors. 0435 /// 0436 /// This function generates code that concatenate the vectors in \p Vecs into a 0437 /// single large vector. The number of vectors should be greater than one, and 0438 /// their element types should be the same. The number of elements in the 0439 /// vectors should also be the same; however, if the last vector has fewer 0440 /// elements, it will be padded with undefs. 0441 Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs); 0442 0443 /// Given a mask vector of i1, Return true if all of the elements of this 0444 /// predicate mask are known to be false or undef. That is, return true if all 0445 /// lanes can be assumed inactive. 0446 bool maskIsAllZeroOrUndef(Value *Mask); 0447 0448 /// Given a mask vector of i1, Return true if all of the elements of this 0449 /// predicate mask are known to be true or undef. That is, return true if all 0450 /// lanes can be assumed active. 0451 bool maskIsAllOneOrUndef(Value *Mask); 0452 0453 /// Given a mask vector of i1, Return true if any of the elements of this 0454 /// predicate mask are known to be true or undef. That is, return true if at 0455 /// least one lane can be assumed active. 0456 bool maskContainsAllOneOrUndef(Value *Mask); 0457 0458 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y) 0459 /// for each lane which may be active. 0460 APInt possiblyDemandedEltsInMask(Value *Mask); 0461 0462 /// The group of interleaved loads/stores sharing the same stride and 0463 /// close to each other. 0464 /// 0465 /// Each member in this group has an index starting from 0, and the largest 0466 /// index should be less than interleaved factor, which is equal to the absolute 0467 /// value of the access's stride. 0468 /// 0469 /// E.g. An interleaved load group of factor 4: 0470 /// for (unsigned i = 0; i < 1024; i+=4) { 0471 /// a = A[i]; // Member of index 0 0472 /// b = A[i+1]; // Member of index 1 0473 /// d = A[i+3]; // Member of index 3 0474 /// ... 0475 /// } 0476 /// 0477 /// An interleaved store group of factor 4: 0478 /// for (unsigned i = 0; i < 1024; i+=4) { 0479 /// ... 0480 /// A[i] = a; // Member of index 0 0481 /// A[i+1] = b; // Member of index 1 0482 /// A[i+2] = c; // Member of index 2 0483 /// A[i+3] = d; // Member of index 3 0484 /// } 0485 /// 0486 /// Note: the interleaved load group could have gaps (missing members), but 0487 /// the interleaved store group doesn't allow gaps. 0488 template <typename InstTy> class InterleaveGroup { 0489 public: 0490 InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment) 0491 : Factor(Factor), Reverse(Reverse), Alignment(Alignment), 0492 InsertPos(nullptr) {} 0493 0494 InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment) 0495 : Alignment(Alignment), InsertPos(Instr) { 0496 Factor = std::abs(Stride); 0497 assert(Factor > 1 && "Invalid interleave factor"); 0498 0499 Reverse = Stride < 0; 0500 Members[0] = Instr; 0501 } 0502 0503 bool isReverse() const { return Reverse; } 0504 uint32_t getFactor() const { return Factor; } 0505 Align getAlign() const { return Alignment; } 0506 uint32_t getNumMembers() const { return Members.size(); } 0507 0508 /// Try to insert a new member \p Instr with index \p Index and 0509 /// alignment \p NewAlign. The index is related to the leader and it could be 0510 /// negative if it is the new leader. 0511 /// 0512 /// \returns false if the instruction doesn't belong to the group. 0513 bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) { 0514 // Make sure the key fits in an int32_t. 0515 std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey); 0516 if (!MaybeKey) 0517 return false; 0518 int32_t Key = *MaybeKey; 0519 0520 // Skip if the key is used for either the tombstone or empty special values. 0521 if (DenseMapInfo<int32_t>::getTombstoneKey() == Key || 0522 DenseMapInfo<int32_t>::getEmptyKey() == Key) 0523 return false; 0524 0525 // Skip if there is already a member with the same index. 0526 if (Members.contains(Key)) 0527 return false; 0528 0529 if (Key > LargestKey) { 0530 // The largest index is always less than the interleave factor. 0531 if (Index >= static_cast<int32_t>(Factor)) 0532 return false; 0533 0534 LargestKey = Key; 0535 } else if (Key < SmallestKey) { 0536 0537 // Make sure the largest index fits in an int32_t. 0538 std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key); 0539 if (!MaybeLargestIndex) 0540 return false; 0541 0542 // The largest index is always less than the interleave factor. 0543 if (*MaybeLargestIndex >= static_cast<int64_t>(Factor)) 0544 return false; 0545 0546 SmallestKey = Key; 0547 } 0548 0549 // It's always safe to select the minimum alignment. 0550 Alignment = std::min(Alignment, NewAlign); 0551 Members[Key] = Instr; 0552 return true; 0553 } 0554 0555 /// Get the member with the given index \p Index 0556 /// 0557 /// \returns nullptr if contains no such member. 0558 InstTy *getMember(uint32_t Index) const { 0559 int32_t Key = SmallestKey + Index; 0560 return Members.lookup(Key); 0561 } 0562 0563 /// Get the index for the given member. Unlike the key in the member 0564 /// map, the index starts from 0. 0565 uint32_t getIndex(const InstTy *Instr) const { 0566 for (auto I : Members) { 0567 if (I.second == Instr) 0568 return I.first - SmallestKey; 0569 } 0570 0571 llvm_unreachable("InterleaveGroup contains no such member"); 0572 } 0573 0574 InstTy *getInsertPos() const { return InsertPos; } 0575 void setInsertPos(InstTy *Inst) { InsertPos = Inst; } 0576 0577 /// Add metadata (e.g. alias info) from the instructions in this group to \p 0578 /// NewInst. 0579 /// 0580 /// FIXME: this function currently does not add noalias metadata a'la 0581 /// addNewMedata. To do that we need to compute the intersection of the 0582 /// noalias info from all members. 0583 void addMetadata(InstTy *NewInst) const; 0584 0585 /// Returns true if this Group requires a scalar iteration to handle gaps. 0586 bool requiresScalarEpilogue() const { 0587 // If the last member of the Group exists, then a scalar epilog is not 0588 // needed for this group. 0589 if (getMember(getFactor() - 1)) 0590 return false; 0591 0592 // We have a group with gaps. It therefore can't be a reversed access, 0593 // because such groups get invalidated (TODO). 0594 assert(!isReverse() && "Group should have been invalidated"); 0595 0596 // This is a group of loads, with gaps, and without a last-member 0597 return true; 0598 } 0599 0600 private: 0601 uint32_t Factor; // Interleave Factor. 0602 bool Reverse; 0603 Align Alignment; 0604 DenseMap<int32_t, InstTy *> Members; 0605 int32_t SmallestKey = 0; 0606 int32_t LargestKey = 0; 0607 0608 // To avoid breaking dependences, vectorized instructions of an interleave 0609 // group should be inserted at either the first load or the last store in 0610 // program order. 0611 // 0612 // E.g. %even = load i32 // Insert Position 0613 // %add = add i32 %even // Use of %even 0614 // %odd = load i32 0615 // 0616 // store i32 %even 0617 // %odd = add i32 // Def of %odd 0618 // store i32 %odd // Insert Position 0619 InstTy *InsertPos; 0620 }; 0621 0622 /// Drive the analysis of interleaved memory accesses in the loop. 0623 /// 0624 /// Use this class to analyze interleaved accesses only when we can vectorize 0625 /// a loop. Otherwise it's meaningless to do analysis as the vectorization 0626 /// on interleaved accesses is unsafe. 0627 /// 0628 /// The analysis collects interleave groups and records the relationships 0629 /// between the member and the group in a map. 0630 class InterleavedAccessInfo { 0631 public: 0632 InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, 0633 DominatorTree *DT, LoopInfo *LI, 0634 const LoopAccessInfo *LAI) 0635 : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} 0636 0637 ~InterleavedAccessInfo() { invalidateGroups(); } 0638 0639 /// Analyze the interleaved accesses and collect them in interleave 0640 /// groups. Substitute symbolic strides using \p Strides. 0641 /// Consider also predicated loads/stores in the analysis if 0642 /// \p EnableMaskedInterleavedGroup is true. 0643 void analyzeInterleaving(bool EnableMaskedInterleavedGroup); 0644 0645 /// Invalidate groups, e.g., in case all blocks in loop will be predicated 0646 /// contrary to original assumption. Although we currently prevent group 0647 /// formation for predicated accesses, we may be able to relax this limitation 0648 /// in the future once we handle more complicated blocks. Returns true if any 0649 /// groups were invalidated. 0650 bool invalidateGroups() { 0651 if (InterleaveGroups.empty()) { 0652 assert( 0653 !RequiresScalarEpilogue && 0654 "RequiresScalarEpilog should not be set without interleave groups"); 0655 return false; 0656 } 0657 0658 InterleaveGroupMap.clear(); 0659 for (auto *Ptr : InterleaveGroups) 0660 delete Ptr; 0661 InterleaveGroups.clear(); 0662 RequiresScalarEpilogue = false; 0663 return true; 0664 } 0665 0666 /// Check if \p Instr belongs to any interleave group. 0667 bool isInterleaved(Instruction *Instr) const { 0668 return InterleaveGroupMap.contains(Instr); 0669 } 0670 0671 /// Get the interleave group that \p Instr belongs to. 0672 /// 0673 /// \returns nullptr if doesn't have such group. 0674 InterleaveGroup<Instruction> * 0675 getInterleaveGroup(const Instruction *Instr) const { 0676 return InterleaveGroupMap.lookup(Instr); 0677 } 0678 0679 iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>> 0680 getInterleaveGroups() { 0681 return make_range(InterleaveGroups.begin(), InterleaveGroups.end()); 0682 } 0683 0684 /// Returns true if an interleaved group that may access memory 0685 /// out-of-bounds requires a scalar epilogue iteration for correctness. 0686 bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } 0687 0688 /// Invalidate groups that require a scalar epilogue (due to gaps). This can 0689 /// happen when optimizing for size forbids a scalar epilogue, and the gap 0690 /// cannot be filtered by masking the load/store. 0691 void invalidateGroupsRequiringScalarEpilogue(); 0692 0693 /// Returns true if we have any interleave groups. 0694 bool hasGroups() const { return !InterleaveGroups.empty(); } 0695 0696 private: 0697 /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. 0698 /// Simplifies SCEV expressions in the context of existing SCEV assumptions. 0699 /// The interleaved access analysis can also add new predicates (for example 0700 /// by versioning strides of pointers). 0701 PredicatedScalarEvolution &PSE; 0702 0703 Loop *TheLoop; 0704 DominatorTree *DT; 0705 LoopInfo *LI; 0706 const LoopAccessInfo *LAI; 0707 0708 /// True if the loop may contain non-reversed interleaved groups with 0709 /// out-of-bounds accesses. We ensure we don't speculatively access memory 0710 /// out-of-bounds by executing at least one scalar epilogue iteration. 0711 bool RequiresScalarEpilogue = false; 0712 0713 /// Holds the relationships between the members and the interleave group. 0714 DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap; 0715 0716 SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups; 0717 0718 /// Holds dependences among the memory accesses in the loop. It maps a source 0719 /// access to a set of dependent sink accesses. 0720 DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; 0721 0722 /// The descriptor for a strided memory access. 0723 struct StrideDescriptor { 0724 StrideDescriptor() = default; 0725 StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, 0726 Align Alignment) 0727 : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {} 0728 0729 // The access's stride. It is negative for a reverse access. 0730 int64_t Stride = 0; 0731 0732 // The scalar expression of this access. 0733 const SCEV *Scev = nullptr; 0734 0735 // The size of the memory object. 0736 uint64_t Size = 0; 0737 0738 // The alignment of this access. 0739 Align Alignment; 0740 }; 0741 0742 /// A type for holding instructions and their stride descriptors. 0743 using StrideEntry = std::pair<Instruction *, StrideDescriptor>; 0744 0745 /// Create a new interleave group with the given instruction \p Instr, 0746 /// stride \p Stride and alignment \p Align. 0747 /// 0748 /// \returns the newly created interleave group. 0749 InterleaveGroup<Instruction> * 0750 createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) { 0751 assert(!InterleaveGroupMap.count(Instr) && 0752 "Already in an interleaved access group"); 0753 InterleaveGroupMap[Instr] = 0754 new InterleaveGroup<Instruction>(Instr, Stride, Alignment); 0755 InterleaveGroups.insert(InterleaveGroupMap[Instr]); 0756 return InterleaveGroupMap[Instr]; 0757 } 0758 0759 /// Release the group and remove all the relationships. 0760 void releaseGroup(InterleaveGroup<Instruction> *Group) { 0761 InterleaveGroups.erase(Group); 0762 releaseGroupWithoutRemovingFromSet(Group); 0763 } 0764 0765 /// Do everything necessary to release the group, apart from removing it from 0766 /// the InterleaveGroups set. 0767 void releaseGroupWithoutRemovingFromSet(InterleaveGroup<Instruction> *Group) { 0768 for (unsigned i = 0; i < Group->getFactor(); i++) 0769 if (Instruction *Member = Group->getMember(i)) 0770 InterleaveGroupMap.erase(Member); 0771 0772 delete Group; 0773 } 0774 0775 /// Collect all the accesses with a constant stride in program order. 0776 void collectConstStrideAccesses( 0777 MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, 0778 const DenseMap<Value *, const SCEV *> &Strides); 0779 0780 /// Returns true if \p Stride is allowed in an interleaved group. 0781 static bool isStrided(int Stride); 0782 0783 /// Returns true if \p BB is a predicated block. 0784 bool isPredicated(BasicBlock *BB) const { 0785 return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); 0786 } 0787 0788 /// Returns true if LoopAccessInfo can be used for dependence queries. 0789 bool areDependencesValid() const { 0790 return LAI && LAI->getDepChecker().getDependences(); 0791 } 0792 0793 /// Returns true if memory accesses \p A and \p B can be reordered, if 0794 /// necessary, when constructing interleaved groups. 0795 /// 0796 /// \p A must precede \p B in program order. We return false if reordering is 0797 /// not necessary or is prevented because \p A and \p B may be dependent. 0798 bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, 0799 StrideEntry *B) const { 0800 // Code motion for interleaved accesses can potentially hoist strided loads 0801 // and sink strided stores. The code below checks the legality of the 0802 // following two conditions: 0803 // 0804 // 1. Potentially moving a strided load (B) before any store (A) that 0805 // precedes B, or 0806 // 0807 // 2. Potentially moving a strided store (A) after any load or store (B) 0808 // that A precedes. 0809 // 0810 // It's legal to reorder A and B if we know there isn't a dependence from A 0811 // to B. Note that this determination is conservative since some 0812 // dependences could potentially be reordered safely. 0813 0814 // A is potentially the source of a dependence. 0815 auto *Src = A->first; 0816 auto SrcDes = A->second; 0817 0818 // B is potentially the sink of a dependence. 0819 auto *Sink = B->first; 0820 auto SinkDes = B->second; 0821 0822 // Code motion for interleaved accesses can't violate WAR dependences. 0823 // Thus, reordering is legal if the source isn't a write. 0824 if (!Src->mayWriteToMemory()) 0825 return true; 0826 0827 // At least one of the accesses must be strided. 0828 if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) 0829 return true; 0830 0831 // If dependence information is not available from LoopAccessInfo, 0832 // conservatively assume the instructions can't be reordered. 0833 if (!areDependencesValid()) 0834 return false; 0835 0836 // If we know there is a dependence from source to sink, assume the 0837 // instructions can't be reordered. Otherwise, reordering is legal. 0838 return !Dependences.contains(Src) || !Dependences.lookup(Src).count(Sink); 0839 } 0840 0841 /// Collect the dependences from LoopAccessInfo. 0842 /// 0843 /// We process the dependences once during the interleaved access analysis to 0844 /// enable constant-time dependence queries. 0845 void collectDependences() { 0846 if (!areDependencesValid()) 0847 return; 0848 const auto &DepChecker = LAI->getDepChecker(); 0849 auto *Deps = DepChecker.getDependences(); 0850 for (auto Dep : *Deps) 0851 Dependences[Dep.getSource(DepChecker)].insert( 0852 Dep.getDestination(DepChecker)); 0853 } 0854 }; 0855 0856 } // llvm namespace 0857 0858 #endif
| [ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
|
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
|