Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:48:20

0001 //===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 // Make changes to isl's schedule tree data structure.
0010 //
0011 //===----------------------------------------------------------------------===//
0012 
0013 #ifndef POLLY_SCHEDULETREETRANSFORM_H
0014 #define POLLY_SCHEDULETREETRANSFORM_H
0015 
0016 #include "polly/Support/ISLTools.h"
0017 #include "llvm/ADT/ArrayRef.h"
0018 #include "llvm/Support/ErrorHandling.h"
0019 #include "isl/isl-noexceptions.h"
0020 #include <cassert>
0021 
0022 namespace polly {
0023 struct BandAttr;
0024 
0025 /// This class defines a simple visitor class that may be used for
0026 /// various schedule tree analysis purposes.
0027 template <typename Derived, typename RetTy = void, typename... Args>
0028 struct ScheduleTreeVisitor {
0029   Derived &getDerived() { return *static_cast<Derived *>(this); }
0030   const Derived &getDerived() const {
0031     return *static_cast<const Derived *>(this);
0032   }
0033 
0034   RetTy visit(isl::schedule_node Node, Args... args) {
0035     assert(!Node.is_null());
0036     switch (isl_schedule_node_get_type(Node.get())) {
0037     case isl_schedule_node_domain:
0038       assert(isl_schedule_node_n_children(Node.get()) == 1);
0039       return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(),
0040                                       std::forward<Args>(args)...);
0041     case isl_schedule_node_band:
0042       assert(isl_schedule_node_n_children(Node.get()) == 1);
0043       return getDerived().visitBand(Node.as<isl::schedule_node_band>(),
0044                                     std::forward<Args>(args)...);
0045     case isl_schedule_node_sequence:
0046       assert(isl_schedule_node_n_children(Node.get()) >= 2);
0047       return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(),
0048                                         std::forward<Args>(args)...);
0049     case isl_schedule_node_set:
0050       assert(isl_schedule_node_n_children(Node.get()) >= 2);
0051       return getDerived().visitSet(Node.as<isl::schedule_node_set>(),
0052                                    std::forward<Args>(args)...);
0053     case isl_schedule_node_leaf:
0054       assert(isl_schedule_node_n_children(Node.get()) == 0);
0055       return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(),
0056                                     std::forward<Args>(args)...);
0057     case isl_schedule_node_mark:
0058       assert(isl_schedule_node_n_children(Node.get()) == 1);
0059       return getDerived().visitMark(Node.as<isl::schedule_node_mark>(),
0060                                     std::forward<Args>(args)...);
0061     case isl_schedule_node_extension:
0062       assert(isl_schedule_node_n_children(Node.get()) == 1);
0063       return getDerived().visitExtension(
0064           Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...);
0065     case isl_schedule_node_filter:
0066       assert(isl_schedule_node_n_children(Node.get()) == 1);
0067       return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(),
0068                                       std::forward<Args>(args)...);
0069     default:
0070       llvm_unreachable("unimplemented schedule node type");
0071     }
0072   }
0073 
0074   RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) {
0075     return getDerived().visitSingleChild(std::move(Domain),
0076                                          std::forward<Args>(args)...);
0077   }
0078 
0079   RetTy visitBand(isl::schedule_node_band Band, Args... args) {
0080     return getDerived().visitSingleChild(std::move(Band),
0081                                          std::forward<Args>(args)...);
0082   }
0083 
0084   RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) {
0085     return getDerived().visitMultiChild(std::move(Sequence),
0086                                         std::forward<Args>(args)...);
0087   }
0088 
0089   RetTy visitSet(isl::schedule_node_set Set, Args... args) {
0090     return getDerived().visitMultiChild(std::move(Set),
0091                                         std::forward<Args>(args)...);
0092   }
0093 
0094   RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
0095     return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...);
0096   }
0097 
0098   RetTy visitMark(isl::schedule_node_mark Mark, Args... args) {
0099     return getDerived().visitSingleChild(std::move(Mark),
0100                                          std::forward<Args>(args)...);
0101   }
0102 
0103   RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) {
0104     return getDerived().visitSingleChild(std::move(Extension),
0105                                          std::forward<Args>(args)...);
0106   }
0107 
0108   RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) {
0109     return getDerived().visitSingleChild(std::move(Filter),
0110                                          std::forward<Args>(args)...);
0111   }
0112 
0113   RetTy visitSingleChild(isl::schedule_node Node, Args... args) {
0114     return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
0115   }
0116 
0117   RetTy visitMultiChild(isl::schedule_node Node, Args... args) {
0118     return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
0119   }
0120 
0121   RetTy visitNode(isl::schedule_node Node, Args... args) {
0122     llvm_unreachable("Unimplemented other");
0123   }
0124 };
0125 
0126 /// Recursively visit all nodes of a schedule tree.
0127 template <typename Derived, typename RetTy = void, typename... Args>
0128 struct RecursiveScheduleTreeVisitor
0129     : ScheduleTreeVisitor<Derived, RetTy, Args...> {
0130   using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>;
0131   BaseTy &getBase() { return *this; }
0132   const BaseTy &getBase() const { return *this; }
0133   Derived &getDerived() { return *static_cast<Derived *>(this); }
0134   const Derived &getDerived() const {
0135     return *static_cast<const Derived *>(this);
0136   }
0137 
0138   /// When visiting an entire schedule tree, start at its root node.
0139   RetTy visit(isl::schedule Schedule, Args... args) {
0140     return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...);
0141   }
0142 
0143   // Necessary to allow overload resolution with the added visit(isl::schedule)
0144   // overload.
0145   RetTy visit(isl::schedule_node Node, Args... args) {
0146     return getBase().visit(Node, std::forward<Args>(args)...);
0147   }
0148 
0149   /// By default, recursively visit the child nodes.
0150   RetTy visitNode(isl::schedule_node Node, Args... args) {
0151     for (unsigned i : rangeIslSize(0, Node.n_children()))
0152       getDerived().visit(Node.child(i), std::forward<Args>(args)...);
0153     return RetTy();
0154   }
0155 };
0156 
0157 /// Recursively visit all nodes of a schedule tree while allowing changes.
0158 ///
0159 /// The visit methods return an isl::schedule_node that is used to continue
0160 /// visiting the tree. Structural changes such as returning a different node
0161 /// will confuse the visitor.
0162 template <typename Derived, typename... Args>
0163 struct ScheduleNodeRewriter
0164     : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
0165                                           Args...> {
0166   Derived &getDerived() { return *static_cast<Derived *>(this); }
0167   const Derived &getDerived() const {
0168     return *static_cast<const Derived *>(this);
0169   }
0170 
0171   isl::schedule_node visitNode(isl::schedule_node Node, Args... args) {
0172     return getDerived().visitChildren(Node);
0173   }
0174 
0175   isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) {
0176     if (!Node.has_children())
0177       return Node;
0178 
0179     isl::schedule_node It = Node.first_child();
0180     while (true) {
0181       It = getDerived().visit(It, std::forward<Args>(args)...);
0182       if (!It.has_next_sibling())
0183         break;
0184       It = It.next_sibling();
0185     }
0186     return It.parent();
0187   }
0188 };
0189 
0190 /// Is this node the marker for its parent band?
0191 bool isBandMark(const isl::schedule_node &Node);
0192 
0193 /// Extract the BandAttr from a band's wrapping marker. Can also pass the band
0194 /// itself and this methods will try to find its wrapping mark. Returns nullptr
0195 /// if the band has not BandAttr.
0196 BandAttr *getBandAttr(isl::schedule_node MarkOrBand);
0197 
0198 /// Hoist all domains from extension into the root domain node, such that there
0199 /// are no more extension nodes (which isl does not support for some
0200 /// operations). This assumes that domains added by to extension nodes do not
0201 /// overlap.
0202 isl::schedule hoistExtensionNodes(isl::schedule Sched);
0203 
0204 /// Replace the AST band @p BandToUnroll by a sequence of all its iterations.
0205 ///
0206 /// The implementation enumerates all points in the partial schedule and creates
0207 /// an ISL sequence node for each point. The number of iterations must be a
0208 /// constant.
0209 isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll);
0210 
0211 /// Replace the AST band @p BandToUnroll by a partially unrolled equivalent.
0212 isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor);
0213 
0214 /// Loop-distribute the band @p BandToFission as much as possible.
0215 isl::schedule applyMaxFission(isl::schedule_node BandToFission);
0216 
0217 /// Build the desired set of partial tile prefixes.
0218 ///
0219 /// We build a set of partial tile prefixes, which are prefixes of the vector
0220 /// loop that have exactly VectorWidth iterations.
0221 ///
0222 /// 1. Drop all constraints involving the dimension that represents the
0223 ///    vector loop.
0224 /// 2. Constrain the last dimension to get a set, which has exactly VectorWidth
0225 ///    iterations.
0226 /// 3. Subtract loop domain from it, project out the vector loop dimension and
0227 ///    get a set that contains prefixes, which do not have exactly VectorWidth
0228 ///    iterations.
0229 /// 4. Project out the vector loop dimension of the set that was build on the
0230 ///    first step and subtract the set built on the previous step to get the
0231 ///    desired set of prefixes.
0232 ///
0233 /// @param ScheduleRange A range of a map, which describes a prefix schedule
0234 ///                      relation.
0235 isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth);
0236 
0237 /// Create an isl::union_set, which describes the isolate option based on
0238 /// IsolateDomain.
0239 ///
0240 /// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should
0241 ///                      belong to the current band node.
0242 /// @param OutDimsNum    A number of dimensions that should belong to
0243 ///                      the current band node.
0244 isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum);
0245 
0246 /// Create an isl::union_set, which describes the specified option for the
0247 /// dimension of the current node.
0248 ///
0249 /// @param Ctx    An isl::ctx, which is used to create the isl::union_set.
0250 /// @param Option The name of the option.
0251 isl::union_set getDimOptions(isl::ctx Ctx, const char *Option);
0252 
0253 /// Tile a schedule node.
0254 ///
0255 /// @param Node            The node to tile.
0256 /// @param Identifier      An name that identifies this kind of tiling and
0257 ///                        that is used to mark the tiled loops in the
0258 ///                        generated AST.
0259 /// @param TileSizes       A vector of tile sizes that should be used for
0260 ///                        tiling.
0261 /// @param DefaultTileSize A default tile size that is used for dimensions
0262 ///                        that are not covered by the TileSizes vector.
0263 isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier,
0264                             llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
0265 
0266 /// Tile a schedule node and unroll point loops.
0267 ///
0268 /// @param Node            The node to register tile.
0269 /// @param TileSizes       A vector of tile sizes that should be used for
0270 ///                        tiling.
0271 /// @param DefaultTileSize A default tile size that is used for dimensions
0272 isl::schedule_node applyRegisterTiling(isl::schedule_node Node,
0273                                        llvm::ArrayRef<int> TileSizes,
0274                                        int DefaultTileSize);
0275 
0276 /// Apply greedy fusion. That is, fuse any loop that is possible to be fused
0277 /// top-down.
0278 ///
0279 /// @param Sched  Sched tree to fuse all the loops in.
0280 /// @param Deps   Validity constraints that must be preserved.
0281 isl::schedule applyGreedyFusion(isl::schedule Sched,
0282                                 const isl::union_map &Deps);
0283 
0284 } // namespace polly
0285 
0286 #endif // POLLY_SCHEDULETREETRANSFORM_H