File indexing completed on 2026-05-10 08:43:17
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
0014 #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
0015
0016 #include "llvm/ADT/DenseMap.h"
0017 #include "llvm/ADT/SmallPtrSet.h"
0018 #include "llvm/ADT/SmallVector.h"
0019 #include "llvm/Analysis/ScalarEvolution.h"
0020 #include "llvm/IR/Constants.h"
0021 #include "llvm/IR/ValueHandle.h"
0022 #include "llvm/Support/Casting.h"
0023 #include "llvm/Support/ErrorHandling.h"
0024 #include <cassert>
0025 #include <cstddef>
0026
0027 namespace llvm {
0028
0029 class APInt;
0030 class Constant;
0031 class ConstantInt;
0032 class ConstantRange;
0033 class Loop;
0034 class Type;
0035 class Value;
0036
0037 enum SCEVTypes : unsigned short {
0038
0039
0040 scConstant,
0041 scVScale,
0042 scTruncate,
0043 scZeroExtend,
0044 scSignExtend,
0045 scAddExpr,
0046 scMulExpr,
0047 scUDivExpr,
0048 scAddRecExpr,
0049 scUMaxExpr,
0050 scSMaxExpr,
0051 scUMinExpr,
0052 scSMinExpr,
0053 scSequentialUMinExpr,
0054 scPtrToInt,
0055 scUnknown,
0056 scCouldNotCompute
0057 };
0058
0059
0060 class SCEVConstant : public SCEV {
0061 friend class ScalarEvolution;
0062
0063 ConstantInt *V;
0064
0065 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
0066 : SCEV(ID, scConstant, 1), V(v) {}
0067
0068 public:
0069 ConstantInt *getValue() const { return V; }
0070 const APInt &getAPInt() const { return getValue()->getValue(); }
0071
0072 Type *getType() const { return V->getType(); }
0073
0074
0075 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
0076 };
0077
0078
0079
0080 class SCEVVScale : public SCEV {
0081 friend class ScalarEvolution;
0082
0083 SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
0084 : SCEV(ID, scVScale, 0), Ty(ty) {}
0085
0086 Type *Ty;
0087
0088 public:
0089 Type *getType() const { return Ty; }
0090
0091
0092 static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
0093 };
0094
0095 inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
0096 APInt Size(16, 1);
0097 for (const auto *Arg : Args)
0098 Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
0099 return (unsigned short)Size.getZExtValue();
0100 }
0101
0102
0103 class SCEVCastExpr : public SCEV {
0104 protected:
0105 const SCEV *Op;
0106 Type *Ty;
0107
0108 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
0109 Type *ty);
0110
0111 public:
0112 const SCEV *getOperand() const { return Op; }
0113 const SCEV *getOperand(unsigned i) const {
0114 assert(i == 0 && "Operand index out of range!");
0115 return Op;
0116 }
0117 ArrayRef<const SCEV *> operands() const { return Op; }
0118 size_t getNumOperands() const { return 1; }
0119 Type *getType() const { return Ty; }
0120
0121
0122 static bool classof(const SCEV *S) {
0123 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
0124 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
0125 }
0126 };
0127
0128
0129
0130 class SCEVPtrToIntExpr : public SCEVCastExpr {
0131 friend class ScalarEvolution;
0132
0133 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
0134
0135 public:
0136
0137 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
0138 };
0139
0140
0141 class SCEVIntegralCastExpr : public SCEVCastExpr {
0142 protected:
0143 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
0144 const SCEV *op, Type *ty);
0145
0146 public:
0147
0148 static bool classof(const SCEV *S) {
0149 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
0150 S->getSCEVType() == scSignExtend;
0151 }
0152 };
0153
0154
0155
0156 class SCEVTruncateExpr : public SCEVIntegralCastExpr {
0157 friend class ScalarEvolution;
0158
0159 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
0160
0161 public:
0162
0163 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
0164 };
0165
0166
0167
0168 class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
0169 friend class ScalarEvolution;
0170
0171 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
0172
0173 public:
0174
0175 static bool classof(const SCEV *S) {
0176 return S->getSCEVType() == scZeroExtend;
0177 }
0178 };
0179
0180
0181
0182 class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
0183 friend class ScalarEvolution;
0184
0185 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
0186
0187 public:
0188
0189 static bool classof(const SCEV *S) {
0190 return S->getSCEVType() == scSignExtend;
0191 }
0192 };
0193
0194
0195
0196 class SCEVNAryExpr : public SCEV {
0197 protected:
0198
0199
0200
0201
0202 const SCEV *const *Operands;
0203 size_t NumOperands;
0204
0205 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
0206 const SCEV *const *O, size_t N)
0207 : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
0208 NumOperands(N) {}
0209
0210 public:
0211 size_t getNumOperands() const { return NumOperands; }
0212
0213 const SCEV *getOperand(unsigned i) const {
0214 assert(i < NumOperands && "Operand index out of range!");
0215 return Operands[i];
0216 }
0217
0218 ArrayRef<const SCEV *> operands() const {
0219 return ArrayRef(Operands, NumOperands);
0220 }
0221
0222 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
0223 return (NoWrapFlags)(SubclassData & Mask);
0224 }
0225
0226 bool hasNoUnsignedWrap() const {
0227 return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
0228 }
0229
0230 bool hasNoSignedWrap() const {
0231 return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
0232 }
0233
0234 bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
0235
0236
0237 static bool classof(const SCEV *S) {
0238 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
0239 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
0240 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
0241 S->getSCEVType() == scSequentialUMinExpr ||
0242 S->getSCEVType() == scAddRecExpr;
0243 }
0244 };
0245
0246
0247 class SCEVCommutativeExpr : public SCEVNAryExpr {
0248 protected:
0249 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
0250 const SCEV *const *O, size_t N)
0251 : SCEVNAryExpr(ID, T, O, N) {}
0252
0253 public:
0254
0255 static bool classof(const SCEV *S) {
0256 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
0257 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
0258 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
0259 }
0260
0261
0262 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
0263 };
0264
0265
0266 class SCEVAddExpr : public SCEVCommutativeExpr {
0267 friend class ScalarEvolution;
0268
0269 Type *Ty;
0270
0271 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
0272 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
0273 auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
0274 return Op->getType()->isPointerTy();
0275 });
0276 if (FirstPointerTypedOp != operands().end())
0277 Ty = (*FirstPointerTypedOp)->getType();
0278 else
0279 Ty = getOperand(0)->getType();
0280 }
0281
0282 public:
0283 Type *getType() const { return Ty; }
0284
0285
0286 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
0287 };
0288
0289
0290 class SCEVMulExpr : public SCEVCommutativeExpr {
0291 friend class ScalarEvolution;
0292
0293 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
0294 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
0295
0296 public:
0297 Type *getType() const { return getOperand(0)->getType(); }
0298
0299
0300 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
0301 };
0302
0303
0304 class SCEVUDivExpr : public SCEV {
0305 friend class ScalarEvolution;
0306
0307 std::array<const SCEV *, 2> Operands;
0308
0309 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
0310 : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
0311 Operands[0] = lhs;
0312 Operands[1] = rhs;
0313 }
0314
0315 public:
0316 const SCEV *getLHS() const { return Operands[0]; }
0317 const SCEV *getRHS() const { return Operands[1]; }
0318 size_t getNumOperands() const { return 2; }
0319 const SCEV *getOperand(unsigned i) const {
0320 assert((i == 0 || i == 1) && "Operand index out of range!");
0321 return i == 0 ? getLHS() : getRHS();
0322 }
0323
0324 ArrayRef<const SCEV *> operands() const { return Operands; }
0325
0326 Type *getType() const {
0327
0328
0329
0330
0331
0332 return getRHS()->getType();
0333 }
0334
0335
0336 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
0337 };
0338
0339
0340
0341
0342
0343
0344
0345
0346
0347 class SCEVAddRecExpr : public SCEVNAryExpr {
0348 friend class ScalarEvolution;
0349
0350 const Loop *L;
0351
0352 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
0353 const Loop *l)
0354 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
0355
0356 public:
0357 Type *getType() const { return getStart()->getType(); }
0358 const SCEV *getStart() const { return Operands[0]; }
0359 const Loop *getLoop() const { return L; }
0360
0361
0362
0363
0364
0365 const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
0366 if (isAffine())
0367 return getOperand(1);
0368 return SE.getAddRecExpr(
0369 SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
0370 FlagAnyWrap);
0371 }
0372
0373
0374
0375 bool isAffine() const {
0376
0377
0378 return getNumOperands() == 2;
0379 }
0380
0381
0382
0383
0384 bool isQuadratic() const { return getNumOperands() == 3; }
0385
0386
0387
0388
0389 void setNoWrapFlags(NoWrapFlags Flags) {
0390 if (Flags & (FlagNUW | FlagNSW))
0391 Flags = ScalarEvolution::setFlags(Flags, FlagNW);
0392 SubclassData |= Flags;
0393 }
0394
0395
0396
0397 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
0398
0399
0400
0401 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
0402 const SCEV *It, ScalarEvolution &SE);
0403
0404
0405
0406
0407
0408
0409
0410 const SCEV *getNumIterationsInRange(const ConstantRange &Range,
0411 ScalarEvolution &SE) const;
0412
0413
0414
0415 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
0416
0417
0418 static bool classof(const SCEV *S) {
0419 return S->getSCEVType() == scAddRecExpr;
0420 }
0421 };
0422
0423
0424 class SCEVMinMaxExpr : public SCEVCommutativeExpr {
0425 friend class ScalarEvolution;
0426
0427 static bool isMinMaxType(enum SCEVTypes T) {
0428 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
0429 T == scUMinExpr;
0430 }
0431
0432 protected:
0433
0434 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
0435 const SCEV *const *O, size_t N)
0436 : SCEVCommutativeExpr(ID, T, O, N) {
0437 assert(isMinMaxType(T));
0438
0439 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
0440 }
0441
0442 public:
0443 Type *getType() const { return getOperand(0)->getType(); }
0444
0445 static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
0446
0447 static enum SCEVTypes negate(enum SCEVTypes T) {
0448 switch (T) {
0449 case scSMaxExpr:
0450 return scSMinExpr;
0451 case scSMinExpr:
0452 return scSMaxExpr;
0453 case scUMaxExpr:
0454 return scUMinExpr;
0455 case scUMinExpr:
0456 return scUMaxExpr;
0457 default:
0458 llvm_unreachable("Not a min or max SCEV type!");
0459 }
0460 }
0461 };
0462
0463
0464 class SCEVSMaxExpr : public SCEVMinMaxExpr {
0465 friend class ScalarEvolution;
0466
0467 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
0468 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
0469
0470 public:
0471
0472 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
0473 };
0474
0475
0476 class SCEVUMaxExpr : public SCEVMinMaxExpr {
0477 friend class ScalarEvolution;
0478
0479 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
0480 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
0481
0482 public:
0483
0484 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
0485 };
0486
0487
0488 class SCEVSMinExpr : public SCEVMinMaxExpr {
0489 friend class ScalarEvolution;
0490
0491 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
0492 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
0493
0494 public:
0495
0496 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
0497 };
0498
0499
0500 class SCEVUMinExpr : public SCEVMinMaxExpr {
0501 friend class ScalarEvolution;
0502
0503 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
0504 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
0505
0506 public:
0507
0508 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
0509 };
0510
0511
0512
0513
0514
0515
0516
0517 class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
0518 friend class ScalarEvolution;
0519
0520 static bool isSequentialMinMaxType(enum SCEVTypes T) {
0521 return T == scSequentialUMinExpr;
0522 }
0523
0524
0525 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
0526
0527 protected:
0528
0529 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
0530 const SCEV *const *O, size_t N)
0531 : SCEVNAryExpr(ID, T, O, N) {
0532 assert(isSequentialMinMaxType(T));
0533
0534 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
0535 }
0536
0537 public:
0538 Type *getType() const { return getOperand(0)->getType(); }
0539
0540 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
0541 assert(isSequentialMinMaxType(Ty));
0542 switch (Ty) {
0543 case scSequentialUMinExpr:
0544 return scUMinExpr;
0545 default:
0546 llvm_unreachable("Not a sequential min/max type.");
0547 }
0548 }
0549
0550 SCEVTypes getEquivalentNonSequentialSCEVType() const {
0551 return getEquivalentNonSequentialSCEVType(getSCEVType());
0552 }
0553
0554 static bool classof(const SCEV *S) {
0555 return isSequentialMinMaxType(S->getSCEVType());
0556 }
0557 };
0558
0559
0560 class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
0561 friend class ScalarEvolution;
0562
0563 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
0564 size_t N)
0565 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
0566
0567 public:
0568
0569 static bool classof(const SCEV *S) {
0570 return S->getSCEVType() == scSequentialUMinExpr;
0571 }
0572 };
0573
0574
0575
0576
0577 class SCEVUnknown final : public SCEV, private CallbackVH {
0578 friend class ScalarEvolution;
0579
0580
0581
0582
0583 ScalarEvolution *SE;
0584
0585
0586
0587 SCEVUnknown *Next;
0588
0589 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
0590 SCEVUnknown *next)
0591 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
0592
0593
0594 void deleted() override;
0595 void allUsesReplacedWith(Value *New) override;
0596
0597 public:
0598 Value *getValue() const { return getValPtr(); }
0599
0600 Type *getType() const { return getValPtr()->getType(); }
0601
0602
0603 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
0604 };
0605
0606
0607
0608 template <typename SC, typename RetVal = void> struct SCEVVisitor {
0609 RetVal visit(const SCEV *S) {
0610 switch (S->getSCEVType()) {
0611 case scConstant:
0612 return ((SC *)this)->visitConstant((const SCEVConstant *)S);
0613 case scVScale:
0614 return ((SC *)this)->visitVScale((const SCEVVScale *)S);
0615 case scPtrToInt:
0616 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
0617 case scTruncate:
0618 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
0619 case scZeroExtend:
0620 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
0621 case scSignExtend:
0622 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
0623 case scAddExpr:
0624 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
0625 case scMulExpr:
0626 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
0627 case scUDivExpr:
0628 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
0629 case scAddRecExpr:
0630 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
0631 case scSMaxExpr:
0632 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
0633 case scUMaxExpr:
0634 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
0635 case scSMinExpr:
0636 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
0637 case scUMinExpr:
0638 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
0639 case scSequentialUMinExpr:
0640 return ((SC *)this)
0641 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
0642 case scUnknown:
0643 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
0644 case scCouldNotCompute:
0645 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
0646 }
0647 llvm_unreachable("Unknown SCEV kind!");
0648 }
0649
0650 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
0651 llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
0652 }
0653 };
0654
0655
0656
0657
0658
0659
0660
0661
0662 template <typename SV> class SCEVTraversal {
0663 SV &Visitor;
0664 SmallVector<const SCEV *, 8> Worklist;
0665 SmallPtrSet<const SCEV *, 8> Visited;
0666
0667 void push(const SCEV *S) {
0668 if (Visited.insert(S).second && Visitor.follow(S))
0669 Worklist.push_back(S);
0670 }
0671
0672 public:
0673 SCEVTraversal(SV &V) : Visitor(V) {}
0674
0675 void visitAll(const SCEV *Root) {
0676 push(Root);
0677 while (!Worklist.empty() && !Visitor.isDone()) {
0678 const SCEV *S = Worklist.pop_back_val();
0679
0680 switch (S->getSCEVType()) {
0681 case scConstant:
0682 case scVScale:
0683 case scUnknown:
0684 continue;
0685 case scPtrToInt:
0686 case scTruncate:
0687 case scZeroExtend:
0688 case scSignExtend:
0689 case scAddExpr:
0690 case scMulExpr:
0691 case scUDivExpr:
0692 case scSMaxExpr:
0693 case scUMaxExpr:
0694 case scSMinExpr:
0695 case scUMinExpr:
0696 case scSequentialUMinExpr:
0697 case scAddRecExpr:
0698 for (const auto *Op : S->operands()) {
0699 push(Op);
0700 if (Visitor.isDone())
0701 break;
0702 }
0703 continue;
0704 case scCouldNotCompute:
0705 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
0706 }
0707 llvm_unreachable("Unknown SCEV kind!");
0708 }
0709 }
0710 };
0711
0712
0713 template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
0714 SCEVTraversal<SV> T(Visitor);
0715 T.visitAll(Root);
0716 }
0717
0718
0719 template <typename PredTy>
0720 bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
0721 struct FindClosure {
0722 bool Found = false;
0723 PredTy Pred;
0724
0725 FindClosure(PredTy Pred) : Pred(Pred) {}
0726
0727 bool follow(const SCEV *S) {
0728 if (!Pred(S))
0729 return true;
0730
0731 Found = true;
0732 return false;
0733 }
0734
0735 bool isDone() const { return Found; }
0736 };
0737
0738 FindClosure FC(Pred);
0739 visitAll(Root, FC);
0740 return FC.Found;
0741 }
0742
0743
0744
0745
0746 template <typename SC>
0747 class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
0748 protected:
0749 ScalarEvolution &SE;
0750
0751
0752
0753
0754
0755 SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
0756
0757 public:
0758 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
0759
0760 const SCEV *visit(const SCEV *S) {
0761 auto It = RewriteResults.find(S);
0762 if (It != RewriteResults.end())
0763 return It->second;
0764 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
0765 auto Result = RewriteResults.try_emplace(S, Visited);
0766 assert(Result.second && "Should insert a new entry");
0767 return Result.first->second;
0768 }
0769
0770 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
0771
0772 const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
0773
0774 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
0775 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
0776 return Operand == Expr->getOperand()
0777 ? Expr
0778 : SE.getPtrToIntExpr(Operand, Expr->getType());
0779 }
0780
0781 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
0782 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
0783 return Operand == Expr->getOperand()
0784 ? Expr
0785 : SE.getTruncateExpr(Operand, Expr->getType());
0786 }
0787
0788 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
0789 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
0790 return Operand == Expr->getOperand()
0791 ? Expr
0792 : SE.getZeroExtendExpr(Operand, Expr->getType());
0793 }
0794
0795 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
0796 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
0797 return Operand == Expr->getOperand()
0798 ? Expr
0799 : SE.getSignExtendExpr(Operand, Expr->getType());
0800 }
0801
0802 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
0803 SmallVector<const SCEV *, 2> Operands;
0804 bool Changed = false;
0805 for (const auto *Op : Expr->operands()) {
0806 Operands.push_back(((SC *)this)->visit(Op));
0807 Changed |= Op != Operands.back();
0808 }
0809 return !Changed ? Expr : SE.getAddExpr(Operands);
0810 }
0811
0812 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
0813 SmallVector<const SCEV *, 2> Operands;
0814 bool Changed = false;
0815 for (const auto *Op : Expr->operands()) {
0816 Operands.push_back(((SC *)this)->visit(Op));
0817 Changed |= Op != Operands.back();
0818 }
0819 return !Changed ? Expr : SE.getMulExpr(Operands);
0820 }
0821
0822 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
0823 auto *LHS = ((SC *)this)->visit(Expr->getLHS());
0824 auto *RHS = ((SC *)this)->visit(Expr->getRHS());
0825 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
0826 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
0827 }
0828
0829 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
0830 SmallVector<const SCEV *, 2> Operands;
0831 bool Changed = false;
0832 for (const auto *Op : Expr->operands()) {
0833 Operands.push_back(((SC *)this)->visit(Op));
0834 Changed |= Op != Operands.back();
0835 }
0836 return !Changed ? Expr
0837 : SE.getAddRecExpr(Operands, Expr->getLoop(),
0838 Expr->getNoWrapFlags());
0839 }
0840
0841 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
0842 SmallVector<const SCEV *, 2> Operands;
0843 bool Changed = false;
0844 for (const auto *Op : Expr->operands()) {
0845 Operands.push_back(((SC *)this)->visit(Op));
0846 Changed |= Op != Operands.back();
0847 }
0848 return !Changed ? Expr : SE.getSMaxExpr(Operands);
0849 }
0850
0851 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
0852 SmallVector<const SCEV *, 2> Operands;
0853 bool Changed = false;
0854 for (const auto *Op : Expr->operands()) {
0855 Operands.push_back(((SC *)this)->visit(Op));
0856 Changed |= Op != Operands.back();
0857 }
0858 return !Changed ? Expr : SE.getUMaxExpr(Operands);
0859 }
0860
0861 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
0862 SmallVector<const SCEV *, 2> Operands;
0863 bool Changed = false;
0864 for (const auto *Op : Expr->operands()) {
0865 Operands.push_back(((SC *)this)->visit(Op));
0866 Changed |= Op != Operands.back();
0867 }
0868 return !Changed ? Expr : SE.getSMinExpr(Operands);
0869 }
0870
0871 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
0872 SmallVector<const SCEV *, 2> Operands;
0873 bool Changed = false;
0874 for (const auto *Op : Expr->operands()) {
0875 Operands.push_back(((SC *)this)->visit(Op));
0876 Changed |= Op != Operands.back();
0877 }
0878 return !Changed ? Expr : SE.getUMinExpr(Operands);
0879 }
0880
0881 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
0882 SmallVector<const SCEV *, 2> Operands;
0883 bool Changed = false;
0884 for (const auto *Op : Expr->operands()) {
0885 Operands.push_back(((SC *)this)->visit(Op));
0886 Changed |= Op != Operands.back();
0887 }
0888 return !Changed ? Expr : SE.getUMinExpr(Operands, true);
0889 }
0890
0891 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
0892
0893 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
0894 return Expr;
0895 }
0896 };
0897
0898 using ValueToValueMap = DenseMap<const Value *, Value *>;
0899 using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
0900
0901
0902
0903 class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
0904 public:
0905 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
0906 ValueToSCEVMapTy &Map) {
0907 SCEVParameterRewriter Rewriter(SE, Map);
0908 return Rewriter.visit(Scev);
0909 }
0910
0911 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
0912 : SCEVRewriteVisitor(SE), Map(M) {}
0913
0914 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
0915 auto I = Map.find(Expr->getValue());
0916 if (I == Map.end())
0917 return Expr;
0918 return I->second;
0919 }
0920
0921 private:
0922 ValueToSCEVMapTy ⤅
0923 };
0924
0925 using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
0926
0927
0928
0929 class SCEVLoopAddRecRewriter
0930 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
0931 public:
0932 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
0933 : SCEVRewriteVisitor(SE), Map(M) {}
0934
0935 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
0936 ScalarEvolution &SE) {
0937 SCEVLoopAddRecRewriter Rewriter(SE, Map);
0938 return Rewriter.visit(Scev);
0939 }
0940
0941 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
0942 SmallVector<const SCEV *, 2> Operands;
0943 for (const SCEV *Op : Expr->operands())
0944 Operands.push_back(visit(Op));
0945
0946 const Loop *L = Expr->getLoop();
0947 if (0 == Map.count(L))
0948 return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
0949
0950 return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
0951 }
0952
0953 private:
0954 LoopToScevMapT ⤅
0955 };
0956
0957 }
0958
0959 #endif