|
||||
File indexing completed on 2025-01-18 10:10:38
0001 // Author: Enrico Guiraud, Danilo Piparo CERN 02/2018 0002 0003 /************************************************************************* 0004 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. * 0005 * All rights reserved. * 0006 * * 0007 * For the licensing terms see $ROOTSYS/LICENSE. * 0008 * For the list of contributors see $ROOTSYS/README/CREDITS. * 0009 *************************************************************************/ 0010 0011 // This header contains helper free functions that slim down RDataFrame's programming model 0012 0013 #ifndef ROOT_RDF_HELPERS 0014 #define ROOT_RDF_HELPERS 0015 0016 #include <ROOT/RDF/GraphUtils.hxx> 0017 #include <ROOT/RDF/RActionBase.hxx> 0018 #include <ROOT/RDF/RResultMap.hxx> 0019 #include <ROOT/RResultHandle.hxx> // users of RunGraphs might rely on this transitive include 0020 #include <ROOT/TypeTraits.hxx> 0021 0022 #include <array> 0023 #include <chrono> 0024 #include <fstream> 0025 #include <functional> 0026 #include <map> 0027 #include <memory> 0028 #include <mutex> 0029 #include <type_traits> 0030 #include <utility> // std::index_sequence 0031 #include <vector> 0032 0033 namespace ROOT { 0034 namespace Internal { 0035 namespace RDF { 0036 template <typename... ArgTypes, typename F> 0037 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, F &&f) 0038 { 0039 return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); }); 0040 } 0041 0042 template <typename... ArgTypes, typename Ret, typename... Args> 0043 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, Ret (*f)(Args...)) 0044 { 0045 return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); }); 0046 } 0047 0048 template <typename I, typename T, typename F> 0049 class PassAsVecHelper; 0050 0051 template <std::size_t... N, typename T, typename F> 0052 class PassAsVecHelper<std::index_sequence<N...>, T, F> { 0053 template <std::size_t Idx> 0054 using AlwaysT = T; 0055 std::decay_t<F> fFunc; 0056 0057 public: 0058 PassAsVecHelper(F &&f) : fFunc(std::forward<F>(f)) {} 0059 auto operator()(AlwaysT<N>... args) -> decltype(fFunc({args...})) { return fFunc({args...}); } 0060 }; 0061 0062 template <std::size_t N, typename T, typename F> 0063 auto PassAsVec(F &&f) -> PassAsVecHelper<std::make_index_sequence<N>, T, F> 0064 { 0065 return PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f)); 0066 } 0067 0068 } // namespace RDF 0069 } // namespace Internal 0070 0071 namespace RDF { 0072 namespace RDFInternal = ROOT::Internal::RDF; 0073 0074 // clang-format off 0075 /// Given a callable with signature bool(T1, T2, ...) return a callable with same signature that returns the negated result 0076 /// 0077 /// The callable must have one single non-template definition of operator(). This is a limitation with respect to 0078 /// std::not_fn, required for interoperability with RDataFrame. 0079 // clang-format on 0080 template <typename F, 0081 typename Args = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::arg_types_nodecay, 0082 typename Ret = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::ret_type> 0083 auto Not(F &&f) -> decltype(RDFInternal::NotHelper(Args(), std::forward<F>(f))) 0084 { 0085 static_assert(std::is_same<Ret, bool>::value, "RDF::Not requires a callable that returns a bool."); 0086 return RDFInternal::NotHelper(Args(), std::forward<F>(f)); 0087 } 0088 0089 // clang-format off 0090 /// PassAsVec is a callable generator that allows passing N variables of type T to a function as a single collection. 0091 /// 0092 /// PassAsVec<N, T>(func) returns a callable that takes N arguments of type T, passes them down to function `func` as 0093 /// an initializer list `{t1, t2, t3,..., tN}` and returns whatever f({t1, t2, t3, ..., tN}) returns. 0094 /// 0095 /// Note that for this to work with RDataFrame the type of all columns that the callable is applied to must be exactly T. 0096 /// Example usage together with RDataFrame ("varX" columns must all be `float` variables): 0097 /// \code 0098 /// bool myVecFunc(std::vector<float> args); 0099 /// df.Filter(PassAsVec<3, float>(myVecFunc), {"var1", "var2", "var3"}); 0100 /// \endcode 0101 // clang-format on 0102 template <std::size_t N, typename T, typename F> 0103 auto PassAsVec(F &&f) -> RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F> 0104 { 0105 return RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f)); 0106 } 0107 0108 // clang-format off 0109 /// Create a graphviz representation of the dataframe computation graph, return it as a string. 0110 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to. 0111 /// 0112 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`. 0113 /// 0114 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are 0115 /// effectively optimized away from the computation graph. 0116 /// 0117 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads. 0118 // clang-format on 0119 template <typename NodeType> 0120 std::string SaveGraph(NodeType node) 0121 { 0122 ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper; 0123 return helper.RepresentGraph(node); 0124 } 0125 0126 // clang-format off 0127 /// Create a graphviz representation of the dataframe computation graph, write it to the specified file. 0128 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to. 0129 /// \param[in] outputFile file where to save the representation. 0130 /// 0131 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`. 0132 /// 0133 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are 0134 /// effectively optimized away from the computation graph. 0135 /// 0136 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads. 0137 // clang-format on 0138 template <typename NodeType> 0139 void SaveGraph(NodeType node, const std::string &outputFile) 0140 { 0141 ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper; 0142 std::string dotGraph = helper.RepresentGraph(node); 0143 0144 std::ofstream out(outputFile); 0145 if (!out.is_open()) { 0146 throw std::runtime_error("Could not open output file \"" + outputFile + "\"for reading"); 0147 } 0148 0149 out << dotGraph; 0150 out.close(); 0151 } 0152 0153 // clang-format off 0154 /// Cast a RDataFrame node to the common type ROOT::RDF::RNode 0155 /// \param[in] node Any node of a RDataFrame graph 0156 // clang-format on 0157 template <typename NodeType> 0158 RNode AsRNode(NodeType node) 0159 { 0160 return node; 0161 } 0162 0163 // clang-format off 0164 /// Trigger the event loop of multiple RDataFrames concurrently 0165 /// \param[in] handles A vector of RResultHandles 0166 /// \return The number of distinct computation graphs that have been processed 0167 /// 0168 /// This function triggers the event loop of all computation graphs which relate to the 0169 /// given RResultHandles. The advantage compared to running the event loop implicitly by accessing the 0170 /// RResultPtr is that the event loops will run concurrently. Therefore, the overall 0171 /// computation of all results is generally more efficient. 0172 /// It should be noted that user-defined operations (e.g., Filters and Defines) of the different RDataFrame graphs are assumed to be safe to call concurrently. 0173 /// 0174 /// ~~~{.cpp} 0175 /// ROOT::RDataFrame df1("tree1", "file1.root"); 0176 /// auto r1 = df1.Histo1D("var1"); 0177 /// 0178 /// ROOT::RDataFrame df2("tree2", "file2.root"); 0179 /// auto r2 = df2.Sum("var2"); 0180 /// 0181 /// // RResultPtr -> RResultHandle conversion is automatic 0182 /// ROOT::RDF::RunGraphs({r1, r2}); 0183 /// ~~~ 0184 // clang-format on 0185 unsigned int RunGraphs(std::vector<RResultHandle> handles); 0186 0187 namespace Experimental { 0188 0189 /// \brief Produce all required systematic variations for the given result. 0190 /// \param[in] resPtr The result for which variations should be produced. 0191 /// \return A \ref ROOT::RDF::Experimental::RResultMap "RResultMap" object with full variation names as strings 0192 /// (e.g. "pt:down") and the corresponding varied results as values. 0193 /// 0194 /// A given input RResultPtr<T> produces a corresponding RResultMap<T> with a "nominal" 0195 /// key that will return a value identical to the one contained in the original RResultPtr. 0196 /// Other keys correspond to the varied values of this result, one for each variation 0197 /// that the result depends on. 0198 /// VariationsFor does not trigger the event loop. The event loop is only triggered 0199 /// upon first access to a valid key, similarly to what happens with RResultPtr. 0200 /// 0201 /// If the result does not depend, directly or indirectly, from any registered systematic variation, the 0202 /// returned RResultMap will contain only the "nominal" key. 0203 /// 0204 /// See RDataFrame's \ref ROOT::RDF::RInterface::Vary() "Vary" method for more information and example usages. 0205 /// 0206 /// \note Currently, producing variations for the results of \ref ROOT::RDF::RInterface::Display() "Display", 0207 /// \ref ROOT::RDF::RInterface::Report() "Report" and \ref ROOT::RDF::RInterface::Snapshot() "Snapshot" 0208 /// actions is not supported. 0209 // 0210 // An overview of how systematic variations work internally. Given N variations (including the nominal): 0211 // 0212 // RResultMap owns RVariedAction 0213 // N results N action helpers 0214 // N previous filters 0215 // N*#input_cols column readers 0216 // 0217 // ...and each RFilter and RDefine knows for what universe it needs to construct column readers ("nominal" by default). 0218 template <typename T> 0219 RResultMap<T> VariationsFor(RResultPtr<T> resPtr) 0220 { 0221 R__ASSERT(resPtr != nullptr && "Calling VariationsFor on an empty RResultPtr"); 0222 0223 // populate parts of the computation graph for which we only have "empty shells", e.g. RJittedActions and 0224 // RJittedFilters 0225 resPtr.fLoopManager->Jit(); 0226 0227 std::unique_ptr<RDFInternal::RActionBase> variedAction; 0228 std::vector<std::shared_ptr<T>> variedResults; 0229 0230 std::shared_ptr<RDFInternal::RActionBase> nominalAction = resPtr.fActionPtr; 0231 std::vector<std::string> variations = nominalAction->GetVariations(); 0232 const auto nVariations = variations.size(); 0233 0234 if (nVariations > 0) { 0235 // clone the result once for each variation 0236 variedResults.reserve(nVariations); 0237 for (auto i = 0u; i < nVariations; ++i){ 0238 // implicitly assuming that T is copiable: this should be the case 0239 // for all result types in use, as they are copied for each slot 0240 variedResults.emplace_back(new T{*resPtr.fObjPtr}); 0241 0242 // Check if the result's type T inherits from TNamed 0243 if constexpr (std::is_base_of<TNamed, T>::value) { 0244 // Get the current variation name 0245 std::string variationName = variations[i]; 0246 // Replace the colon with an underscore 0247 std::replace(variationName.begin(), variationName.end(), ':', '_'); 0248 // Get a pointer to the corresponding varied result 0249 auto &variedResult = variedResults.back(); 0250 // Set the varied result's name to NOMINALNAME_VARIATIONAME 0251 variedResult->SetName((std::string(variedResult->GetName()) + "_" + variationName).c_str()); 0252 } 0253 } 0254 0255 std::vector<void *> typeErasedResults; 0256 typeErasedResults.reserve(variedResults.size()); 0257 for (auto &res : variedResults) 0258 typeErasedResults.emplace_back(&res); 0259 0260 // Create the RVariedAction and inject it in the computation graph. 0261 // This recursively creates all the required varied column readers and upstream nodes of the computation graph. 0262 variedAction = nominalAction->MakeVariedAction(std::move(typeErasedResults)); 0263 } 0264 0265 return RDFInternal::MakeResultMap<T>(resPtr.fObjPtr, std::move(variedResults), std::move(variations), 0266 *resPtr.fLoopManager, std::move(nominalAction), std::move(variedAction)); 0267 } 0268 0269 using SnapshotPtr_t = ROOT::RDF::RResultPtr<ROOT::RDF::RInterface<ROOT::Detail::RDF::RLoopManager, void>>; 0270 SnapshotPtr_t VariationsFor(SnapshotPtr_t resPtr); 0271 0272 /// \brief Add ProgressBar to a ROOT::RDF::RNode 0273 /// \param[in] df RDataFrame node at which ProgressBar is called. 0274 /// 0275 /// The ProgressBar can be added not only at the RDataFrame head node, but also at any any computational node, 0276 /// such as Filter or Define. 0277 /// ###Example usage: 0278 /// ~~~{.cpp} 0279 /// ROOT::RDataFrame df("tree", "file.root"); 0280 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1")); 0281 /// ROOT::RDF::Experimental::AddProgressBar(df_1); 0282 /// ~~~ 0283 void AddProgressBar(ROOT::RDF::RNode df); 0284 0285 /// \brief Add ProgressBar to an RDataFrame 0286 /// \param[in] df RDataFrame for which ProgressBar is called. 0287 /// 0288 /// This function adds a ProgressBar to display the event statistics in the terminal every 0289 /// \b m events and every \b n seconds, including elapsed time, currently processed file, 0290 /// currently processed events, the rate of event processing 0291 /// and an estimated remaining time (per file being processed). 0292 /// ProgressBar should be added after the dataframe object (df) is created first: 0293 /// ~~~{.cpp} 0294 /// ROOT::RDataFrame df("tree", "file.root"); 0295 /// ROOT::RDF::Experimental::AddProgressBar(df); 0296 /// ~~~ 0297 /// For more details see ROOT::RDF::Experimental::ProgressHelper Class. 0298 void AddProgressBar(ROOT::RDataFrame df); 0299 0300 class ProgressBarAction; 0301 0302 /// RDF progress helper. 0303 /// This class provides callback functions to the RDataFrame. The event statistics 0304 /// (including elapsed time, currently processed file, currently processed events, the rate of event processing 0305 /// and an estimated remaining time (per file being processed)) 0306 /// are recorded and printed in the terminal every m events and every n seconds. 0307 /// ProgressHelper::operator()(unsigned int, T&) is thread safe, and can be used as a callback in MT mode. 0308 /// ProgressBar should be added after creating the dataframe object (df): 0309 /// ~~~{.cpp} 0310 /// ROOT::RDataFrame df("tree", "file.root"); 0311 /// ROOT::RDF::Experimental::AddProgressBar(df); 0312 /// ~~~ 0313 /// alternatively RDataFrame can be cast to an RNode first giving it more flexibility. 0314 /// For example, it can be called at any computational node, such as Filter or Define, not only the head node, 0315 /// with no change to the ProgressBar function itself: 0316 /// ~~~{.cpp} 0317 /// ROOT::RDataFrame df("tree", "file.root"); 0318 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1")); 0319 /// ROOT::RDF::Experimental::AddProgressBar(df_1); 0320 /// ~~~ 0321 class ProgressHelper { 0322 private: 0323 double EvtPerSec() const; 0324 std::pair<std::size_t, std::chrono::seconds> RecordEvtCountAndTime(); 0325 void PrintStats(std::ostream &stream, std::size_t currentEventCount, std::chrono::seconds totalElapsedSeconds) const; 0326 void PrintStatsFinal(std::ostream &stream, std::chrono::seconds totalElapsedSeconds) const; 0327 void PrintProgressBar(std::ostream &stream, std::size_t currentEventCount) const; 0328 0329 std::chrono::time_point<std::chrono::system_clock> fBeginTime = std::chrono::system_clock::now(); 0330 std::chrono::time_point<std::chrono::system_clock> fLastPrintTime = fBeginTime; 0331 std::chrono::seconds fPrintInterval{1}; 0332 0333 std::atomic<std::size_t> fProcessedEvents{0}; 0334 std::size_t fLastProcessedEvents{0}; 0335 std::size_t fIncrement; 0336 0337 mutable std::mutex fSampleNameToEventEntriesMutex; 0338 std::map<std::string, ULong64_t> fSampleNameToEventEntries; // Filename, events in the file 0339 0340 std::array<double, 20> fEventsPerSecondStatistics; 0341 std::size_t fEventsPerSecondStatisticsIndex{0}; 0342 0343 unsigned int fBarWidth; 0344 unsigned int fTotalFiles; 0345 0346 std::mutex fPrintMutex; 0347 bool fIsTTY; 0348 bool fUseShellColours; 0349 0350 std::shared_ptr<TTree> fTree{nullptr}; 0351 0352 public: 0353 /// Create a progress helper. 0354 /// \param increment RDF callbacks are called every `n` events. Pass this `n` here. 0355 /// \param totalFiles read total number of files in the RDF. 0356 /// \param progressBarWidth Number of characters the progress bar will occupy. 0357 /// \param printInterval Update every stats every `n` seconds. 0358 /// \param useColors Use shell colour codes to colour the output. Automatically disabled when 0359 /// we are not writing to a tty. 0360 ProgressHelper(std::size_t increment, unsigned int totalFiles = 1, unsigned int progressBarWidth = 40, 0361 unsigned int printInterval = 1, bool useColors = true); 0362 0363 ~ProgressHelper() = default; 0364 0365 friend class ProgressBarAction; 0366 0367 /// Register a new sample for completion statistics. 0368 /// \see ROOT::RDF::RInterface::DefinePerSample(). 0369 /// The *id.AsString()* refers to the name of the currently processed file. 0370 /// The idea is to populate the event entries in the *fSampleNameToEventEntries* map 0371 /// by selecting the greater of the two values: 0372 /// *id.EntryRange().second* which is the upper event entry range of the processed sample 0373 /// and the current value of the event entries in the *fSampleNameToEventEntries* map. 0374 /// In the single threaded case, the two numbers are the same as the entry range corresponds 0375 /// to the number of events in an individual file (each sample is simply a single file). 0376 /// In the multithreaded case, the idea is to accumulate the higher event entry value until 0377 /// the total number of events in a given file is reached. 0378 void registerNewSample(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &id) 0379 { 0380 std::lock_guard<std::mutex> lock(fSampleNameToEventEntriesMutex); 0381 fSampleNameToEventEntries[id.AsString()] = 0382 std::max(id.EntryRange().second, fSampleNameToEventEntries[id.AsString()]); 0383 } 0384 0385 /// Thread-safe callback for RDataFrame. 0386 /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the 0387 /// fPrintInterval). \param slot Ignored. \param value Ignored. 0388 template <typename T> 0389 void operator()(unsigned int /*slot*/, T &value) 0390 { 0391 operator()(value); 0392 } 0393 // clang-format off 0394 /// Thread-safe callback for RDataFrame. 0395 /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the fPrintInterval). 0396 /// \param value Ignored. 0397 // clang-format on 0398 template <typename T> 0399 void operator()(T & /*value*/) 0400 { 0401 using namespace std::chrono; 0402 // *************************************************** 0403 // Warning: Here, everything needs to be thread safe: 0404 // *************************************************** 0405 fProcessedEvents += fIncrement; 0406 0407 // We only print every n seconds. 0408 if (duration_cast<seconds>(system_clock::now() - fLastPrintTime) < fPrintInterval) { 0409 return; 0410 } 0411 0412 // *************************************************** 0413 // Protected by lock from here: 0414 // *************************************************** 0415 if (!fPrintMutex.try_lock()) 0416 return; 0417 std::lock_guard<std::mutex> lockGuard(fPrintMutex, std::adopt_lock); 0418 0419 std::size_t eventCount; 0420 seconds elapsedSeconds; 0421 std::tie(eventCount, elapsedSeconds) = RecordEvtCountAndTime(); 0422 0423 if (fIsTTY) 0424 std::cout << "\r"; 0425 0426 PrintProgressBar(std::cout, eventCount); 0427 PrintStats(std::cout, eventCount, elapsedSeconds); 0428 0429 if (fIsTTY) 0430 std::cout << std::flush; 0431 else 0432 std::cout << std::endl; 0433 } 0434 0435 std::size_t ComputeNEventsSoFar() const 0436 { 0437 std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex); 0438 std::size_t result = 0; 0439 for (const auto &item : fSampleNameToEventEntries) 0440 result += item.second; 0441 return result; 0442 } 0443 0444 unsigned int ComputeCurrentFileIdx() const 0445 { 0446 std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex); 0447 return fSampleNameToEventEntries.size(); 0448 } 0449 }; 0450 } // namespace Experimental 0451 } // namespace RDF 0452 } // namespace ROOT 0453 #endif
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |