![]() |
|
|||
File indexing completed on 2025-09-18 09:32:03
0001 // Author: Enrico Guiraud, Danilo Piparo CERN 02/2018 0002 0003 /************************************************************************* 0004 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. * 0005 * All rights reserved. * 0006 * * 0007 * For the licensing terms see $ROOTSYS/LICENSE. * 0008 * For the list of contributors see $ROOTSYS/README/CREDITS. * 0009 *************************************************************************/ 0010 0011 // This header contains helper free functions that slim down RDataFrame's programming model 0012 0013 #ifndef ROOT_RDF_HELPERS 0014 #define ROOT_RDF_HELPERS 0015 0016 #include <ROOT/RDF/GraphUtils.hxx> 0017 #include <ROOT/RDF/RActionBase.hxx> 0018 #include <ROOT/RDF/RResultMap.hxx> 0019 #include <ROOT/RResultHandle.hxx> // users of RunGraphs might rely on this transitive include 0020 #include <ROOT/TypeTraits.hxx> 0021 0022 #include <array> 0023 #include <chrono> 0024 #include <fstream> 0025 #include <functional> 0026 #include <map> 0027 #include <memory> 0028 #include <mutex> 0029 #include <type_traits> 0030 #include <utility> // std::index_sequence 0031 #include <vector> 0032 0033 namespace ROOT { 0034 namespace Internal { 0035 namespace RDF { 0036 template <typename... ArgTypes, typename F> 0037 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, F &&f) 0038 { 0039 return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); }); 0040 } 0041 0042 template <typename... ArgTypes, typename Ret, typename... Args> 0043 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, Ret (*f)(Args...)) 0044 { 0045 return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); }); 0046 } 0047 0048 template <typename I, typename T, typename F> 0049 class PassAsVecHelper; 0050 0051 template <std::size_t... N, typename T, typename F> 0052 class PassAsVecHelper<std::index_sequence<N...>, T, F> { 0053 template <std::size_t Idx> 0054 using AlwaysT = T; 0055 std::decay_t<F> fFunc; 0056 0057 public: 0058 PassAsVecHelper(F &&f) : fFunc(std::forward<F>(f)) {} 0059 auto operator()(AlwaysT<N>... args) -> decltype(fFunc({args...})) { return fFunc({args...}); } 0060 }; 0061 0062 template <std::size_t N, typename T, typename F> 0063 auto PassAsVec(F &&f) -> PassAsVecHelper<std::make_index_sequence<N>, T, F> 0064 { 0065 return PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f)); 0066 } 0067 0068 } // namespace RDF 0069 } // namespace Internal 0070 0071 namespace RDF { 0072 namespace RDFInternal = ROOT::Internal::RDF; 0073 0074 // clang-format off 0075 /// Given a callable with signature bool(T1, T2, ...) return a callable with same signature that returns the negated result 0076 /// 0077 /// The callable must have one single non-template definition of operator(). This is a limitation with respect to 0078 /// std::not_fn, required for interoperability with RDataFrame. 0079 // clang-format on 0080 template <typename F, 0081 typename Args = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::arg_types_nodecay, 0082 typename Ret = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::ret_type> 0083 auto Not(F &&f) -> decltype(RDFInternal::NotHelper(Args(), std::forward<F>(f))) 0084 { 0085 static_assert(std::is_same<Ret, bool>::value, "RDF::Not requires a callable that returns a bool."); 0086 return RDFInternal::NotHelper(Args(), std::forward<F>(f)); 0087 } 0088 0089 // clang-format off 0090 /// PassAsVec is a callable generator that allows passing N variables of type T to a function as a single collection. 0091 /// 0092 /// PassAsVec<N, T>(func) returns a callable that takes N arguments of type T, passes them down to function `func` as 0093 /// an initializer list `{t1, t2, t3,..., tN}` and returns whatever f({t1, t2, t3, ..., tN}) returns. 0094 /// 0095 /// Note that for this to work with RDataFrame the type of all columns that the callable is applied to must be exactly T. 0096 /// Example usage together with RDataFrame ("varX" columns must all be `float` variables): 0097 /// \code 0098 /// bool myVecFunc(std::vector<float> args); 0099 /// df.Filter(PassAsVec<3, float>(myVecFunc), {"var1", "var2", "var3"}); 0100 /// \endcode 0101 // clang-format on 0102 template <std::size_t N, typename T, typename F> 0103 auto PassAsVec(F &&f) -> RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F> 0104 { 0105 return RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f)); 0106 } 0107 0108 // clang-format off 0109 /// Create a graphviz representation of the dataframe computation graph, return it as a string. 0110 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to. 0111 /// 0112 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`. 0113 /// 0114 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are 0115 /// effectively optimized away from the computation graph. 0116 /// 0117 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads. 0118 // clang-format on 0119 template <typename NodeType> 0120 std::string SaveGraph(NodeType node) 0121 { 0122 ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper; 0123 return helper.RepresentGraph(node); 0124 } 0125 0126 // clang-format off 0127 /// Create a graphviz representation of the dataframe computation graph, write it to the specified file. 0128 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to. 0129 /// \param[in] outputFile file where to save the representation. 0130 /// 0131 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`. 0132 /// 0133 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are 0134 /// effectively optimized away from the computation graph. 0135 /// 0136 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads. 0137 // clang-format on 0138 template <typename NodeType> 0139 void SaveGraph(NodeType node, const std::string &outputFile) 0140 { 0141 ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper; 0142 std::string dotGraph = helper.RepresentGraph(node); 0143 0144 std::ofstream out(outputFile); 0145 if (!out.is_open()) { 0146 throw std::runtime_error("Could not open output file \"" + outputFile + "\"for reading"); 0147 } 0148 0149 out << dotGraph; 0150 out.close(); 0151 } 0152 0153 // clang-format off 0154 /// Cast a RDataFrame node to the common type ROOT::RDF::RNode 0155 /// \param[in] node Any node of a RDataFrame graph 0156 // clang-format on 0157 template <typename NodeType> 0158 RNode AsRNode(NodeType node) 0159 { 0160 return node; 0161 } 0162 0163 // clang-format off 0164 /// Run the event loops of multiple RDataFrames concurrently. 0165 /// \param[in] handles A vector of RResultHandles whose event loops should be run. 0166 /// \return The number of distinct computation graphs that have been processed. 0167 /// 0168 /// This function triggers the event loop of all computation graphs which relate to the 0169 /// given RResultHandles. The advantage compared to running the event loop implicitly by accessing the 0170 /// RResultPtr is that the event loops will run concurrently. Therefore, the overall 0171 /// computation of all results can be scheduled more efficiently. 0172 /// It should be noted that user-defined operations (e.g., Filters and Defines) of the different RDataFrame graphs are assumed to be safe to call concurrently. 0173 /// RDataFrame will pass slot numbers in the range [0, NThread-1] to all helpers used in nodes such as DefineSlot. NThread is the number of threads ROOT was 0174 /// configured with in EnableImplicitMT(). 0175 /// Slot numbers are unique across all graphs, so no two tasks with the same slot number will run concurrently. Note that it is not guaranteed that each slot 0176 /// number will be reached in every graph. 0177 /// 0178 /// ~~~{.cpp} 0179 /// ROOT::RDataFrame df1("tree1", "file1.root"); 0180 /// auto r1 = df1.Histo1D("var1"); 0181 /// 0182 /// ROOT::RDataFrame df2("tree2", "file2.root"); 0183 /// auto r2 = df2.Sum("var2"); 0184 /// 0185 /// // RResultPtr -> RResultHandle conversion is automatic 0186 /// ROOT::RDF::RunGraphs({r1, r2}); 0187 /// ~~~ 0188 // clang-format on 0189 unsigned int RunGraphs(std::vector<RResultHandle> handles); 0190 0191 namespace Experimental { 0192 0193 /// \brief Produce all required systematic variations for the given result. 0194 /// \param[in] resPtr The result for which variations should be produced. 0195 /// \return A \ref ROOT::RDF::Experimental::RResultMap "RResultMap" object with full variation names as strings 0196 /// (e.g. "pt:down") and the corresponding varied results as values. 0197 /// 0198 /// A given input RResultPtr<T> produces a corresponding RResultMap<T> with a "nominal" 0199 /// key that will return a value identical to the one contained in the original RResultPtr. 0200 /// Other keys correspond to the varied values of this result, one for each variation 0201 /// that the result depends on. 0202 /// VariationsFor does not trigger the event loop. The event loop is only triggered 0203 /// upon first access to a valid key, similarly to what happens with RResultPtr. 0204 /// 0205 /// If the result does not depend, directly or indirectly, from any registered systematic variation, the 0206 /// returned RResultMap will contain only the "nominal" key. 0207 /// 0208 /// See RDataFrame's \ref ROOT::RDF::RInterface::Vary() "Vary" method for more information and example usages. 0209 /// 0210 /// \note Currently, producing variations for the results of \ref ROOT::RDF::RInterface::Display() "Display", 0211 /// \ref ROOT::RDF::RInterface::Report() "Report" and \ref ROOT::RDF::RInterface::Snapshot() "Snapshot" 0212 /// actions is not supported. 0213 // 0214 // An overview of how systematic variations work internally. Given N variations (including the nominal): 0215 // 0216 // RResultMap owns RVariedAction 0217 // N results N action helpers 0218 // N previous filters 0219 // N*#input_cols column readers 0220 // 0221 // ...and each RFilter and RDefine knows for what universe it needs to construct column readers ("nominal" by default). 0222 template <typename T> 0223 RResultMap<T> VariationsFor(RResultPtr<T> resPtr) 0224 { 0225 R__ASSERT(resPtr != nullptr && "Calling VariationsFor on an empty RResultPtr"); 0226 0227 // populate parts of the computation graph for which we only have "empty shells", e.g. RJittedActions and 0228 // RJittedFilters 0229 resPtr.fLoopManager->Jit(); 0230 0231 std::unique_ptr<RDFInternal::RActionBase> variedAction; 0232 std::vector<std::shared_ptr<T>> variedResults; 0233 0234 std::shared_ptr<RDFInternal::RActionBase> nominalAction = resPtr.fActionPtr; 0235 std::vector<std::string> variations = nominalAction->GetVariations(); 0236 const auto nVariations = variations.size(); 0237 0238 if (nVariations > 0) { 0239 // clone the result once for each variation 0240 variedResults.reserve(nVariations); 0241 for (auto i = 0u; i < nVariations; ++i){ 0242 // implicitly assuming that T is copiable: this should be the case 0243 // for all result types in use, as they are copied for each slot 0244 variedResults.emplace_back(new T{*resPtr.fObjPtr}); 0245 0246 // Check if the result's type T inherits from TNamed 0247 if constexpr (std::is_base_of<TNamed, T>::value) { 0248 // Get the current variation name 0249 std::string variationName = variations[i]; 0250 // Replace the colon with an underscore 0251 std::replace(variationName.begin(), variationName.end(), ':', '_'); 0252 // Get a pointer to the corresponding varied result 0253 auto &variedResult = variedResults.back(); 0254 // Set the varied result's name to NOMINALNAME_VARIATIONAME 0255 variedResult->SetName((std::string(variedResult->GetName()) + "_" + variationName).c_str()); 0256 } 0257 } 0258 0259 std::vector<void *> typeErasedResults; 0260 typeErasedResults.reserve(variedResults.size()); 0261 for (auto &res : variedResults) 0262 typeErasedResults.emplace_back(&res); 0263 0264 // Create the RVariedAction and inject it in the computation graph. 0265 // This recursively creates all the required varied column readers and upstream nodes of the computation graph. 0266 variedAction = nominalAction->MakeVariedAction(std::move(typeErasedResults)); 0267 } 0268 0269 return RDFInternal::MakeResultMap<T>(resPtr.fObjPtr, std::move(variedResults), std::move(variations), 0270 *resPtr.fLoopManager, std::move(nominalAction), std::move(variedAction)); 0271 } 0272 0273 using SnapshotPtr_t = ROOT::RDF::RResultPtr<ROOT::RDF::RInterface<ROOT::Detail::RDF::RLoopManager, void>>; 0274 SnapshotPtr_t VariationsFor(SnapshotPtr_t resPtr); 0275 0276 /// \brief Add ProgressBar to a ROOT::RDF::RNode 0277 /// \param[in] df RDataFrame node at which ProgressBar is called. 0278 /// 0279 /// The ProgressBar can be added not only at the RDataFrame head node, but also at any any computational node, 0280 /// such as Filter or Define. 0281 /// ###Example usage: 0282 /// ~~~{.cpp} 0283 /// ROOT::RDataFrame df("tree", "file.root"); 0284 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1")); 0285 /// ROOT::RDF::Experimental::AddProgressBar(df_1); 0286 /// ~~~ 0287 void AddProgressBar(ROOT::RDF::RNode df); 0288 0289 /// \brief Add ProgressBar to an RDataFrame 0290 /// \param[in] df RDataFrame for which ProgressBar is called. 0291 /// 0292 /// This function adds a ProgressBar to display the event statistics in the terminal every 0293 /// \b m events and every \b n seconds, including elapsed time, currently processed file, 0294 /// currently processed events, the rate of event processing 0295 /// and an estimated remaining time (per file being processed). 0296 /// ProgressBar should be added after the dataframe object (df) is created first: 0297 /// ~~~{.cpp} 0298 /// ROOT::RDataFrame df("tree", "file.root"); 0299 /// ROOT::RDF::Experimental::AddProgressBar(df); 0300 /// ~~~ 0301 /// For more details see ROOT::RDF::Experimental::ProgressHelper Class. 0302 void AddProgressBar(ROOT::RDataFrame df); 0303 0304 class ProgressBarAction; 0305 0306 /// RDF progress helper. 0307 /// This class provides callback functions to the RDataFrame. The event statistics 0308 /// (including elapsed time, currently processed file, currently processed events, the rate of event processing 0309 /// and an estimated remaining time (per file being processed)) 0310 /// are recorded and printed in the terminal every m events and every n seconds. 0311 /// ProgressHelper::operator()(unsigned int, T&) is thread safe, and can be used as a callback in MT mode. 0312 /// ProgressBar should be added after creating the dataframe object (df): 0313 /// ~~~{.cpp} 0314 /// ROOT::RDataFrame df("tree", "file.root"); 0315 /// ROOT::RDF::Experimental::AddProgressBar(df); 0316 /// ~~~ 0317 /// alternatively RDataFrame can be cast to an RNode first giving it more flexibility. 0318 /// For example, it can be called at any computational node, such as Filter or Define, not only the head node, 0319 /// with no change to the ProgressBar function itself: 0320 /// ~~~{.cpp} 0321 /// ROOT::RDataFrame df("tree", "file.root"); 0322 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1")); 0323 /// ROOT::RDF::Experimental::AddProgressBar(df_1); 0324 /// ~~~ 0325 class ProgressHelper { 0326 private: 0327 double EvtPerSec() const; 0328 std::pair<std::size_t, std::chrono::seconds> RecordEvtCountAndTime(); 0329 void PrintStats(std::ostream &stream, std::size_t currentEventCount, std::chrono::seconds totalElapsedSeconds) const; 0330 void PrintStatsFinal(std::ostream &stream, std::chrono::seconds totalElapsedSeconds) const; 0331 void PrintProgressBar(std::ostream &stream, std::size_t currentEventCount) const; 0332 0333 std::chrono::time_point<std::chrono::system_clock> fBeginTime = std::chrono::system_clock::now(); 0334 std::chrono::time_point<std::chrono::system_clock> fLastPrintTime = fBeginTime; 0335 std::chrono::seconds fPrintInterval{1}; 0336 0337 std::atomic<std::size_t> fProcessedEvents{0}; 0338 std::size_t fLastProcessedEvents{0}; 0339 std::size_t fIncrement; 0340 0341 mutable std::mutex fSampleNameToEventEntriesMutex; 0342 std::map<std::string, ULong64_t> fSampleNameToEventEntries; // Filename, events in the file 0343 0344 std::array<double, 20> fEventsPerSecondStatistics; 0345 std::size_t fEventsPerSecondStatisticsIndex{0}; 0346 0347 unsigned int fBarWidth; 0348 unsigned int fTotalFiles; 0349 0350 std::mutex fPrintMutex; 0351 bool fIsTTY; 0352 bool fUseShellColours; 0353 0354 std::shared_ptr<TTree> fTree{nullptr}; 0355 0356 public: 0357 /// Create a progress helper. 0358 /// \param increment RDF callbacks are called every `n` events. Pass this `n` here. 0359 /// \param totalFiles read total number of files in the RDF. 0360 /// \param progressBarWidth Number of characters the progress bar will occupy. 0361 /// \param printInterval Update every stats every `n` seconds. 0362 /// \param useColors Use shell colour codes to colour the output. Automatically disabled when 0363 /// we are not writing to a tty. 0364 ProgressHelper(std::size_t increment, unsigned int totalFiles = 1, unsigned int progressBarWidth = 40, 0365 unsigned int printInterval = 1, bool useColors = true); 0366 0367 ~ProgressHelper() = default; 0368 0369 friend class ProgressBarAction; 0370 0371 /// Register a new sample for completion statistics. 0372 /// \see ROOT::RDF::RInterface::DefinePerSample(). 0373 /// The *id.AsString()* refers to the name of the currently processed file. 0374 /// The idea is to populate the event entries in the *fSampleNameToEventEntries* map 0375 /// by selecting the greater of the two values: 0376 /// *id.EntryRange().second* which is the upper event entry range of the processed sample 0377 /// and the current value of the event entries in the *fSampleNameToEventEntries* map. 0378 /// In the single threaded case, the two numbers are the same as the entry range corresponds 0379 /// to the number of events in an individual file (each sample is simply a single file). 0380 /// In the multithreaded case, the idea is to accumulate the higher event entry value until 0381 /// the total number of events in a given file is reached. 0382 void registerNewSample(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &id) 0383 { 0384 std::lock_guard<std::mutex> lock(fSampleNameToEventEntriesMutex); 0385 fSampleNameToEventEntries[id.AsString()] = 0386 std::max(id.EntryRange().second, fSampleNameToEventEntries[id.AsString()]); 0387 } 0388 0389 /// Thread-safe callback for RDataFrame. 0390 /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the 0391 /// fPrintInterval). \param slot Ignored. \param value Ignored. 0392 template <typename T> 0393 void operator()(unsigned int /*slot*/, T &value) 0394 { 0395 operator()(value); 0396 } 0397 // clang-format off 0398 /// Thread-safe callback for RDataFrame. 0399 /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the fPrintInterval). 0400 /// \param value Ignored. 0401 // clang-format on 0402 template <typename T> 0403 void operator()(T & /*value*/) 0404 { 0405 using namespace std::chrono; 0406 // *************************************************** 0407 // Warning: Here, everything needs to be thread safe: 0408 // *************************************************** 0409 fProcessedEvents += fIncrement; 0410 0411 // We only print every n seconds. 0412 if (duration_cast<seconds>(system_clock::now() - fLastPrintTime) < fPrintInterval) { 0413 return; 0414 } 0415 0416 // *************************************************** 0417 // Protected by lock from here: 0418 // *************************************************** 0419 if (!fPrintMutex.try_lock()) 0420 return; 0421 std::lock_guard<std::mutex> lockGuard(fPrintMutex, std::adopt_lock); 0422 0423 std::size_t eventCount; 0424 seconds elapsedSeconds; 0425 std::tie(eventCount, elapsedSeconds) = RecordEvtCountAndTime(); 0426 0427 if (fIsTTY) 0428 std::cout << "\r"; 0429 0430 PrintProgressBar(std::cout, eventCount); 0431 PrintStats(std::cout, eventCount, elapsedSeconds); 0432 0433 if (fIsTTY) 0434 std::cout << std::flush; 0435 else 0436 std::cout << std::endl; 0437 } 0438 0439 std::size_t ComputeNEventsSoFar() const 0440 { 0441 std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex); 0442 std::size_t result = 0; 0443 for (const auto &item : fSampleNameToEventEntries) 0444 result += item.second; 0445 return result; 0446 } 0447 0448 unsigned int ComputeCurrentFileIdx() const 0449 { 0450 std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex); 0451 return fSampleNameToEventEntries.size(); 0452 } 0453 }; 0454 } // namespace Experimental 0455 } // namespace RDF 0456 } // namespace ROOT 0457 #endif
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
![]() ![]() |