|
|
|||
File indexing completed on 2026-05-04 08:51:38
0001 // Author: Enrico Guiraud, Danilo Piparo CERN 02/2018 0002 0003 /************************************************************************* 0004 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. * 0005 * All rights reserved. * 0006 * * 0007 * For the licensing terms see $ROOTSYS/LICENSE. * 0008 * For the list of contributors see $ROOTSYS/README/CREDITS. * 0009 *************************************************************************/ 0010 0011 // This header contains helper free functions that slim down RDataFrame's programming model 0012 0013 #ifndef ROOT_RDF_HELPERS 0014 #define ROOT_RDF_HELPERS 0015 0016 #include <ROOT/RDF/GraphUtils.hxx> 0017 #include <ROOT/RDF/RActionBase.hxx> 0018 #include <ROOT/RDF/RResultMap.hxx> 0019 #include <ROOT/RResultHandle.hxx> // users of RunGraphs might rely on this transitive include 0020 #include <ROOT/TypeTraits.hxx> 0021 0022 #include <array> 0023 #include <chrono> 0024 #include <fstream> 0025 #include <functional> 0026 #include <map> 0027 #include <memory> 0028 #include <mutex> 0029 #include <type_traits> 0030 #include <utility> // std::index_sequence 0031 #include <vector> 0032 0033 namespace ROOT { 0034 namespace Internal { 0035 namespace RDF { 0036 template <typename... ArgTypes, typename F> 0037 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, F &&f) 0038 { 0039 return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); }); 0040 } 0041 0042 template <typename... ArgTypes, typename Ret, typename... Args> 0043 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, Ret (*f)(Args...)) 0044 { 0045 return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); }); 0046 } 0047 0048 template <typename I, typename T, typename F> 0049 class PassAsVecHelper; 0050 0051 template <std::size_t... N, typename T, typename F> 0052 class PassAsVecHelper<std::index_sequence<N...>, T, F> { 0053 template <std::size_t Idx> 0054 using AlwaysT = T; 0055 std::decay_t<F> fFunc; 0056 0057 public: 0058 PassAsVecHelper(F &&f) : fFunc(std::forward<F>(f)) {} 0059 auto operator()(AlwaysT<N>... args) -> decltype(fFunc({args...})) { return fFunc({args...}); } 0060 }; 0061 0062 template <std::size_t N, typename T, typename F> 0063 auto PassAsVec(F &&f) -> PassAsVecHelper<std::make_index_sequence<N>, T, F> 0064 { 0065 return PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f)); 0066 } 0067 0068 } // namespace RDF 0069 } // namespace Internal 0070 0071 namespace RDF { 0072 namespace RDFInternal = ROOT::Internal::RDF; 0073 0074 // clang-format off 0075 /// Given a callable with signature bool(T1, T2, ...) return a callable with same signature that returns the negated result 0076 /// 0077 /// The callable must have one single non-template definition of operator(). This is a limitation with respect to 0078 /// std::not_fn, required for interoperability with RDataFrame. 0079 // clang-format on 0080 template <typename F, 0081 typename Args = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::arg_types_nodecay, 0082 typename Ret = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::ret_type> 0083 auto Not(F &&f) -> decltype(RDFInternal::NotHelper(Args(), std::forward<F>(f))) 0084 { 0085 static_assert(std::is_same<Ret, bool>::value, "RDF::Not requires a callable that returns a bool."); 0086 return RDFInternal::NotHelper(Args(), std::forward<F>(f)); 0087 } 0088 0089 // clang-format off 0090 /// PassAsVec is a callable generator that allows passing N variables of type T to a function as a single collection. 0091 /// 0092 /// PassAsVec<N, T>(func) returns a callable that takes N arguments of type T, passes them down to function `func` as 0093 /// an initializer list `{t1, t2, t3,..., tN}` and returns whatever f({t1, t2, t3, ..., tN}) returns. 0094 /// 0095 /// Note that for this to work with RDataFrame the type of all columns that the callable is applied to must be exactly T. 0096 /// Example usage together with RDataFrame ("varX" columns must all be `float` variables): 0097 /// \code 0098 /// bool myVecFunc(std::vector<float> args); 0099 /// df.Filter(PassAsVec<3, float>(myVecFunc), {"var1", "var2", "var3"}); 0100 /// \endcode 0101 // clang-format on 0102 template <std::size_t N, typename T, typename F> 0103 auto PassAsVec(F &&f) -> RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F> 0104 { 0105 return RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f)); 0106 } 0107 0108 // clang-format off 0109 /// Create a graphviz representation of the dataframe computation graph, return it as a string. 0110 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to. 0111 /// 0112 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`. 0113 /// 0114 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are 0115 /// effectively optimized away from the computation graph. 0116 /// 0117 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads. 0118 // clang-format on 0119 template <typename NodeType> 0120 std::string SaveGraph(NodeType node) 0121 { 0122 ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper; 0123 return helper.RepresentGraph(node); 0124 } 0125 0126 // clang-format off 0127 /// Create a graphviz representation of the dataframe computation graph, write it to the specified file. 0128 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to. 0129 /// \param[in] outputFile file where to save the representation. 0130 /// 0131 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`. 0132 /// 0133 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are 0134 /// effectively optimized away from the computation graph. 0135 /// 0136 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads. 0137 // clang-format on 0138 template <typename NodeType> 0139 void SaveGraph(NodeType node, const std::string &outputFile) 0140 { 0141 ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper; 0142 std::string dotGraph = helper.RepresentGraph(node); 0143 0144 std::ofstream out(outputFile); 0145 if (!out.is_open()) { 0146 throw std::runtime_error("Could not open output file \"" + outputFile + "\"for reading"); 0147 } 0148 0149 out << dotGraph; 0150 out.close(); 0151 } 0152 0153 // clang-format off 0154 /// Cast a RDataFrame node to the common type ROOT::RDF::RNode 0155 /// \param[in] node Any node of a RDataFrame graph 0156 // clang-format on 0157 template <typename NodeType> 0158 RNode AsRNode(NodeType node) 0159 { 0160 return node; 0161 } 0162 0163 // clang-format off 0164 /// Run the event loops of multiple RDataFrames concurrently. 0165 /// \param[in] handles A vector of RResultHandles whose event loops should be run. 0166 /// \return The number of distinct computation graphs that have been processed. 0167 /// 0168 /// This function triggers the event loop of all computation graphs which relate to the 0169 /// given RResultHandles. The advantage compared to running the event loop implicitly by accessing the 0170 /// RResultPtr is that the event loops will run concurrently. Therefore, the overall 0171 /// computation of all results can be scheduled more efficiently. 0172 /// It should be noted that user-defined operations (e.g., Filters and Defines) of the different RDataFrame graphs are assumed to be safe to call concurrently. 0173 /// RDataFrame will pass slot numbers in the range [0, NThread-1] to all helpers used in nodes such as DefineSlot. NThread is the number of threads ROOT was 0174 /// configured with in EnableImplicitMT(). 0175 /// Slot numbers are unique across all graphs, so no two tasks with the same slot number will run concurrently. Note that it is not guaranteed that each slot 0176 /// number will be reached in every graph. 0177 /// 0178 /// ~~~{.cpp} 0179 /// ROOT::RDataFrame df1("tree1", "file1.root"); 0180 /// auto r1 = df1.Histo1D("var1"); 0181 /// 0182 /// ROOT::RDataFrame df2("tree2", "file2.root"); 0183 /// auto r2 = df2.Sum("var2"); 0184 /// 0185 /// // RResultPtr -> RResultHandle conversion is automatic 0186 /// ROOT::RDF::RunGraphs({r1, r2}); 0187 /// ~~~ 0188 // clang-format on 0189 unsigned int RunGraphs(std::vector<RResultHandle> handles); 0190 0191 namespace Experimental { 0192 0193 /// \brief Produce all required systematic variations for the given result. 0194 /// \param[in] resPtr The result for which variations should be produced. 0195 /// \return A \ref ROOT::RDF::Experimental::RResultMap "RResultMap" object with full variation names as strings 0196 /// (e.g. "pt:down") and the corresponding varied results as values. 0197 /// 0198 /// A given input RResultPtr<T> produces a corresponding RResultMap<T> with a "nominal" 0199 /// key that will return a value identical to the one contained in the original RResultPtr. 0200 /// Other keys correspond to the varied values of this result, one for each variation 0201 /// that the result depends on. 0202 /// VariationsFor does not trigger the event loop. The event loop is only triggered 0203 /// upon first access to a valid key, similarly to what happens with RResultPtr. 0204 /// 0205 /// If the result does not depend, directly or indirectly, from any registered systematic variation, the 0206 /// returned RResultMap will contain only the "nominal" key. 0207 /// 0208 /// See RDataFrame's \ref ROOT::RDF::RInterface::Vary() "Vary" method for more information and example usages. 0209 /// 0210 /// \note Currently, producing variations for the results of \ref ROOT::RDF::RInterface::Display() "Display", 0211 /// \ref ROOT::RDF::RInterface::Report() "Report" and \ref ROOT::RDF::RInterface::Snapshot() "Snapshot" 0212 /// actions is not supported. 0213 // 0214 // An overview of how systematic variations work internally. Given N variations (including the nominal): 0215 // 0216 // RResultMap owns RVariedAction 0217 // N results N action helpers 0218 // N previous filters 0219 // N*#input_cols column readers 0220 // 0221 // ...and each RFilter and RDefine knows for what universe it needs to construct column readers ("nominal" by default). 0222 template <typename T> 0223 RResultMap<T> VariationsFor(RResultPtr<T> resPtr) 0224 { 0225 using SnapshotResult_t = ROOT::RDF::RInterface<ROOT::Detail::RDF::RLoopManager, void>; 0226 static_assert(!std::is_same_v<T, SnapshotResult_t>, 0227 "Snapshot with variations can only be enabled via RSnapshotOptions."); 0228 0229 R__ASSERT(resPtr != nullptr && "Calling VariationsFor on an empty RResultPtr"); 0230 0231 // populate parts of the computation graph for which we only have "empty shells", e.g. RJittedActions and 0232 // RJittedFilters 0233 resPtr.fLoopManager->Jit(); 0234 0235 std::unique_ptr<RDFInternal::RActionBase> variedAction; 0236 std::vector<std::shared_ptr<T>> variedResults; 0237 0238 std::shared_ptr<RDFInternal::RActionBase> nominalAction = resPtr.fActionPtr; 0239 std::vector<std::string> variations = nominalAction->GetVariations(); 0240 const auto nVariations = variations.size(); 0241 0242 if (nVariations > 0) { 0243 // clone the result once for each variation 0244 variedResults.reserve(nVariations); 0245 for (auto i = 0u; i < nVariations; ++i){ 0246 // implicitly assuming that T is copiable: this should be the case 0247 // for all result types in use, as they are copied for each slot 0248 variedResults.emplace_back(new T{*resPtr.fObjPtr}); 0249 0250 // Check if the result's type T inherits from TNamed 0251 if constexpr (std::is_base_of<TNamed, T>::value) { 0252 // Get the current variation name 0253 std::string variationName = variations[i]; 0254 // Replace the colon with an underscore 0255 std::replace(variationName.begin(), variationName.end(), ':', '_'); 0256 // Get a pointer to the corresponding varied result 0257 auto &variedResult = variedResults.back(); 0258 // Set the varied result's name to NOMINALNAME_VARIATIONAME 0259 variedResult->SetName((std::string(variedResult->GetName()) + "_" + variationName).c_str()); 0260 } 0261 } 0262 0263 std::vector<void *> typeErasedResults; 0264 typeErasedResults.reserve(variedResults.size()); 0265 for (auto &res : variedResults) 0266 typeErasedResults.emplace_back(&res); 0267 0268 // Create the RVariedAction and inject it in the computation graph. 0269 // This recursively creates all the required varied column readers and upstream nodes of the computation graph. 0270 variedAction = nominalAction->MakeVariedAction(std::move(typeErasedResults)); 0271 } 0272 0273 return RDFInternal::MakeResultMap<T>(resPtr.fObjPtr, std::move(variedResults), std::move(variations), 0274 *resPtr.fLoopManager, std::move(nominalAction), std::move(variedAction)); 0275 } 0276 0277 /// \brief Add ProgressBar to a ROOT::RDF::RNode 0278 /// \param[in] df RDataFrame node at which ProgressBar is called. 0279 /// 0280 /// The ProgressBar can be added not only at the RDataFrame head node, but also at any any computational node, 0281 /// such as Filter or Define. 0282 /// ###Example usage: 0283 /// ~~~{.cpp} 0284 /// ROOT::RDataFrame df("tree", "file.root"); 0285 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1")); 0286 /// ROOT::RDF::Experimental::AddProgressBar(df_1); 0287 /// ~~~ 0288 void AddProgressBar(ROOT::RDF::RNode df); 0289 0290 /// \brief Add ProgressBar to an RDataFrame 0291 /// \param[in] df RDataFrame for which ProgressBar is called. 0292 /// 0293 /// This function adds a ProgressBar to display the event statistics in the terminal every 0294 /// \b m events and every \b n seconds, including elapsed time, currently processed file, 0295 /// currently processed events, the rate of event processing 0296 /// and an estimated remaining time (per file being processed). 0297 /// ProgressBar should be added after the dataframe object (df) is created first: 0298 /// ~~~{.cpp} 0299 /// ROOT::RDataFrame df("tree", "file.root"); 0300 /// ROOT::RDF::Experimental::AddProgressBar(df); 0301 /// ~~~ 0302 /// For more details see ROOT::RDF::Experimental::ProgressHelper Class. 0303 void AddProgressBar(ROOT::RDataFrame df); 0304 0305 /// @brief Set the number of threads sharing one TH3 in RDataFrame. 0306 /// When RDF runs multi-threaded, each thread typically clones every histogram in the computation graph. 0307 /// If this consumes too much memory, N threads can share one clone. 0308 /// Higher values might slow down RDF because they lead to higher contention on the TH3Ds, but save memory. 0309 /// Lower values run faster with less contention at the cost of higher memory usage. 0310 /// @param nThread Number of threads that share a TH3D. 0311 void ThreadsPerTH3(unsigned int nThread = 1); 0312 0313 class ProgressBarAction; 0314 0315 /// RDF progress helper. 0316 /// This class provides callback functions to the RDataFrame. The event statistics 0317 /// (including elapsed time, currently processed file, currently processed events, the rate of event processing 0318 /// and an estimated remaining time (per file being processed)) 0319 /// are recorded and printed in the terminal every m events and every n seconds. 0320 /// ProgressHelper::operator()(unsigned int, T&) is thread safe, and can be used as a callback in MT mode. 0321 /// ProgressBar should be added after creating the dataframe object (df): 0322 /// ~~~{.cpp} 0323 /// ROOT::RDataFrame df("tree", "file.root"); 0324 /// ROOT::RDF::Experimental::AddProgressBar(df); 0325 /// ~~~ 0326 /// alternatively RDataFrame can be cast to an RNode first giving it more flexibility. 0327 /// For example, it can be called at any computational node, such as Filter or Define, not only the head node, 0328 /// with no change to the ProgressBar function itself: 0329 /// ~~~{.cpp} 0330 /// ROOT::RDataFrame df("tree", "file.root"); 0331 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1")); 0332 /// ROOT::RDF::Experimental::AddProgressBar(df_1); 0333 /// ~~~ 0334 class ProgressHelper { 0335 private: 0336 double EvtPerSec() const; 0337 std::pair<std::size_t, std::chrono::seconds> RecordEvtCountAndTime(); 0338 void PrintStats(std::ostream &stream, std::size_t currentEventCount, std::chrono::seconds totalElapsedSeconds) const; 0339 void PrintStatsFinal(std::ostream &stream, std::chrono::seconds totalElapsedSeconds) const; 0340 void PrintProgressBar(std::ostream &stream, std::size_t currentEventCount) const; 0341 0342 std::chrono::time_point<std::chrono::system_clock> fBeginTime = std::chrono::system_clock::now(); 0343 std::chrono::time_point<std::chrono::system_clock> fLastPrintTime = fBeginTime; 0344 std::chrono::seconds fPrintInterval{1}; 0345 0346 std::atomic<std::size_t> fProcessedEvents{0}; 0347 std::size_t fLastProcessedEvents{0}; 0348 std::size_t fIncrement; 0349 0350 mutable std::mutex fSampleNameToEventEntriesMutex; 0351 std::map<std::string, ULong64_t> fSampleNameToEventEntries; // Filename, events in the file 0352 0353 std::array<double, 20> fEventsPerSecondStatistics; 0354 std::size_t fEventsPerSecondStatisticsIndex{0}; 0355 0356 unsigned int fBarWidth; 0357 unsigned int fTotalFiles; 0358 0359 std::mutex fPrintMutex; 0360 bool fIsTTY; 0361 bool fUseShellColours; 0362 0363 std::shared_ptr<TTree> fTree{nullptr}; 0364 0365 public: 0366 /// Create a progress helper. 0367 /// \param increment RDF callbacks are called every `n` events. Pass this `n` here. 0368 /// \param totalFiles read total number of files in the RDF. 0369 /// \param progressBarWidth Number of characters the progress bar will occupy. 0370 /// \param printInterval Update every stats every `n` seconds. 0371 /// \param useColors Use shell colour codes to colour the output. Automatically disabled when 0372 /// we are not writing to a tty. 0373 ProgressHelper(std::size_t increment, unsigned int totalFiles = 1, unsigned int progressBarWidth = 40, 0374 unsigned int printInterval = 1, bool useColors = true); 0375 0376 ~ProgressHelper() = default; 0377 0378 friend class ProgressBarAction; 0379 0380 /// Register a new sample for completion statistics. 0381 /// \see ROOT::RDF::RInterface::DefinePerSample(). 0382 /// The *id.AsString()* refers to the name of the currently processed file. 0383 /// The idea is to populate the event entries in the *fSampleNameToEventEntries* map 0384 /// by selecting the greater of the two values: 0385 /// *id.EntryRange().second* which is the upper event entry range of the processed sample 0386 /// and the current value of the event entries in the *fSampleNameToEventEntries* map. 0387 /// In the single threaded case, the two numbers are the same as the entry range corresponds 0388 /// to the number of events in an individual file (each sample is simply a single file). 0389 /// In the multithreaded case, the idea is to accumulate the higher event entry value until 0390 /// the total number of events in a given file is reached. 0391 void registerNewSample(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &id) 0392 { 0393 std::lock_guard<std::mutex> lock(fSampleNameToEventEntriesMutex); 0394 fSampleNameToEventEntries[id.AsString()] = 0395 std::max(id.EntryRange().second, fSampleNameToEventEntries[id.AsString()]); 0396 } 0397 0398 /// Thread-safe callback for RDataFrame. 0399 /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the 0400 /// fPrintInterval). \param slot Ignored. \param value Ignored. 0401 template <typename T> 0402 void operator()(unsigned int /*slot*/, T &value) 0403 { 0404 operator()(value); 0405 } 0406 // clang-format off 0407 /// Thread-safe callback for RDataFrame. 0408 /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the fPrintInterval). 0409 /// \param value Ignored. 0410 // clang-format on 0411 template <typename T> 0412 void operator()(T & /*value*/) 0413 { 0414 using namespace std::chrono; 0415 // *************************************************** 0416 // Warning: Here, everything needs to be thread safe: 0417 // *************************************************** 0418 fProcessedEvents += fIncrement; 0419 0420 // We only print every n seconds. 0421 if (duration_cast<seconds>(system_clock::now() - fLastPrintTime) < fPrintInterval) { 0422 return; 0423 } 0424 0425 // *************************************************** 0426 // Protected by lock from here: 0427 // *************************************************** 0428 if (!fPrintMutex.try_lock()) 0429 return; 0430 std::lock_guard<std::mutex> lockGuard(fPrintMutex, std::adopt_lock); 0431 0432 std::size_t eventCount; 0433 seconds elapsedSeconds; 0434 std::tie(eventCount, elapsedSeconds) = RecordEvtCountAndTime(); 0435 0436 if (fIsTTY) 0437 std::cout << "\r"; 0438 0439 PrintProgressBar(std::cout, eventCount); 0440 PrintStats(std::cout, eventCount, elapsedSeconds); 0441 0442 if (fIsTTY) 0443 std::cout << std::flush; 0444 else 0445 std::cout << std::endl; 0446 } 0447 0448 std::size_t ComputeNEventsSoFar() const 0449 { 0450 std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex); 0451 std::size_t result = 0; 0452 for (const auto &item : fSampleNameToEventEntries) 0453 result += item.second; 0454 return result; 0455 } 0456 0457 unsigned int ComputeCurrentFileIdx() const 0458 { 0459 std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex); 0460 return fSampleNameToEventEntries.size(); 0461 } 0462 }; 0463 } // namespace Experimental 0464 } // namespace RDF 0465 } // namespace ROOT 0466 #endif
| [ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
|
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
|