Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:38

0001 // Author: Enrico Guiraud, Danilo Piparo CERN  02/2018
0002 
0003 /*************************************************************************
0004  * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers.               *
0005  * All rights reserved.                                                  *
0006  *                                                                       *
0007  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0008  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0009  *************************************************************************/
0010 
0011 // This header contains helper free functions that slim down RDataFrame's programming model
0012 
0013 #ifndef ROOT_RDF_HELPERS
0014 #define ROOT_RDF_HELPERS
0015 
0016 #include <ROOT/RDF/GraphUtils.hxx>
0017 #include <ROOT/RDF/RActionBase.hxx>
0018 #include <ROOT/RDF/RResultMap.hxx>
0019 #include <ROOT/RResultHandle.hxx> // users of RunGraphs might rely on this transitive include
0020 #include <ROOT/TypeTraits.hxx>
0021 
0022 #include <array>
0023 #include <chrono>
0024 #include <fstream>
0025 #include <functional>
0026 #include <map>
0027 #include <memory>
0028 #include <mutex>
0029 #include <type_traits>
0030 #include <utility> // std::index_sequence
0031 #include <vector>
0032 
0033 namespace ROOT {
0034 namespace Internal {
0035 namespace RDF {
0036 template <typename... ArgTypes, typename F>
0037 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, F &&f)
0038 {
0039    return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); });
0040 }
0041 
0042 template <typename... ArgTypes, typename Ret, typename... Args>
0043 std::function<bool(ArgTypes...)> NotHelper(ROOT::TypeTraits::TypeList<ArgTypes...>, Ret (*f)(Args...))
0044 {
0045    return std::function<bool(ArgTypes...)>([=](ArgTypes... args) mutable { return !f(args...); });
0046 }
0047 
0048 template <typename I, typename T, typename F>
0049 class PassAsVecHelper;
0050 
0051 template <std::size_t... N, typename T, typename F>
0052 class PassAsVecHelper<std::index_sequence<N...>, T, F> {
0053    template <std::size_t Idx>
0054    using AlwaysT = T;
0055    std::decay_t<F> fFunc;
0056 
0057 public:
0058    PassAsVecHelper(F &&f) : fFunc(std::forward<F>(f)) {}
0059    auto operator()(AlwaysT<N>... args) -> decltype(fFunc({args...})) { return fFunc({args...}); }
0060 };
0061 
0062 template <std::size_t N, typename T, typename F>
0063 auto PassAsVec(F &&f) -> PassAsVecHelper<std::make_index_sequence<N>, T, F>
0064 {
0065    return PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f));
0066 }
0067 
0068 } // namespace RDF
0069 } // namespace Internal
0070 
0071 namespace RDF {
0072 namespace RDFInternal = ROOT::Internal::RDF;
0073 
0074 // clang-format off
0075 /// Given a callable with signature bool(T1, T2, ...) return a callable with same signature that returns the negated result
0076 ///
0077 /// The callable must have one single non-template definition of operator(). This is a limitation with respect to
0078 /// std::not_fn, required for interoperability with RDataFrame.
0079 // clang-format on
0080 template <typename F,
0081           typename Args = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::arg_types_nodecay,
0082           typename Ret = typename ROOT::TypeTraits::CallableTraits<std::decay_t<F>>::ret_type>
0083 auto Not(F &&f) -> decltype(RDFInternal::NotHelper(Args(), std::forward<F>(f)))
0084 {
0085    static_assert(std::is_same<Ret, bool>::value, "RDF::Not requires a callable that returns a bool.");
0086    return RDFInternal::NotHelper(Args(), std::forward<F>(f));
0087 }
0088 
0089 // clang-format off
0090 /// PassAsVec is a callable generator that allows passing N variables of type T to a function as a single collection.
0091 ///
0092 /// PassAsVec<N, T>(func) returns a callable that takes N arguments of type T, passes them down to function `func` as
0093 /// an initializer list `{t1, t2, t3,..., tN}` and returns whatever f({t1, t2, t3, ..., tN}) returns.
0094 ///
0095 /// Note that for this to work with RDataFrame the type of all columns that the callable is applied to must be exactly T.
0096 /// Example usage together with RDataFrame ("varX" columns must all be `float` variables):
0097 /// \code
0098 /// bool myVecFunc(std::vector<float> args);
0099 /// df.Filter(PassAsVec<3, float>(myVecFunc), {"var1", "var2", "var3"});
0100 /// \endcode
0101 // clang-format on
0102 template <std::size_t N, typename T, typename F>
0103 auto PassAsVec(F &&f) -> RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F>
0104 {
0105    return RDFInternal::PassAsVecHelper<std::make_index_sequence<N>, T, F>(std::forward<F>(f));
0106 }
0107 
0108 // clang-format off
0109 /// Create a graphviz representation of the dataframe computation graph, return it as a string.
0110 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to.
0111 ///
0112 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`.
0113 ///
0114 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are
0115 /// effectively optimized away from the computation graph.
0116 ///
0117 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads.
0118 // clang-format on
0119 template <typename NodeType>
0120 std::string SaveGraph(NodeType node)
0121 {
0122    ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper;
0123    return helper.RepresentGraph(node);
0124 }
0125 
0126 // clang-format off
0127 /// Create a graphviz representation of the dataframe computation graph, write it to the specified file.
0128 /// \param[in] node any node of the graph. Called on the head (first) node, it prints the entire graph. Otherwise, only the branch the node belongs to.
0129 /// \param[in] outputFile file where to save the representation.
0130 ///
0131 /// The output can be displayed with a command akin to `dot -Tpng output.dot > output.png && open output.png`.
0132 ///
0133 /// Note that "hanging" Defines, i.e. Defines without downstream nodes, will not be displayed by SaveGraph as they are
0134 /// effectively optimized away from the computation graph.
0135 ///
0136 /// Note that SaveGraph is not thread-safe and must not be called concurrently from different threads.
0137 // clang-format on
0138 template <typename NodeType>
0139 void SaveGraph(NodeType node, const std::string &outputFile)
0140 {
0141    ROOT::Internal::RDF::GraphDrawing::GraphCreatorHelper helper;
0142    std::string dotGraph = helper.RepresentGraph(node);
0143 
0144    std::ofstream out(outputFile);
0145    if (!out.is_open()) {
0146       throw std::runtime_error("Could not open output file \"" + outputFile  + "\"for reading");
0147    }
0148 
0149    out << dotGraph;
0150    out.close();
0151 }
0152 
0153 // clang-format off
0154 /// Cast a RDataFrame node to the common type ROOT::RDF::RNode
0155 /// \param[in] node Any node of a RDataFrame graph
0156 // clang-format on
0157 template <typename NodeType>
0158 RNode AsRNode(NodeType node)
0159 {
0160    return node;
0161 }
0162 
0163 // clang-format off
0164 /// Trigger the event loop of multiple RDataFrames concurrently
0165 /// \param[in] handles A vector of RResultHandles
0166 /// \return The number of distinct computation graphs that have been processed
0167 ///
0168 /// This function triggers the event loop of all computation graphs which relate to the
0169 /// given RResultHandles. The advantage compared to running the event loop implicitly by accessing the
0170 /// RResultPtr is that the event loops will run concurrently. Therefore, the overall
0171 /// computation of all results is generally more efficient.
0172 /// It should be noted that user-defined operations (e.g., Filters and Defines) of the different RDataFrame graphs are assumed to be safe to call concurrently.
0173 ///
0174 /// ~~~{.cpp}
0175 /// ROOT::RDataFrame df1("tree1", "file1.root");
0176 /// auto r1 = df1.Histo1D("var1");
0177 ///
0178 /// ROOT::RDataFrame df2("tree2", "file2.root");
0179 /// auto r2 = df2.Sum("var2");
0180 ///
0181 /// // RResultPtr -> RResultHandle conversion is automatic
0182 /// ROOT::RDF::RunGraphs({r1, r2});
0183 /// ~~~
0184 // clang-format on
0185 unsigned int RunGraphs(std::vector<RResultHandle> handles);
0186 
0187 namespace Experimental {
0188 
0189 /// \brief Produce all required systematic variations for the given result.
0190 /// \param[in] resPtr The result for which variations should be produced.
0191 /// \return A \ref ROOT::RDF::Experimental::RResultMap "RResultMap" object with full variation names as strings
0192 ///         (e.g. "pt:down") and the corresponding varied results as values.
0193 ///
0194 /// A given input RResultPtr<T> produces a corresponding RResultMap<T> with a "nominal"
0195 /// key that will return a value identical to the one contained in the original RResultPtr.
0196 /// Other keys correspond to the varied values of this result, one for each variation
0197 /// that the result depends on.
0198 /// VariationsFor does not trigger the event loop. The event loop is only triggered
0199 /// upon first access to a valid key, similarly to what happens with RResultPtr.
0200 ///
0201 /// If the result does not depend, directly or indirectly, from any registered systematic variation, the
0202 /// returned RResultMap will contain only the "nominal" key.
0203 ///
0204 /// See RDataFrame's \ref ROOT::RDF::RInterface::Vary() "Vary" method for more information and example usages.
0205 ///
0206 /// \note Currently, producing variations for the results of \ref ROOT::RDF::RInterface::Display() "Display",
0207 ///       \ref ROOT::RDF::RInterface::Report() "Report" and \ref ROOT::RDF::RInterface::Snapshot() "Snapshot"
0208 ///       actions is not supported.
0209 //
0210 // An overview of how systematic variations work internally. Given N variations (including the nominal):
0211 //
0212 // RResultMap   owns    RVariedAction
0213 //  N results            N action helpers
0214 //                       N previous filters
0215 //                       N*#input_cols column readers
0216 //
0217 // ...and each RFilter and RDefine knows for what universe it needs to construct column readers ("nominal" by default).
0218 template <typename T>
0219 RResultMap<T> VariationsFor(RResultPtr<T> resPtr)
0220 {
0221    R__ASSERT(resPtr != nullptr && "Calling VariationsFor on an empty RResultPtr");
0222 
0223    // populate parts of the computation graph for which we only have "empty shells", e.g. RJittedActions and
0224    // RJittedFilters
0225    resPtr.fLoopManager->Jit();
0226 
0227    std::unique_ptr<RDFInternal::RActionBase> variedAction;
0228    std::vector<std::shared_ptr<T>> variedResults;
0229 
0230    std::shared_ptr<RDFInternal::RActionBase> nominalAction = resPtr.fActionPtr;
0231    std::vector<std::string> variations = nominalAction->GetVariations();
0232    const auto nVariations = variations.size();
0233 
0234    if (nVariations > 0) {
0235       // clone the result once for each variation
0236       variedResults.reserve(nVariations);
0237       for (auto i = 0u; i < nVariations; ++i){
0238          // implicitly assuming that T is copiable: this should be the case
0239          // for all result types in use, as they are copied for each slot
0240          variedResults.emplace_back(new T{*resPtr.fObjPtr});
0241 
0242          // Check if the result's type T inherits from TNamed
0243          if constexpr (std::is_base_of<TNamed, T>::value) {
0244             // Get the current variation name
0245             std::string variationName = variations[i];
0246             // Replace the colon with an underscore
0247             std::replace(variationName.begin(), variationName.end(), ':', '_'); 
0248             // Get a pointer to the corresponding varied result
0249             auto &variedResult = variedResults.back();
0250             // Set the varied result's name to NOMINALNAME_VARIATIONAME
0251             variedResult->SetName((std::string(variedResult->GetName()) + "_" + variationName).c_str());
0252          }
0253       }
0254 
0255       std::vector<void *> typeErasedResults;
0256       typeErasedResults.reserve(variedResults.size());
0257       for (auto &res : variedResults)
0258          typeErasedResults.emplace_back(&res);
0259 
0260       // Create the RVariedAction and inject it in the computation graph.
0261       // This recursively creates all the required varied column readers and upstream nodes of the computation graph.
0262       variedAction = nominalAction->MakeVariedAction(std::move(typeErasedResults));
0263    }
0264 
0265    return RDFInternal::MakeResultMap<T>(resPtr.fObjPtr, std::move(variedResults), std::move(variations),
0266                                         *resPtr.fLoopManager, std::move(nominalAction), std::move(variedAction));
0267 }
0268 
0269 using SnapshotPtr_t = ROOT::RDF::RResultPtr<ROOT::RDF::RInterface<ROOT::Detail::RDF::RLoopManager, void>>;
0270 SnapshotPtr_t VariationsFor(SnapshotPtr_t resPtr);
0271 
0272 /// \brief Add ProgressBar to a ROOT::RDF::RNode
0273 /// \param[in] df RDataFrame node at which ProgressBar is called.
0274 ///
0275 /// The ProgressBar can be added not only at the RDataFrame head node, but also at any any computational node,
0276 /// such as Filter or Define.
0277 /// ###Example usage:
0278 /// ~~~{.cpp}
0279 /// ROOT::RDataFrame df("tree", "file.root");
0280 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1"));
0281 /// ROOT::RDF::Experimental::AddProgressBar(df_1);
0282 /// ~~~
0283 void AddProgressBar(ROOT::RDF::RNode df);
0284 
0285 /// \brief Add ProgressBar to an RDataFrame
0286 /// \param[in] df RDataFrame for which ProgressBar is called.
0287 ///
0288 /// This function adds a ProgressBar to display the event statistics in the terminal every
0289 /// \b m events and every \b n seconds, including elapsed time, currently processed file,
0290 /// currently processed events, the rate of event processing
0291 /// and an estimated remaining time (per file being processed).
0292 /// ProgressBar should be added after the dataframe object (df) is created first:
0293 /// ~~~{.cpp}
0294 /// ROOT::RDataFrame df("tree", "file.root");
0295 /// ROOT::RDF::Experimental::AddProgressBar(df);
0296 /// ~~~
0297 /// For more details see ROOT::RDF::Experimental::ProgressHelper Class.
0298 void AddProgressBar(ROOT::RDataFrame df);
0299 
0300 class ProgressBarAction;
0301 
0302 /// RDF progress helper.
0303 /// This class provides callback functions to the RDataFrame. The event statistics
0304 /// (including elapsed time, currently processed file, currently processed events, the rate of event processing
0305 /// and an estimated remaining time (per file being processed))
0306 /// are recorded and printed in the terminal every m events and every n seconds.
0307 /// ProgressHelper::operator()(unsigned int, T&) is thread safe, and can be used as a callback in MT mode.
0308 /// ProgressBar should be added after creating the dataframe object (df):
0309 /// ~~~{.cpp}
0310 /// ROOT::RDataFrame df("tree", "file.root");
0311 /// ROOT::RDF::Experimental::AddProgressBar(df);
0312 /// ~~~
0313 /// alternatively RDataFrame can be cast to an RNode first giving it more flexibility.
0314 /// For example, it can be called at any computational node, such as Filter or Define, not only the head node,
0315 /// with no change to the ProgressBar function itself:
0316 /// ~~~{.cpp}
0317 /// ROOT::RDataFrame df("tree", "file.root");
0318 /// auto df_1 = ROOT::RDF::RNode(df.Filter("x>1"));
0319 /// ROOT::RDF::Experimental::AddProgressBar(df_1);
0320 /// ~~~
0321 class ProgressHelper {
0322 private:
0323    double EvtPerSec() const;
0324    std::pair<std::size_t, std::chrono::seconds> RecordEvtCountAndTime();
0325    void PrintStats(std::ostream &stream, std::size_t currentEventCount, std::chrono::seconds totalElapsedSeconds) const;
0326    void PrintStatsFinal(std::ostream &stream, std::chrono::seconds totalElapsedSeconds) const;
0327    void PrintProgressBar(std::ostream &stream, std::size_t currentEventCount) const;
0328 
0329    std::chrono::time_point<std::chrono::system_clock> fBeginTime = std::chrono::system_clock::now();
0330    std::chrono::time_point<std::chrono::system_clock> fLastPrintTime = fBeginTime;
0331    std::chrono::seconds fPrintInterval{1};
0332 
0333    std::atomic<std::size_t> fProcessedEvents{0};
0334    std::size_t fLastProcessedEvents{0};
0335    std::size_t fIncrement;
0336 
0337    mutable std::mutex fSampleNameToEventEntriesMutex;
0338    std::map<std::string, ULong64_t> fSampleNameToEventEntries; // Filename, events in the file
0339 
0340    std::array<double, 20> fEventsPerSecondStatistics;
0341    std::size_t fEventsPerSecondStatisticsIndex{0};
0342 
0343    unsigned int fBarWidth;
0344    unsigned int fTotalFiles;
0345 
0346    std::mutex fPrintMutex;
0347    bool fIsTTY;
0348    bool fUseShellColours;
0349 
0350    std::shared_ptr<TTree> fTree{nullptr};
0351 
0352 public:
0353    /// Create a progress helper.
0354    /// \param increment RDF callbacks are called every `n` events. Pass this `n` here.
0355    /// \param totalFiles read total number of files in the RDF.
0356    /// \param progressBarWidth Number of characters the progress bar will occupy.
0357    /// \param printInterval Update every stats every `n` seconds.
0358    /// \param useColors Use shell colour codes to colour the output. Automatically disabled when
0359    /// we are not writing to a tty.
0360    ProgressHelper(std::size_t increment, unsigned int totalFiles = 1, unsigned int progressBarWidth = 40,
0361                   unsigned int printInterval = 1, bool useColors = true);
0362 
0363    ~ProgressHelper() = default;
0364 
0365    friend class ProgressBarAction;
0366 
0367    /// Register a new sample for completion statistics.
0368    /// \see ROOT::RDF::RInterface::DefinePerSample().
0369    /// The *id.AsString()* refers to the name of the currently processed file.
0370    /// The idea is to populate the  event entries in the *fSampleNameToEventEntries* map
0371    /// by selecting the greater of the two values:
0372    /// *id.EntryRange().second* which is the upper event entry range of the processed sample
0373    /// and the current value of the event entries in the *fSampleNameToEventEntries* map.
0374    /// In the single threaded case, the two numbers are the same as the entry range corresponds
0375    /// to the number of events in an individual file (each sample is simply a single file).
0376    /// In the multithreaded case, the idea is to accumulate the higher event entry value until
0377    /// the total number of events in a given file is reached.
0378    void registerNewSample(unsigned int /*slot*/, const ROOT::RDF::RSampleInfo &id)
0379    {
0380       std::lock_guard<std::mutex> lock(fSampleNameToEventEntriesMutex);
0381       fSampleNameToEventEntries[id.AsString()] =
0382          std::max(id.EntryRange().second, fSampleNameToEventEntries[id.AsString()]);
0383    }
0384 
0385    /// Thread-safe callback for RDataFrame.
0386    /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the
0387    /// fPrintInterval). \param slot Ignored. \param value Ignored.
0388    template <typename T>
0389    void operator()(unsigned int /*slot*/, T &value)
0390    {
0391       operator()(value);
0392    }
0393    // clang-format off
0394    /// Thread-safe callback for RDataFrame.
0395    /// It will record elapsed times and event statistics, and print a progress bar every n seconds (set by the fPrintInterval).
0396    /// \param value Ignored.
0397    // clang-format on
0398    template <typename T>
0399    void operator()(T & /*value*/)
0400    {
0401       using namespace std::chrono;
0402       // ***************************************************
0403       // Warning: Here, everything needs to be thread safe:
0404       // ***************************************************
0405       fProcessedEvents += fIncrement;
0406 
0407       // We only print every n seconds.
0408       if (duration_cast<seconds>(system_clock::now() - fLastPrintTime) < fPrintInterval) {
0409          return;
0410       }
0411 
0412       // ***************************************************
0413       // Protected by lock from here:
0414       // ***************************************************
0415       if (!fPrintMutex.try_lock())
0416          return;
0417       std::lock_guard<std::mutex> lockGuard(fPrintMutex, std::adopt_lock);
0418 
0419       std::size_t eventCount;
0420       seconds elapsedSeconds;
0421       std::tie(eventCount, elapsedSeconds) = RecordEvtCountAndTime();
0422 
0423       if (fIsTTY)
0424          std::cout << "\r";
0425 
0426       PrintProgressBar(std::cout, eventCount);
0427       PrintStats(std::cout, eventCount, elapsedSeconds);
0428 
0429       if (fIsTTY)
0430          std::cout << std::flush;
0431       else
0432          std::cout << std::endl;
0433    }
0434 
0435    std::size_t ComputeNEventsSoFar() const
0436    {
0437       std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex);
0438       std::size_t result = 0;
0439       for (const auto &item : fSampleNameToEventEntries)
0440          result += item.second;
0441       return result;
0442    }
0443 
0444    unsigned int ComputeCurrentFileIdx() const
0445    {
0446       std::unique_lock<std::mutex> lock(fSampleNameToEventEntriesMutex);
0447       return fSampleNameToEventEntries.size();
0448    }
0449 };
0450 } // namespace Experimental
0451 } // namespace RDF
0452 } // namespace ROOT
0453 #endif