File indexing completed on 2026-05-10 08:48:20
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 #ifndef POLLY_SCHEDULETREETRANSFORM_H
0014 #define POLLY_SCHEDULETREETRANSFORM_H
0015
0016 #include "polly/Support/ISLTools.h"
0017 #include "llvm/ADT/ArrayRef.h"
0018 #include "llvm/Support/ErrorHandling.h"
0019 #include "isl/isl-noexceptions.h"
0020 #include <cassert>
0021
0022 namespace polly {
0023 struct BandAttr;
0024
0025
0026
0027 template <typename Derived, typename RetTy = void, typename... Args>
0028 struct ScheduleTreeVisitor {
0029 Derived &getDerived() { return *static_cast<Derived *>(this); }
0030 const Derived &getDerived() const {
0031 return *static_cast<const Derived *>(this);
0032 }
0033
0034 RetTy visit(isl::schedule_node Node, Args... args) {
0035 assert(!Node.is_null());
0036 switch (isl_schedule_node_get_type(Node.get())) {
0037 case isl_schedule_node_domain:
0038 assert(isl_schedule_node_n_children(Node.get()) == 1);
0039 return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(),
0040 std::forward<Args>(args)...);
0041 case isl_schedule_node_band:
0042 assert(isl_schedule_node_n_children(Node.get()) == 1);
0043 return getDerived().visitBand(Node.as<isl::schedule_node_band>(),
0044 std::forward<Args>(args)...);
0045 case isl_schedule_node_sequence:
0046 assert(isl_schedule_node_n_children(Node.get()) >= 2);
0047 return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(),
0048 std::forward<Args>(args)...);
0049 case isl_schedule_node_set:
0050 assert(isl_schedule_node_n_children(Node.get()) >= 2);
0051 return getDerived().visitSet(Node.as<isl::schedule_node_set>(),
0052 std::forward<Args>(args)...);
0053 case isl_schedule_node_leaf:
0054 assert(isl_schedule_node_n_children(Node.get()) == 0);
0055 return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(),
0056 std::forward<Args>(args)...);
0057 case isl_schedule_node_mark:
0058 assert(isl_schedule_node_n_children(Node.get()) == 1);
0059 return getDerived().visitMark(Node.as<isl::schedule_node_mark>(),
0060 std::forward<Args>(args)...);
0061 case isl_schedule_node_extension:
0062 assert(isl_schedule_node_n_children(Node.get()) == 1);
0063 return getDerived().visitExtension(
0064 Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...);
0065 case isl_schedule_node_filter:
0066 assert(isl_schedule_node_n_children(Node.get()) == 1);
0067 return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(),
0068 std::forward<Args>(args)...);
0069 default:
0070 llvm_unreachable("unimplemented schedule node type");
0071 }
0072 }
0073
0074 RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) {
0075 return getDerived().visitSingleChild(std::move(Domain),
0076 std::forward<Args>(args)...);
0077 }
0078
0079 RetTy visitBand(isl::schedule_node_band Band, Args... args) {
0080 return getDerived().visitSingleChild(std::move(Band),
0081 std::forward<Args>(args)...);
0082 }
0083
0084 RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) {
0085 return getDerived().visitMultiChild(std::move(Sequence),
0086 std::forward<Args>(args)...);
0087 }
0088
0089 RetTy visitSet(isl::schedule_node_set Set, Args... args) {
0090 return getDerived().visitMultiChild(std::move(Set),
0091 std::forward<Args>(args)...);
0092 }
0093
0094 RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
0095 return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...);
0096 }
0097
0098 RetTy visitMark(isl::schedule_node_mark Mark, Args... args) {
0099 return getDerived().visitSingleChild(std::move(Mark),
0100 std::forward<Args>(args)...);
0101 }
0102
0103 RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) {
0104 return getDerived().visitSingleChild(std::move(Extension),
0105 std::forward<Args>(args)...);
0106 }
0107
0108 RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) {
0109 return getDerived().visitSingleChild(std::move(Filter),
0110 std::forward<Args>(args)...);
0111 }
0112
0113 RetTy visitSingleChild(isl::schedule_node Node, Args... args) {
0114 return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
0115 }
0116
0117 RetTy visitMultiChild(isl::schedule_node Node, Args... args) {
0118 return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
0119 }
0120
0121 RetTy visitNode(isl::schedule_node Node, Args... args) {
0122 llvm_unreachable("Unimplemented other");
0123 }
0124 };
0125
0126
0127 template <typename Derived, typename RetTy = void, typename... Args>
0128 struct RecursiveScheduleTreeVisitor
0129 : ScheduleTreeVisitor<Derived, RetTy, Args...> {
0130 using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>;
0131 BaseTy &getBase() { return *this; }
0132 const BaseTy &getBase() const { return *this; }
0133 Derived &getDerived() { return *static_cast<Derived *>(this); }
0134 const Derived &getDerived() const {
0135 return *static_cast<const Derived *>(this);
0136 }
0137
0138
0139 RetTy visit(isl::schedule Schedule, Args... args) {
0140 return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...);
0141 }
0142
0143
0144
0145 RetTy visit(isl::schedule_node Node, Args... args) {
0146 return getBase().visit(Node, std::forward<Args>(args)...);
0147 }
0148
0149
0150 RetTy visitNode(isl::schedule_node Node, Args... args) {
0151 for (unsigned i : rangeIslSize(0, Node.n_children()))
0152 getDerived().visit(Node.child(i), std::forward<Args>(args)...);
0153 return RetTy();
0154 }
0155 };
0156
0157
0158
0159
0160
0161
0162 template <typename Derived, typename... Args>
0163 struct ScheduleNodeRewriter
0164 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
0165 Args...> {
0166 Derived &getDerived() { return *static_cast<Derived *>(this); }
0167 const Derived &getDerived() const {
0168 return *static_cast<const Derived *>(this);
0169 }
0170
0171 isl::schedule_node visitNode(isl::schedule_node Node, Args... args) {
0172 return getDerived().visitChildren(Node);
0173 }
0174
0175 isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) {
0176 if (!Node.has_children())
0177 return Node;
0178
0179 isl::schedule_node It = Node.first_child();
0180 while (true) {
0181 It = getDerived().visit(It, std::forward<Args>(args)...);
0182 if (!It.has_next_sibling())
0183 break;
0184 It = It.next_sibling();
0185 }
0186 return It.parent();
0187 }
0188 };
0189
0190
0191 bool isBandMark(const isl::schedule_node &Node);
0192
0193
0194
0195
0196 BandAttr *getBandAttr(isl::schedule_node MarkOrBand);
0197
0198
0199
0200
0201
0202 isl::schedule hoistExtensionNodes(isl::schedule Sched);
0203
0204
0205
0206
0207
0208
0209 isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll);
0210
0211
0212 isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor);
0213
0214
0215 isl::schedule applyMaxFission(isl::schedule_node BandToFission);
0216
0217
0218
0219
0220
0221
0222
0223
0224
0225
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235 isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth);
0236
0237
0238
0239
0240
0241
0242
0243
0244 isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum);
0245
0246
0247
0248
0249
0250
0251 isl::union_set getDimOptions(isl::ctx Ctx, const char *Option);
0252
0253
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263 isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier,
0264 llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
0265
0266
0267
0268
0269
0270
0271
0272 isl::schedule_node applyRegisterTiling(isl::schedule_node Node,
0273 llvm::ArrayRef<int> TileSizes,
0274 int DefaultTileSize);
0275
0276
0277
0278
0279
0280
0281 isl::schedule applyGreedyFusion(isl::schedule Sched,
0282 const isl::union_map &Deps);
0283
0284 }
0285
0286 #endif