Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:44:45

0001 //===- VecUtils.h -----------------------------------------------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 // Collector for SandboxVectorizer related convenience functions that don't
0010 // belong in other classes.
0011 
0012 #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
0013 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H
0014 
0015 #include "llvm/Analysis/ScalarEvolution.h"
0016 #include "llvm/IR/DataLayout.h"
0017 #include "llvm/SandboxIR/Type.h"
0018 #include "llvm/SandboxIR/Utils.h"
0019 
0020 namespace llvm::sandboxir {
0021 
0022 class VecUtils {
0023 public:
0024   /// \Returns the number of elements in \p Ty. That is the number of lanes if a
0025   /// fixed vector or 1 if scalar. ScalableVectors have unknown size and
0026   /// therefore are unsupported.
0027   static int getNumElements(Type *Ty) {
0028     assert(!isa<ScalableVectorType>(Ty));
0029     return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getNumElements() : 1;
0030   }
0031   /// Returns \p Ty if scalar or its element type if vector.
0032   static Type *getElementType(Type *Ty) {
0033     return Ty->isVectorTy() ? cast<FixedVectorType>(Ty)->getElementType() : Ty;
0034   }
0035 
0036   /// \Returns true if \p I1 and \p I2 are load/stores accessing consecutive
0037   /// memory addresses.
0038   template <typename LoadOrStoreT>
0039   static bool areConsecutive(LoadOrStoreT *I1, LoadOrStoreT *I2,
0040                              ScalarEvolution &SE, const DataLayout &DL) {
0041     static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
0042                       std::is_same<LoadOrStoreT, StoreInst>::value,
0043                   "Expected Load or Store!");
0044     auto Diff = Utils::getPointerDiffInBytes(I1, I2, SE);
0045     if (!Diff)
0046       return false;
0047     int ElmBytes = Utils::getNumBits(I1) / 8;
0048     return *Diff == ElmBytes;
0049   }
0050 
0051   template <typename LoadOrStoreT>
0052   static bool areConsecutive(ArrayRef<Value *> &Bndl, ScalarEvolution &SE,
0053                              const DataLayout &DL) {
0054     static_assert(std::is_same<LoadOrStoreT, LoadInst>::value ||
0055                       std::is_same<LoadOrStoreT, StoreInst>::value,
0056                   "Expected Load or Store!");
0057     assert(isa<LoadOrStoreT>(Bndl[0]) && "Expected Load or Store!");
0058     auto *LastLS = cast<LoadOrStoreT>(Bndl[0]);
0059     for (Value *V : drop_begin(Bndl)) {
0060       assert(isa<LoadOrStoreT>(V) &&
0061              "Unimplemented: we only support StoreInst!");
0062       auto *LS = cast<LoadOrStoreT>(V);
0063       if (!VecUtils::areConsecutive(LastLS, LS, SE, DL))
0064         return false;
0065       LastLS = LS;
0066     }
0067     return true;
0068   }
0069 
0070   /// \Returns the number of vector lanes of \p Ty or 1 if not a vector.
0071   /// NOTE: It asserts that \p Ty is a fixed vector type.
0072   static unsigned getNumLanes(Type *Ty) {
0073     assert(!isa<ScalableVectorType>(Ty) && "Expect scalar or fixed vector");
0074     if (auto *FixedVecTy = dyn_cast<FixedVectorType>(Ty))
0075       return FixedVecTy->getNumElements();
0076     return 1u;
0077   }
0078 
0079   /// \Returns the expected vector lanes of \p V or 1 if not a vector.
0080   /// NOTE: It asserts that \p V is a fixed vector.
0081   static unsigned getNumLanes(Value *V) {
0082     return VecUtils::getNumLanes(Utils::getExpectedType(V));
0083   }
0084 
0085   /// \Returns the total number of lanes across all values in \p Bndl.
0086   static unsigned getNumLanes(ArrayRef<Value *> Bndl) {
0087     unsigned Lanes = 0;
0088     for (Value *V : Bndl)
0089       Lanes += getNumLanes(V);
0090     return Lanes;
0091   }
0092 
0093   /// \Returns <NumElts x ElemTy>.
0094   /// It works for both scalar and vector \p ElemTy.
0095   static Type *getWideType(Type *ElemTy, unsigned NumElts) {
0096     if (ElemTy->isVectorTy()) {
0097       auto *VecTy = cast<FixedVectorType>(ElemTy);
0098       ElemTy = VecTy->getElementType();
0099       NumElts = VecTy->getNumElements() * NumElts;
0100     }
0101     return FixedVectorType::get(ElemTy, NumElts);
0102   }
0103   /// \Returns the instruction in \p Instrs that is lowest in the BB. Expects
0104   /// that all instructions are in the same BB.
0105   static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
0106     Instruction *LowestI = Instrs.front();
0107     for (auto *I : drop_begin(Instrs)) {
0108       if (LowestI->comesBefore(I))
0109         LowestI = I;
0110     }
0111     return LowestI;
0112   }
0113   /// \Returns the lowest instruction in \p Vals, or nullptr if no instructions
0114   /// are found. Skips instructions not in \p BB.
0115   static Instruction *getLowest(ArrayRef<Value *> Vals, BasicBlock *BB) {
0116     // Find the first Instruction in Vals that is also in `BB`.
0117     auto It = find_if(Vals, [BB](Value *V) {
0118       return isa<Instruction>(V) && cast<Instruction>(V)->getParent() == BB;
0119     });
0120     // If we couldn't find an instruction return nullptr.
0121     if (It == Vals.end())
0122       return nullptr;
0123     Instruction *FirstI = cast<Instruction>(*It);
0124     // Now look for the lowest instruction in Vals starting from one position
0125     // after FirstI.
0126     Instruction *LowestI = FirstI;
0127     for (auto *V : make_range(std::next(It), Vals.end())) {
0128       auto *I = dyn_cast<Instruction>(V);
0129       // Skip non-instructions.
0130       if (I == nullptr)
0131         continue;
0132       // Skips instructions not in \p BB.
0133       if (I->getParent() != BB)
0134         continue;
0135       // If `LowestI` comes before `I` then `I` is the new lowest.
0136       if (LowestI->comesBefore(I))
0137         LowestI = I;
0138     }
0139     return LowestI;
0140   }
0141 
0142   /// If \p I is not a PHI it returns it. Else it walks down the instruction
0143   /// chain looking for the last PHI and returns it. \Returns nullptr if \p I is
0144   /// nullptr.
0145   static Instruction *getLastPHIOrSelf(Instruction *I) {
0146     Instruction *LastI = I;
0147     while (I != nullptr && isa<PHINode>(I)) {
0148       LastI = I;
0149       I = I->getNextNode();
0150     }
0151     return LastI;
0152   }
0153 
0154   /// If all values in \p Bndl are of the same scalar type then return it,
0155   /// otherwise return nullptr.
0156   static Type *tryGetCommonScalarType(ArrayRef<Value *> Bndl) {
0157     Value *V0 = Bndl[0];
0158     Type *Ty0 = Utils::getExpectedType(V0);
0159     Type *ScalarTy = VecUtils::getElementType(Ty0);
0160     for (auto *V : drop_begin(Bndl)) {
0161       Type *NTy = Utils::getExpectedType(V);
0162       Type *NScalarTy = VecUtils::getElementType(NTy);
0163       if (NScalarTy != ScalarTy)
0164         return nullptr;
0165     }
0166     return ScalarTy;
0167   }
0168 
0169   /// Similar to tryGetCommonScalarType() but will assert that there is a common
0170   /// type. So this is faster in release builds as it won't iterate through the
0171   /// values.
0172   static Type *getCommonScalarType(ArrayRef<Value *> Bndl) {
0173     Value *V0 = Bndl[0];
0174     Type *Ty0 = Utils::getExpectedType(V0);
0175     Type *ScalarTy = VecUtils::getElementType(Ty0);
0176     assert(tryGetCommonScalarType(Bndl) && "Expected common scalar type!");
0177     return ScalarTy;
0178   }
0179   /// \Returns the first integer power of 2 that is <= Num.
0180   static unsigned getFloorPowerOf2(unsigned Num);
0181 
0182 #ifndef NDEBUG
0183   /// Helper dump function for debugging.
0184   LLVM_DUMP_METHOD static void dump(ArrayRef<Value *> Bndl);
0185   LLVM_DUMP_METHOD static void dump(ArrayRef<Instruction *> Bndl);
0186 #endif // NDEBUG
0187 };
0188 
0189 } // namespace llvm::sandboxir
0190 
0191 #endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_VECUTILS_H