Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/ROOT/RDF/RActionSnapshot.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 // Author: Vincenzo Eduardo Padulano CERN 06/2025
0002 
0003 /*************************************************************************
0004  * Copyright (C) 1995-2025, Rene Brun and Fons Rademakers.               *
0005  * All rights reserved.                                                  *
0006  *                                                                       *
0007  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0008  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0009  *************************************************************************/
0010 
0011 #ifndef ROOT_RACTIONSNAPSHOT
0012 #define ROOT_RACTIONSNAPSHOT
0013 
0014 #include "ROOT/RDF/ColumnReaderUtils.hxx"
0015 #include "ROOT/RDF/GraphNode.hxx"
0016 #include "ROOT/RDF/RActionBase.hxx"
0017 #include "ROOT/RDF/RFilterBase.hxx"
0018 #include "ROOT/RDF/RJittedFilter.hxx"
0019 #include "ROOT/RDF/RLoopManager.hxx"
0020 
0021 #include <cstddef> // std::size_t
0022 #include <memory>
0023 #include <string>
0024 #include <vector>
0025 
0026 namespace ROOT::Internal::RDF {
0027 
0028 namespace GraphDrawing {
0029 std::shared_ptr<GraphNode> AddDefinesToGraph(std::shared_ptr<GraphNode> node, const RColumnRegister &colRegister,
0030                                              const std::vector<std::string> &prevNodeDefines,
0031                                              std::unordered_map<void *, std::shared_ptr<GraphNode>> &visitedMap);
0032 } // namespace GraphDrawing
0033 
0034 class SnapshotHelperWithVariations;
0035 
0036 template <typename Helper, typename PrevNode>
0037 class R__CLING_PTRCHECK(off) RActionSnapshot final : public RActionBase {
0038 
0039    // Template needed to avoid dependency on ActionHelpers.hxx
0040    Helper fHelper;
0041 
0042    // If the PrevNode is a RJittedFilter, our collection of previous nodes will have to use the RFilterBase type:
0043    // we'll have a RJittedFilter for the nominal case, but the others will be concrete filters.
0044    using PrevNodeCommon_t = std::conditional_t<std::is_same_v<PrevNode, ROOT::Detail::RDF::RJittedFilter>,
0045                                                ROOT::Detail::RDF::RFilterBase, PrevNode>;
0046    /// Previous nodes in the computation graph. First element is nominal, others are varied.
0047    std::vector<std::shared_ptr<PrevNodeCommon_t>> fPrevNodes;
0048 
0049    /// Column readers per slot and per input column
0050    std::vector<std::vector<RColumnReaderBase *>> fValues;
0051 
0052    /// The nth flag signals whether the nth input column is a custom column or not.
0053    std::vector<bool> fIsDefine;
0054 
0055    /// Types of the columns to Snapshot
0056    std::vector<const std::type_info *> fColTypeIDs;
0057 
0058    ROOT::RDF::SampleCallback_t GetSampleCallback() final { return fHelper.GetSampleCallback(); }
0059 
0060    void AppendVariedPrevNodes()
0061    {
0062       // This method only makes sense if we're appending the varied filters to the list after the nominal
0063       assert(fPrevNodes.size() == 1);
0064       const auto &currentVariations = GetVariations();
0065 
0066       // If this node hangs from the RLoopManager itself, just use that as the upstream node for each variation
0067       auto nominalPrevNode = fPrevNodes.front();
0068       if (static_cast<ROOT::Detail::RDF::RNodeBase *>(nominalPrevNode.get()) == fLoopManager) {
0069          fPrevNodes.resize(1 + currentVariations.size(), nominalPrevNode);
0070          return;
0071       }
0072 
0073       // Otherwise, append one varied filter per variation
0074       const auto &prevVariations = nominalPrevNode->GetVariations();
0075       fPrevNodes.reserve(1 + currentVariations.size());
0076 
0077       for (const auto &variation : currentVariations) {
0078          if (IsStrInVec(variation, prevVariations)) {
0079             fPrevNodes.emplace_back(
0080                std::static_pointer_cast<PrevNodeCommon_t>(nominalPrevNode->GetVariedFilter(variation)));
0081          } else {
0082             fPrevNodes.push_back(nominalPrevNode);
0083          }
0084       }
0085    }
0086 
0087 public:
0088    RActionSnapshot(Helper &&h, const std::vector<std::string> &columns,
0089                    const std::vector<const std::type_info *> &colTypeIDs, std::shared_ptr<PrevNode> pd,
0090                    const RColumnRegister &colRegister)
0091       : RActionBase(pd->GetLoopManagerUnchecked(), columns, colRegister, pd->GetVariations()),
0092         fHelper(std::move(h)),
0093         fPrevNodes{std::static_pointer_cast<PrevNodeCommon_t>(pd)},
0094         fValues(GetNSlots()),
0095         fColTypeIDs(colTypeIDs)
0096    {
0097       fLoopManager->Register(this);
0098 
0099       const auto nColumns = columns.size();
0100       fIsDefine.reserve(nColumns);
0101       for (auto i = 0u; i < nColumns; ++i)
0102          fIsDefine.push_back(colRegister.IsDefineOrAlias(columns[i]));
0103 
0104       if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
0105          // Need to populate parts of the computation graph for which we have empty shells, e.g. RJittedFilters and
0106          // varied Defines
0107          if (!GetVariations().empty())
0108             fLoopManager->Jit();
0109 
0110          AppendVariedPrevNodes();
0111 
0112          for (auto i = 0u; i < nColumns; ++i) {
0113             if (fIsDefine[i]) {
0114                auto define = colRegister.GetDefine(columns[i]);
0115                define->MakeVariations(GetVariations());
0116             }
0117          }
0118       }
0119    }
0120 
0121    RActionSnapshot(const RActionSnapshot &) = delete;
0122    RActionSnapshot &operator=(const RActionSnapshot &) = delete;
0123    RActionSnapshot(RActionSnapshot &&) = delete;
0124    RActionSnapshot &operator=(RActionSnapshot &&) = delete;
0125 
0126    ~RActionSnapshot() final { fLoopManager->Deregister(this); }
0127 
0128    /**
0129       Retrieve a wrapper to the result of the action that knows how to merge
0130       with others of the same type.
0131    */
0132    std::unique_ptr<ROOT::Detail::RDF::RMergeableValueBase> GetMergeableValue() const final
0133    {
0134       return fHelper.GetMergeableValue();
0135    }
0136 
0137    void Initialize() final { fHelper.Initialize(); }
0138 
0139    void InitSlot(TTreeReader *r, unsigned int slot) final
0140    {
0141       fValues[slot] = GetUntypedColumnReaders(slot, r, RActionBase::GetColRegister(), *fLoopManager,
0142                                               RActionBase::GetColumnNames(), fColTypeIDs);
0143 
0144       if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
0145          // In case of systematic variations, append also the varied column readers to the values
0146          // that get passed to the helpers
0147          auto const &variations = GetVariations();
0148          for (unsigned int variationIndex = 0; variationIndex < variations.size(); ++variationIndex) {
0149             auto const &readers =
0150                GetUntypedColumnReaders(slot, r, RActionBase::GetColRegister(), *fLoopManager,
0151                                        RActionBase::GetColumnNames(), fColTypeIDs, variations[variationIndex]);
0152             for (unsigned int i = 0; i < readers.size(); ++i) {
0153                if (fValues[slot][i] != readers[i]) {
0154                   // The reader with variations differs from nominal, so this column needs to be added to the output
0155                   fValues[slot].push_back(readers[i]);
0156                   // Both the original and the varied column need to be registered for masking
0157                   fHelper.RegisterVariedColumn(slot, i, i, 0,
0158                                                "nominal"); // (No harm flagging the nominal multiple times)
0159                   fHelper.RegisterVariedColumn(slot, fValues[slot].size() - 1, i, variationIndex + 1,
0160                                                variations[variationIndex]);
0161                }
0162             }
0163          }
0164       }
0165 
0166       fHelper.InitTask(r, slot);
0167    }
0168 
0169    void *GetValue(unsigned int slot, std::size_t readerIdx, Long64_t entry)
0170    {
0171       assert(slot < fValues.size());
0172       assert(readerIdx < fValues[slot].size());
0173       if (auto *val = fValues[slot][readerIdx]->template TryGet<void>(entry))
0174          return val;
0175 
0176       throw std::out_of_range{"RDataFrame: Action (" + fHelper.GetActionName() +
0177                               ") could not retrieve value for column '" + fColumnNames[readerIdx] + "' for entry " +
0178                               std::to_string(entry) +
0179                               ". You can use the DefaultValueFor operation to provide a default value, or "
0180                               "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
0181    }
0182 
0183    void CallExec(unsigned int slot, Long64_t entry)
0184    {
0185       std::vector<void *> untypedValues;
0186       auto nReaders = fValues[slot].size();
0187       untypedValues.reserve(nReaders);
0188       for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
0189          untypedValues.push_back(GetValue(slot, readerIdx, entry));
0190 
0191       fHelper.Exec(slot, untypedValues);
0192    }
0193 
0194    void Run(unsigned int slot, Long64_t entry) final
0195    {
0196       if constexpr (std::is_same_v<Helper, SnapshotHelperWithVariations>) {
0197          // check if entry passes all filters
0198          std::vector<bool> filterPassed(fPrevNodes.size(), false);
0199          for (unsigned int variation = 0; variation < fPrevNodes.size(); ++variation) {
0200             filterPassed[variation] = fPrevNodes[variation]->CheckFilters(slot, entry);
0201          }
0202 
0203          // Currently, every event where any of nominal or variations pass gets written to the output.
0204          // This logic could be extended for different use cases if the need arises.
0205          if (std::any_of(filterPassed.begin(), filterPassed.end(), [](bool val) { return val; })) {
0206             // TODO: Don't allocate
0207             std::vector<void *> untypedValues;
0208             auto nReaders = fValues[slot].size();
0209             untypedValues.reserve(nReaders);
0210             for (decltype(nReaders) readerIdx{}; readerIdx < nReaders; readerIdx++)
0211                untypedValues.push_back(GetValue(slot, readerIdx, entry));
0212 
0213             fHelper.Exec(slot, untypedValues, filterPassed);
0214          }
0215       } else {
0216          if (fPrevNodes.front()->CheckFilters(slot, entry))
0217             CallExec(slot, entry);
0218       }
0219    }
0220 
0221    void TriggerChildrenCount() final
0222    {
0223       for (auto const &node : fPrevNodes)
0224          node->IncrChildrenCount();
0225    }
0226 
0227    /// Clean-up operations to be performed at the end of a task.
0228    void FinalizeSlot(unsigned int slot) final
0229    {
0230       fValues[slot].clear();
0231       fHelper.CallFinalizeTask(slot);
0232    }
0233 
0234    /// Clean-up and finalize the action result (e.g. merging slot-local results).
0235    /// It invokes the helper's Finalize method.
0236    void Finalize() final
0237    {
0238       fHelper.Finalize();
0239       SetHasRun();
0240    }
0241 
0242    std::shared_ptr<GraphDrawing::GraphNode>
0243    GetGraph(std::unordered_map<void *, std::shared_ptr<GraphDrawing::GraphNode>> &visitedMap) final
0244    {
0245       // Action nodes do not need to go through CreateFilterNode: they are never common nodes between multiple branches
0246       const auto nodeType = HasRun() ? GraphDrawing::ENodeType::kUsedAction : GraphDrawing::ENodeType::kAction;
0247       auto thisNode = std::make_shared<GraphDrawing::GraphNode>(fHelper.GetActionName(), visitedMap.size(), nodeType);
0248       visitedMap[(void *)this] = thisNode;
0249 
0250       for (auto const &node : fPrevNodes) {
0251          auto prevNode = node->GetGraph(visitedMap);
0252          const auto &prevColumns = prevNode->GetDefinedColumns();
0253          auto upmostNode = AddDefinesToGraph(thisNode, GetColRegister(), prevColumns, visitedMap);
0254 
0255          thisNode->AddDefinedColumns(GetColRegister().GenerateColumnNames());
0256          upmostNode->SetPrevNode(prevNode);
0257       }
0258       return thisNode;
0259    }
0260 
0261    /// Forwards to the action helpers; will throw since PartialUpdate not supported for most snapshot helpers.
0262    void *PartialUpdate(unsigned int slot) final { return fHelper.CallPartialUpdate(slot); }
0263 
0264    /// Will throw, since varied actions are unsupported. Instead, set a flag in RSnapshotOptions.
0265    [[maybe_unused]] std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> && /*results*/) final
0266    {
0267       throw std::logic_error("RDataFrame::Snapshot: The snapshot action cannot be varied. Instead, switch on "
0268                              "variations in RSnapshotOptions.");
0269    }
0270 
0271    /**
0272     * \brief Returns a new action with a cloned helper.
0273     *
0274     * \param[in] newResult The result to be filled by the new action (needed to clone the helper).
0275     * \return A unique pointer to the new action.
0276     */
0277    std::unique_ptr<RActionBase> CloneAction(void *newResult) final
0278    {
0279       return std::make_unique<RActionSnapshot>(fHelper.CallMakeNew(newResult), GetColumnNames(), fColTypeIDs,
0280                                                std::static_pointer_cast<PrevNode>(fPrevNodes.front()),
0281                                                GetColRegister());
0282    }
0283 };
0284 
0285 } // namespace ROOT::Internal::RDF
0286 
0287 #endif // ROOT_RACTIONSNAPSHOT