Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:44:38

0001 //===-- CFGMST.h - Minimum Spanning Tree for CFG ----------------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 // This file implements a Union-find algorithm to compute Minimum Spanning Tree
0010 // for a given CFG.
0011 //
0012 //===----------------------------------------------------------------------===//
0013 
0014 #ifndef LLVM_TRANSFORMS_INSTRUMENTATION_CFGMST_H
0015 #define LLVM_TRANSFORMS_INSTRUMENTATION_CFGMST_H
0016 
0017 #include "llvm/ADT/DenseMap.h"
0018 #include "llvm/ADT/STLExtras.h"
0019 #include "llvm/Analysis/BlockFrequencyInfo.h"
0020 #include "llvm/Analysis/BranchProbabilityInfo.h"
0021 #include "llvm/Analysis/CFG.h"
0022 #include "llvm/Analysis/LoopInfo.h"
0023 #include "llvm/IR/Instructions.h"
0024 #include "llvm/IR/IntrinsicInst.h"
0025 #include "llvm/Support/BranchProbability.h"
0026 #include "llvm/Support/Debug.h"
0027 #include "llvm/Support/raw_ostream.h"
0028 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
0029 #include <utility>
0030 #include <vector>
0031 
0032 #define DEBUG_TYPE "cfgmst"
0033 
0034 namespace llvm {
0035 
0036 /// An union-find based Minimum Spanning Tree for CFG
0037 ///
0038 /// Implements a Union-find algorithm to compute Minimum Spanning Tree
0039 /// for a given CFG.
0040 template <class Edge, class BBInfo> class CFGMST {
0041   Function &F;
0042 
0043   // Store all the edges in CFG. It may contain some stale edges
0044   // when Removed is set.
0045   std::vector<std::unique_ptr<Edge>> AllEdges;
0046 
0047   // This map records the auxiliary information for each BB.
0048   DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos;
0049 
0050   // Whehter the function has an exit block with no successors.
0051   // (For function with an infinite loop, this block may be absent)
0052   bool ExitBlockFound = false;
0053 
0054   BranchProbabilityInfo *const BPI;
0055   BlockFrequencyInfo *const BFI;
0056   LoopInfo *const LI;
0057 
0058   // If function entry will be always instrumented.
0059   const bool InstrumentFuncEntry;
0060 
0061   // If true loop entries will be always instrumented.
0062   const bool InstrumentLoopEntries;
0063 
0064   // Find the root group of the G and compress the path from G to the root.
0065   BBInfo *findAndCompressGroup(BBInfo *G) {
0066     if (G->Group != G)
0067       G->Group = findAndCompressGroup(static_cast<BBInfo *>(G->Group));
0068     return static_cast<BBInfo *>(G->Group);
0069   }
0070 
0071   // Union BB1 and BB2 into the same group and return true.
0072   // Returns false if BB1 and BB2 are already in the same group.
0073   bool unionGroups(const BasicBlock *BB1, const BasicBlock *BB2) {
0074     BBInfo *BB1G = findAndCompressGroup(&getBBInfo(BB1));
0075     BBInfo *BB2G = findAndCompressGroup(&getBBInfo(BB2));
0076 
0077     if (BB1G == BB2G)
0078       return false;
0079 
0080     // Make the smaller rank tree a direct child or the root of high rank tree.
0081     if (BB1G->Rank < BB2G->Rank)
0082       BB1G->Group = BB2G;
0083     else {
0084       BB2G->Group = BB1G;
0085       // If the ranks are the same, increment root of one tree by one.
0086       if (BB1G->Rank == BB2G->Rank)
0087         BB1G->Rank++;
0088     }
0089     return true;
0090   }
0091 
0092   void handleCoroSuspendEdge(Edge *E) {
0093     // We must not add instrumentation to the BB representing the
0094     // "suspend" path, else CoroSplit won't be able to lower
0095     // llvm.coro.suspend to a tail call. We do want profiling info for
0096     // the other branches (resume/destroy). So we do 2 things:
0097     // 1. we prefer instrumenting those other edges by setting the weight
0098     //    of the "suspend" edge to max, and
0099     // 2. we mark the edge as "Removed" to guarantee it is not considered
0100     //    for instrumentation. That could technically happen:
0101     //    (from test/Transforms/Coroutines/coro-split-musttail.ll)
0102     //
0103     // %suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
0104     // switch i8 %suspend, label %exit [
0105     //   i8 0, label %await.ready
0106     //   i8 1, label %exit
0107     // ]
0108     if (!E->DestBB)
0109       return;
0110     assert(E->SrcBB);
0111     if (llvm::isPresplitCoroSuspendExitEdge(*E->SrcBB, *E->DestBB))
0112       E->Removed = true;
0113   }
0114 
0115   // Traverse the CFG using a stack. Find all the edges and assign the weight.
0116   // Edges with large weight will be put into MST first so they are less likely
0117   // to be instrumented.
0118   void buildEdges() {
0119     LLVM_DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n");
0120 
0121     BasicBlock *Entry = &(F.getEntryBlock());
0122     uint64_t EntryWeight =
0123         (BFI != nullptr ? BFI->getEntryFreq().getFrequency() : 2);
0124     // If we want to instrument the entry count, lower the weight to 0.
0125     if (InstrumentFuncEntry)
0126       EntryWeight = 0;
0127     Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr,
0128          *ExitOutgoing = nullptr, *ExitIncoming = nullptr;
0129     uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0;
0130 
0131     // Add a fake edge to the entry.
0132     EntryIncoming = &addEdge(nullptr, Entry, EntryWeight);
0133     LLVM_DEBUG(dbgs() << "  Edge: from fake node to " << Entry->getName()
0134                       << " w = " << EntryWeight << "\n");
0135 
0136     // Special handling for single BB functions.
0137     if (succ_empty(Entry)) {
0138       addEdge(Entry, nullptr, EntryWeight);
0139       return;
0140     }
0141 
0142     static const uint32_t CriticalEdgeMultiplier = 1000;
0143 
0144     for (BasicBlock &BB : F) {
0145       Instruction *TI = BB.getTerminator();
0146       uint64_t BBWeight =
0147           (BFI != nullptr ? BFI->getBlockFreq(&BB).getFrequency() : 2);
0148       uint64_t Weight = 2;
0149       if (int successors = TI->getNumSuccessors()) {
0150         for (int i = 0; i != successors; ++i) {
0151           BasicBlock *TargetBB = TI->getSuccessor(i);
0152           bool Critical = isCriticalEdge(TI, i);
0153           uint64_t scaleFactor = BBWeight;
0154           if (Critical) {
0155             if (scaleFactor < UINT64_MAX / CriticalEdgeMultiplier)
0156               scaleFactor *= CriticalEdgeMultiplier;
0157             else
0158               scaleFactor = UINT64_MAX;
0159           }
0160           if (BPI != nullptr)
0161             Weight = BPI->getEdgeProbability(&BB, TargetBB).scale(scaleFactor);
0162           // If InstrumentLoopEntries is on and the current edge leads to a loop
0163           // (i.e., TargetBB is a loop head and BB is outside its loop), set
0164           // Weight to be minimal, so that the edge won't be chosen for the MST
0165           // and will be instrumented.
0166           if (InstrumentLoopEntries && LI->isLoopHeader(TargetBB)) {
0167             Loop *TargetLoop = LI->getLoopFor(TargetBB);
0168             assert(TargetLoop);
0169             if (!TargetLoop->contains(&BB))
0170               Weight = 0;
0171           }
0172           if (Weight == 0)
0173             Weight++;
0174           auto *E = &addEdge(&BB, TargetBB, Weight);
0175           E->IsCritical = Critical;
0176           handleCoroSuspendEdge(E);
0177           LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to "
0178                             << TargetBB->getName() << "  w=" << Weight << "\n");
0179 
0180           // Keep track of entry/exit edges:
0181           if (&BB == Entry) {
0182             if (Weight > MaxEntryOutWeight) {
0183               MaxEntryOutWeight = Weight;
0184               EntryOutgoing = E;
0185             }
0186           }
0187 
0188           auto *TargetTI = TargetBB->getTerminator();
0189           if (TargetTI && !TargetTI->getNumSuccessors()) {
0190             if (Weight > MaxExitInWeight) {
0191               MaxExitInWeight = Weight;
0192               ExitIncoming = E;
0193             }
0194           }
0195         }
0196       } else {
0197         ExitBlockFound = true;
0198         Edge *ExitO = &addEdge(&BB, nullptr, BBWeight);
0199         if (BBWeight > MaxExitOutWeight) {
0200           MaxExitOutWeight = BBWeight;
0201           ExitOutgoing = ExitO;
0202         }
0203         LLVM_DEBUG(dbgs() << "  Edge: from " << BB.getName() << " to fake exit"
0204                           << " w = " << BBWeight << "\n");
0205       }
0206     }
0207 
0208     // Entry/exit edge adjustment heurisitic:
0209     // prefer instrumenting entry edge over exit edge
0210     // if possible. Those exit edges may never have a chance to be
0211     // executed (for instance the program is an event handling loop)
0212     // before the profile is asynchronously dumped.
0213     //
0214     // If EntryIncoming and ExitOutgoing has similar weight, make sure
0215     // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing
0216     // and ExitIncoming has similar weight, make sure ExitIncoming becomes
0217     // the min-edge.
0218     uint64_t EntryInWeight = EntryWeight;
0219 
0220     if (EntryInWeight >= MaxExitOutWeight &&
0221         EntryInWeight * 2 < MaxExitOutWeight * 3) {
0222       EntryIncoming->Weight = MaxExitOutWeight;
0223       ExitOutgoing->Weight = EntryInWeight + 1;
0224     }
0225 
0226     if (MaxEntryOutWeight >= MaxExitInWeight &&
0227         MaxEntryOutWeight * 2 < MaxExitInWeight * 3) {
0228       EntryOutgoing->Weight = MaxExitInWeight;
0229       ExitIncoming->Weight = MaxEntryOutWeight + 1;
0230     }
0231   }
0232 
0233   // Sort CFG edges based on its weight.
0234   void sortEdgesByWeight() {
0235     llvm::stable_sort(AllEdges, [](const std::unique_ptr<Edge> &Edge1,
0236                                    const std::unique_ptr<Edge> &Edge2) {
0237       return Edge1->Weight > Edge2->Weight;
0238     });
0239   }
0240 
0241   // Traverse all the edges and compute the Minimum Weight Spanning Tree
0242   // using union-find algorithm.
0243   void computeMinimumSpanningTree() {
0244     // First, put all the critical edge with landing-pad as the Dest to MST.
0245     // This works around the insufficient support of critical edges split
0246     // when destination BB is a landing pad.
0247     for (auto &Ei : AllEdges) {
0248       if (Ei->Removed)
0249         continue;
0250       if (Ei->IsCritical) {
0251         if (Ei->DestBB && Ei->DestBB->isLandingPad()) {
0252           if (unionGroups(Ei->SrcBB, Ei->DestBB))
0253             Ei->InMST = true;
0254         }
0255       }
0256     }
0257 
0258     for (auto &Ei : AllEdges) {
0259       if (Ei->Removed)
0260         continue;
0261       // If we detect infinite loops, force
0262       // instrumenting the entry edge:
0263       if (!ExitBlockFound && Ei->SrcBB == nullptr)
0264         continue;
0265       if (unionGroups(Ei->SrcBB, Ei->DestBB))
0266         Ei->InMST = true;
0267     }
0268   }
0269 
0270   [[maybe_unused]] bool validateLoopEntryInstrumentation() {
0271     if (!InstrumentLoopEntries)
0272       return true;
0273     for (auto &Ei : AllEdges) {
0274       if (Ei->Removed)
0275         continue;
0276       if (Ei->DestBB && LI->isLoopHeader(Ei->DestBB) &&
0277           !LI->getLoopFor(Ei->DestBB)->contains(Ei->SrcBB) && Ei->InMST)
0278         return false;
0279     }
0280     return true;
0281   }
0282 
0283 public:
0284   // Dump the Debug information about the instrumentation.
0285   void dumpEdges(raw_ostream &OS, const Twine &Message) const {
0286     if (!Message.str().empty())
0287       OS << Message << "\n";
0288     OS << "  Number of Basic Blocks: " << BBInfos.size() << "\n";
0289     for (auto &BI : BBInfos) {
0290       const BasicBlock *BB = BI.first;
0291       OS << "  BB: " << (BB == nullptr ? "FakeNode" : BB->getName()) << "  "
0292          << BI.second->infoString() << "\n";
0293     }
0294 
0295     OS << "  Number of Edges: " << AllEdges.size()
0296        << " (*: Instrument, C: CriticalEdge, -: Removed)\n";
0297     uint32_t Count = 0;
0298     for (auto &EI : AllEdges)
0299       OS << "  Edge " << Count++ << ": " << getBBInfo(EI->SrcBB).Index << "-->"
0300          << getBBInfo(EI->DestBB).Index << EI->infoString() << "\n";
0301   }
0302 
0303   // Add an edge to AllEdges with weight W.
0304   Edge &addEdge(BasicBlock *Src, BasicBlock *Dest, uint64_t W) {
0305     uint32_t Index = BBInfos.size();
0306     auto Iter = BBInfos.end();
0307     bool Inserted;
0308     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Src, nullptr));
0309     if (Inserted) {
0310       // Newly inserted, update the real info.
0311       Iter->second = std::make_unique<BBInfo>(Index);
0312       Index++;
0313     }
0314     std::tie(Iter, Inserted) = BBInfos.insert(std::make_pair(Dest, nullptr));
0315     if (Inserted)
0316       // Newly inserted, update the real info.
0317       Iter->second = std::make_unique<BBInfo>(Index);
0318     AllEdges.emplace_back(new Edge(Src, Dest, W));
0319     return *AllEdges.back();
0320   }
0321 
0322   CFGMST(Function &Func, bool InstrumentFuncEntry, bool InstrumentLoopEntries,
0323          BranchProbabilityInfo *BPI = nullptr,
0324          BlockFrequencyInfo *BFI = nullptr, LoopInfo *LI = nullptr)
0325       : F(Func), BPI(BPI), BFI(BFI), LI(LI),
0326         InstrumentFuncEntry(InstrumentFuncEntry),
0327         InstrumentLoopEntries(InstrumentLoopEntries) {
0328     assert(!(InstrumentLoopEntries && !LI) &&
0329            "expected a LoopInfo to instrumenting loop entries");
0330     buildEdges();
0331     sortEdgesByWeight();
0332     computeMinimumSpanningTree();
0333     assert(validateLoopEntryInstrumentation() &&
0334            "Loop entries should not be in MST when "
0335            "InstrumentLoopEntries is on");
0336     if (AllEdges.size() > 1 && InstrumentFuncEntry)
0337       std::iter_swap(std::move(AllEdges.begin()),
0338                      std::move(AllEdges.begin() + AllEdges.size() - 1));
0339   }
0340 
0341   const std::vector<std::unique_ptr<Edge>> &allEdges() const {
0342     return AllEdges;
0343   }
0344 
0345   std::vector<std::unique_ptr<Edge>> &allEdges() { return AllEdges; }
0346 
0347   size_t numEdges() const { return AllEdges.size(); }
0348 
0349   size_t bbInfoSize() const { return BBInfos.size(); }
0350 
0351   // Give BB, return the auxiliary information.
0352   BBInfo &getBBInfo(const BasicBlock *BB) const {
0353     auto It = BBInfos.find(BB);
0354     assert(It->second.get() != nullptr);
0355     return *It->second.get();
0356   }
0357 
0358   // Give BB, return the auxiliary information if it's available.
0359   BBInfo *findBBInfo(const BasicBlock *BB) const {
0360     auto It = BBInfos.find(BB);
0361     if (It == BBInfos.end())
0362       return nullptr;
0363     return It->second.get();
0364   }
0365 };
0366 
0367 } // end namespace llvm
0368 
0369 #undef DEBUG_TYPE // "cfgmst"
0370 
0371 #endif // LLVM_TRANSFORMS_INSTRUMENTATION_CFGMST_H