Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:53:09

0001 // Author: Enrico Guiraud, CERN 11/2021
0002 
0003 /*************************************************************************
0004  * Copyright (C) 1995-2022, Rene Brun and Fons Rademakers.               *
0005  * All rights reserved.                                                  *
0006  *                                                                       *
0007  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0008  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0009  *************************************************************************/
0010 
0011 #ifndef ROOT_RVARIEDACTION
0012 #define ROOT_RVARIEDACTION
0013 
0014 #include "ColumnReaderUtils.hxx"
0015 #include "GraphNode.hxx"
0016 #include "RActionBase.hxx"
0017 #include "RColumnReaderBase.hxx"
0018 #include "RLoopManager.hxx"
0019 #include "RJittedFilter.hxx"
0020 #include "ROOT/RDF/RMergeableValue.hxx"
0021 #include "ROOT/RDF/RSampleInfo.hxx"
0022 
0023 #include <Rtypes.h> // R__CLING_PTRCHECK
0024 #include <ROOT/TypeTraits.hxx>
0025 
0026 #include <algorithm>
0027 #include <array>
0028 #include <memory>
0029 #include <utility> // make_index_sequence
0030 #include <vector>
0031 
0032 namespace ROOT {
0033 namespace Internal {
0034 namespace RDF {
0035 
0036 namespace RDFGraphDrawing = ROOT::Internal::RDF::GraphDrawing;
0037 
0038 /// Just like an RAction, but it has N action helpers and N previous nodes (N is the number of variations).
0039 template <typename Helper, typename PrevNode, typename ColumnTypes_t>
0040 class R__CLING_PTRCHECK(off) RVariedAction final : public RActionBase {
0041    using TypeInd_t = std::make_index_sequence<ColumnTypes_t::list_size>;
0042    // If the PrevNode is a RJittedFilter, our collection of previous nodes will have to use the RNodeBase type:
0043    // we'll have a RJittedFilter for the nominal case, but the others will be concrete filters.
0044    using PrevNodeType = std::conditional_t<std::is_same<PrevNode, RJittedFilter>::value, RFilterBase, PrevNode>;
0045 
0046    std::vector<Helper> fHelpers; ///< Action helpers per variation.
0047    /// Owning pointers to upstream nodes for each systematic variation.
0048    std::vector<std::shared_ptr<PrevNodeType>> fPrevNodes;
0049 
0050    /// Column readers per slot (outer dimension), per variation and per input column (inner dimension, std::array).
0051    std::vector<std::vector<std::array<RColumnReaderBase *, ColumnTypes_t::list_size>>> fInputValues;
0052 
0053    /// The nth flag signals whether the nth input column is a custom column or not.
0054    std::array<bool, ColumnTypes_t::list_size> fIsDefine;
0055 
0056    /// \brief Creates new filter nodes, one per variation, from the upstream nominal one.
0057    /// \param nominal The nominal filter
0058    /// \return The varied filters
0059    ///
0060    /// The nominal filter is not included in the return value.
0061    std::vector<std::shared_ptr<PrevNodeType>> MakePrevFilters(std::shared_ptr<PrevNode> nominal) const
0062    {
0063       const auto &variations = GetVariations();
0064       std::vector<std::shared_ptr<PrevNodeType>> prevFilters;
0065       prevFilters.reserve(variations.size());
0066       if (static_cast<RNodeBase *>(nominal.get()) == fLoopManager) {
0067          // just fill this with the RLoopManager N times
0068          prevFilters.resize(variations.size(), nominal);
0069       } else {
0070          // create varied versions of the previous filter node
0071          const auto &prevVariations = nominal->GetVariations();
0072          for (const auto &variation : variations) {
0073             if (IsStrInVec(variation, prevVariations)) {
0074                prevFilters.emplace_back(std::static_pointer_cast<PrevNodeType>(nominal->GetVariedFilter(variation)));
0075             } else {
0076                prevFilters.emplace_back(nominal);
0077             }
0078          }
0079       }
0080 
0081       return prevFilters;
0082    }
0083 
0084    void SetupClass()
0085    {
0086       // The column register and names are private members of RActionBase
0087       const auto &colRegister = GetColRegister();
0088       const auto &columnNames = GetColumnNames();
0089 
0090       fLoopManager->Register(this);
0091 
0092       for (auto i = 0u; i < columnNames.size(); ++i) {
0093          auto *define = colRegister.GetDefine(columnNames[i]);
0094          fIsDefine[i] = define != nullptr;
0095          if (fIsDefine[i])
0096             define->MakeVariations(GetVariations());
0097       }
0098    }
0099 
0100    /// This constructor takes in input a vector of previous nodes, motivated by the CloneAction logic.
0101    RVariedAction(std::vector<Helper> &&helpers, const ColumnNames_t &columns,
0102                  const std::vector<std::shared_ptr<PrevNodeType>> &prevNodes, const RColumnRegister &colRegister)
0103       : RActionBase(prevNodes[0]->GetLoopManagerUnchecked(), columns, colRegister, prevNodes[0]->GetVariations()),
0104         fHelpers(std::move(helpers)),
0105         fPrevNodes(prevNodes),
0106         fInputValues(GetNSlots())
0107    {
0108       SetupClass();
0109    }
0110 
0111 public:
0112    RVariedAction(std::vector<Helper> &&helpers, const ColumnNames_t &columns, std::shared_ptr<PrevNode> prevNode,
0113                  const RColumnRegister &colRegister)
0114       : RActionBase(prevNode->GetLoopManagerUnchecked(), columns, colRegister, prevNode->GetVariations()),
0115         fHelpers(std::move(helpers)),
0116         fPrevNodes(MakePrevFilters(prevNode)),
0117         fInputValues(GetNSlots())
0118    {
0119       SetupClass();
0120    }
0121 
0122    RVariedAction(const RVariedAction &) = delete;
0123    RVariedAction &operator=(const RVariedAction &) = delete;
0124 
0125    ~RVariedAction() { fLoopManager->Deregister(this); }
0126 
0127    void Initialize() final
0128    {
0129       std::for_each(fHelpers.begin(), fHelpers.end(), [](Helper &h) { h.Initialize(); });
0130    }
0131 
0132    void InitSlot(TTreeReader *r, unsigned int slot) final
0133    {
0134       RColumnReadersInfo info{GetColumnNames(), GetColRegister(), fIsDefine.data(), *fLoopManager};
0135 
0136       // get readers for each systematic variation
0137       for (const auto &variation : GetVariations())
0138          fInputValues[slot].emplace_back(GetColumnReaders(slot, r, ColumnTypes_t{}, info, variation));
0139 
0140       std::for_each(fHelpers.begin(), fHelpers.end(), [=](Helper &h) { h.InitTask(r, slot); });
0141    }
0142 
0143    template <typename... ColTypes, std::size_t... S>
0144    void
0145    CallExec(unsigned int slot, unsigned int varIdx, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>)
0146    {
0147       fHelpers[varIdx].Exec(slot, fInputValues[slot][varIdx][S]->template Get<ColTypes>(entry)...);
0148       (void)entry;
0149    }
0150 
0151    void Run(unsigned int slot, Long64_t entry) final
0152    {
0153       for (auto varIdx = 0u; varIdx < GetVariations().size(); ++varIdx) {
0154          if (fPrevNodes[varIdx]->CheckFilters(slot, entry))
0155             CallExec(slot, varIdx, entry, ColumnTypes_t{}, TypeInd_t{});
0156       }
0157    }
0158 
0159    void TriggerChildrenCount() final
0160    {
0161       std::for_each(fPrevNodes.begin(), fPrevNodes.end(), [](auto &f) { f->IncrChildrenCount(); });
0162    }
0163 
0164    /// Clean-up operations to be performed at the end of a task.
0165    void FinalizeSlot(unsigned int slot) final
0166    {
0167       fInputValues[slot].clear();
0168       std::for_each(fHelpers.begin(), fHelpers.end(), [=](Helper &h) { h.CallFinalizeTask(slot); });
0169    }
0170 
0171    /// Clean-up and finalize the action result (e.g. merging slot-local results).
0172    /// It invokes the helper's Finalize method.
0173    void Finalize() final
0174    {
0175       std::for_each(fHelpers.begin(), fHelpers.end(), [](Helper &h) { h.Finalize(); });
0176       SetHasRun();
0177    }
0178 
0179    /// Return the partially-updated value connected to the first variation.
0180    void *PartialUpdate(unsigned int slot) final { return PartialUpdateImpl(slot); }
0181 
0182    /// Return a callback that in turn runs the callbacks of each variation's helper.
0183    ROOT::RDF::SampleCallback_t GetSampleCallback() final
0184    {
0185       if (fHelpers[0].GetSampleCallback()) {
0186          std::vector<ROOT::RDF::SampleCallback_t> callbacks;
0187          for (auto &h : fHelpers)
0188             callbacks.push_back(h.GetSampleCallback());
0189 
0190          auto callEachCallback = [cs = std::move(callbacks)](unsigned int slot, const RSampleInfo &info) {
0191             for (auto &c : cs)
0192                c(slot, info);
0193          };
0194 
0195          return callEachCallback;
0196       }
0197 
0198       return {};
0199    }
0200 
0201    std::shared_ptr<RDFGraphDrawing::GraphNode>
0202    GetGraph(std::unordered_map<void *, std::shared_ptr<RDFGraphDrawing::GraphNode>> &visitedMap) final
0203    {
0204       auto prevNode = fPrevNodes[0]->GetGraph(visitedMap);
0205       const auto &prevColumns = prevNode->GetDefinedColumns();
0206 
0207       // Action nodes do not need to go through CreateFilterNode: they are never common nodes between multiple branches
0208       const auto nodeType = HasRun() ? RDFGraphDrawing::ENodeType::kUsedAction : RDFGraphDrawing::ENodeType::kAction;
0209       auto thisNode = std::make_shared<RDFGraphDrawing::GraphNode>("Varied " + fHelpers[0].GetActionName(),
0210                                                                    visitedMap.size(), nodeType);
0211       visitedMap[(void *)this] = thisNode;
0212 
0213       auto upmostNode = AddDefinesToGraph(thisNode, GetColRegister(), prevColumns, visitedMap);
0214 
0215       thisNode->AddDefinedColumns(GetColRegister().GenerateColumnNames());
0216       upmostNode->SetPrevNode(prevNode);
0217       return thisNode;
0218    }
0219 
0220    /**
0221       Retrieve a container holding the names and values of the variations. It
0222       knows how to merge with others of the same type.
0223    */
0224    std::unique_ptr<RMergeableValueBase> GetMergeableValue() const final
0225    {
0226       std::vector<std::string> keys{GetVariations()};
0227 
0228       std::vector<std::unique_ptr<RDFDetail::RMergeableValueBase>> values;
0229       values.reserve(fHelpers.size());
0230       for (auto &&h : fHelpers)
0231          values.emplace_back(h.GetMergeableValue());
0232 
0233       return std::make_unique<RDFDetail::RMergeableVariationsBase>(std::move(keys), std::move(values));
0234    }
0235 
0236    [[noreturn]] std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&) final
0237    {
0238       throw std::logic_error("Cannot produce a varied action from a varied action.");
0239    }
0240 
0241    std::unique_ptr<RActionBase> CloneAction(void *typeErasedResults) final
0242    {
0243       const auto &vectorOfTypeErasedResults = *reinterpret_cast<const std::vector<void *> *>(typeErasedResults);
0244       assert(vectorOfTypeErasedResults.size() == fHelpers.size() &&
0245              "The number of results and the number of helpers are not the same!");
0246 
0247       std::vector<Helper> clonedHelpers;
0248       clonedHelpers.reserve(fHelpers.size());
0249       for (std::size_t i = 0; i < fHelpers.size(); i++) {
0250          clonedHelpers.emplace_back(fHelpers[i].CallMakeNew(vectorOfTypeErasedResults[i]));
0251       }
0252 
0253       return std::unique_ptr<RVariedAction>(
0254          new RVariedAction(std::move(clonedHelpers), GetColumnNames(), fPrevNodes, GetColRegister()));
0255    }
0256 
0257 private:
0258    // this overload is SFINAE'd out if Helper does not implement `PartialUpdate`
0259    // the template parameter is required to defer instantiation of the method to SFINAE time
0260    template <typename H = Helper>
0261    auto PartialUpdateImpl(unsigned int slot) -> decltype(std::declval<H>().PartialUpdate(slot), (void *)(nullptr))
0262    {
0263       return &fHelpers[0].PartialUpdate(slot);
0264    }
0265 
0266    // this one is always available but has lower precedence thanks to `...`
0267    void *PartialUpdateImpl(...) { throw std::runtime_error("This action does not support callbacks!"); }
0268 };
0269 
0270 } // namespace RDF
0271 } // namespace Internal
0272 } // namespace ROOT
0273 
0274 #endif // ROOT_RVARIEDACTION