File indexing completed on 2025-02-22 10:53:09
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #ifndef ROOT_RVARIEDACTION
0012 #define ROOT_RVARIEDACTION
0013
0014 #include "ColumnReaderUtils.hxx"
0015 #include "GraphNode.hxx"
0016 #include "RActionBase.hxx"
0017 #include "RColumnReaderBase.hxx"
0018 #include "RLoopManager.hxx"
0019 #include "RJittedFilter.hxx"
0020 #include "ROOT/RDF/RMergeableValue.hxx"
0021 #include "ROOT/RDF/RSampleInfo.hxx"
0022
0023 #include <Rtypes.h> // R__CLING_PTRCHECK
0024 #include <ROOT/TypeTraits.hxx>
0025
0026 #include <algorithm>
0027 #include <array>
0028 #include <memory>
0029 #include <utility> // make_index_sequence
0030 #include <vector>
0031
0032 namespace ROOT {
0033 namespace Internal {
0034 namespace RDF {
0035
0036 namespace RDFGraphDrawing = ROOT::Internal::RDF::GraphDrawing;
0037
0038
0039 template <typename Helper, typename PrevNode, typename ColumnTypes_t>
0040 class R__CLING_PTRCHECK(off) RVariedAction final : public RActionBase {
0041 using TypeInd_t = std::make_index_sequence<ColumnTypes_t::list_size>;
0042
0043
0044 using PrevNodeType = std::conditional_t<std::is_same<PrevNode, RJittedFilter>::value, RFilterBase, PrevNode>;
0045
0046 std::vector<Helper> fHelpers;
0047
0048 std::vector<std::shared_ptr<PrevNodeType>> fPrevNodes;
0049
0050
0051 std::vector<std::vector<std::array<RColumnReaderBase *, ColumnTypes_t::list_size>>> fInputValues;
0052
0053
0054 std::array<bool, ColumnTypes_t::list_size> fIsDefine;
0055
0056
0057
0058
0059
0060
0061 std::vector<std::shared_ptr<PrevNodeType>> MakePrevFilters(std::shared_ptr<PrevNode> nominal) const
0062 {
0063 const auto &variations = GetVariations();
0064 std::vector<std::shared_ptr<PrevNodeType>> prevFilters;
0065 prevFilters.reserve(variations.size());
0066 if (static_cast<RNodeBase *>(nominal.get()) == fLoopManager) {
0067
0068 prevFilters.resize(variations.size(), nominal);
0069 } else {
0070
0071 const auto &prevVariations = nominal->GetVariations();
0072 for (const auto &variation : variations) {
0073 if (IsStrInVec(variation, prevVariations)) {
0074 prevFilters.emplace_back(std::static_pointer_cast<PrevNodeType>(nominal->GetVariedFilter(variation)));
0075 } else {
0076 prevFilters.emplace_back(nominal);
0077 }
0078 }
0079 }
0080
0081 return prevFilters;
0082 }
0083
0084 void SetupClass()
0085 {
0086
0087 const auto &colRegister = GetColRegister();
0088 const auto &columnNames = GetColumnNames();
0089
0090 fLoopManager->Register(this);
0091
0092 for (auto i = 0u; i < columnNames.size(); ++i) {
0093 auto *define = colRegister.GetDefine(columnNames[i]);
0094 fIsDefine[i] = define != nullptr;
0095 if (fIsDefine[i])
0096 define->MakeVariations(GetVariations());
0097 }
0098 }
0099
0100
0101 RVariedAction(std::vector<Helper> &&helpers, const ColumnNames_t &columns,
0102 const std::vector<std::shared_ptr<PrevNodeType>> &prevNodes, const RColumnRegister &colRegister)
0103 : RActionBase(prevNodes[0]->GetLoopManagerUnchecked(), columns, colRegister, prevNodes[0]->GetVariations()),
0104 fHelpers(std::move(helpers)),
0105 fPrevNodes(prevNodes),
0106 fInputValues(GetNSlots())
0107 {
0108 SetupClass();
0109 }
0110
0111 public:
0112 RVariedAction(std::vector<Helper> &&helpers, const ColumnNames_t &columns, std::shared_ptr<PrevNode> prevNode,
0113 const RColumnRegister &colRegister)
0114 : RActionBase(prevNode->GetLoopManagerUnchecked(), columns, colRegister, prevNode->GetVariations()),
0115 fHelpers(std::move(helpers)),
0116 fPrevNodes(MakePrevFilters(prevNode)),
0117 fInputValues(GetNSlots())
0118 {
0119 SetupClass();
0120 }
0121
0122 RVariedAction(const RVariedAction &) = delete;
0123 RVariedAction &operator=(const RVariedAction &) = delete;
0124
0125 ~RVariedAction() { fLoopManager->Deregister(this); }
0126
0127 void Initialize() final
0128 {
0129 std::for_each(fHelpers.begin(), fHelpers.end(), [](Helper &h) { h.Initialize(); });
0130 }
0131
0132 void InitSlot(TTreeReader *r, unsigned int slot) final
0133 {
0134 RColumnReadersInfo info{GetColumnNames(), GetColRegister(), fIsDefine.data(), *fLoopManager};
0135
0136
0137 for (const auto &variation : GetVariations())
0138 fInputValues[slot].emplace_back(GetColumnReaders(slot, r, ColumnTypes_t{}, info, variation));
0139
0140 std::for_each(fHelpers.begin(), fHelpers.end(), [=](Helper &h) { h.InitTask(r, slot); });
0141 }
0142
0143 template <typename... ColTypes, std::size_t... S>
0144 void
0145 CallExec(unsigned int slot, unsigned int varIdx, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>)
0146 {
0147 fHelpers[varIdx].Exec(slot, fInputValues[slot][varIdx][S]->template Get<ColTypes>(entry)...);
0148 (void)entry;
0149 }
0150
0151 void Run(unsigned int slot, Long64_t entry) final
0152 {
0153 for (auto varIdx = 0u; varIdx < GetVariations().size(); ++varIdx) {
0154 if (fPrevNodes[varIdx]->CheckFilters(slot, entry))
0155 CallExec(slot, varIdx, entry, ColumnTypes_t{}, TypeInd_t{});
0156 }
0157 }
0158
0159 void TriggerChildrenCount() final
0160 {
0161 std::for_each(fPrevNodes.begin(), fPrevNodes.end(), [](auto &f) { f->IncrChildrenCount(); });
0162 }
0163
0164
0165 void FinalizeSlot(unsigned int slot) final
0166 {
0167 fInputValues[slot].clear();
0168 std::for_each(fHelpers.begin(), fHelpers.end(), [=](Helper &h) { h.CallFinalizeTask(slot); });
0169 }
0170
0171
0172
0173 void Finalize() final
0174 {
0175 std::for_each(fHelpers.begin(), fHelpers.end(), [](Helper &h) { h.Finalize(); });
0176 SetHasRun();
0177 }
0178
0179
0180 void *PartialUpdate(unsigned int slot) final { return PartialUpdateImpl(slot); }
0181
0182
0183 ROOT::RDF::SampleCallback_t GetSampleCallback() final
0184 {
0185 if (fHelpers[0].GetSampleCallback()) {
0186 std::vector<ROOT::RDF::SampleCallback_t> callbacks;
0187 for (auto &h : fHelpers)
0188 callbacks.push_back(h.GetSampleCallback());
0189
0190 auto callEachCallback = [cs = std::move(callbacks)](unsigned int slot, const RSampleInfo &info) {
0191 for (auto &c : cs)
0192 c(slot, info);
0193 };
0194
0195 return callEachCallback;
0196 }
0197
0198 return {};
0199 }
0200
0201 std::shared_ptr<RDFGraphDrawing::GraphNode>
0202 GetGraph(std::unordered_map<void *, std::shared_ptr<RDFGraphDrawing::GraphNode>> &visitedMap) final
0203 {
0204 auto prevNode = fPrevNodes[0]->GetGraph(visitedMap);
0205 const auto &prevColumns = prevNode->GetDefinedColumns();
0206
0207
0208 const auto nodeType = HasRun() ? RDFGraphDrawing::ENodeType::kUsedAction : RDFGraphDrawing::ENodeType::kAction;
0209 auto thisNode = std::make_shared<RDFGraphDrawing::GraphNode>("Varied " + fHelpers[0].GetActionName(),
0210 visitedMap.size(), nodeType);
0211 visitedMap[(void *)this] = thisNode;
0212
0213 auto upmostNode = AddDefinesToGraph(thisNode, GetColRegister(), prevColumns, visitedMap);
0214
0215 thisNode->AddDefinedColumns(GetColRegister().GenerateColumnNames());
0216 upmostNode->SetPrevNode(prevNode);
0217 return thisNode;
0218 }
0219
0220
0221
0222
0223
0224 std::unique_ptr<RMergeableValueBase> GetMergeableValue() const final
0225 {
0226 std::vector<std::string> keys{GetVariations()};
0227
0228 std::vector<std::unique_ptr<RDFDetail::RMergeableValueBase>> values;
0229 values.reserve(fHelpers.size());
0230 for (auto &&h : fHelpers)
0231 values.emplace_back(h.GetMergeableValue());
0232
0233 return std::make_unique<RDFDetail::RMergeableVariationsBase>(std::move(keys), std::move(values));
0234 }
0235
0236 [[noreturn]] std::unique_ptr<RActionBase> MakeVariedAction(std::vector<void *> &&) final
0237 {
0238 throw std::logic_error("Cannot produce a varied action from a varied action.");
0239 }
0240
0241 std::unique_ptr<RActionBase> CloneAction(void *typeErasedResults) final
0242 {
0243 const auto &vectorOfTypeErasedResults = *reinterpret_cast<const std::vector<void *> *>(typeErasedResults);
0244 assert(vectorOfTypeErasedResults.size() == fHelpers.size() &&
0245 "The number of results and the number of helpers are not the same!");
0246
0247 std::vector<Helper> clonedHelpers;
0248 clonedHelpers.reserve(fHelpers.size());
0249 for (std::size_t i = 0; i < fHelpers.size(); i++) {
0250 clonedHelpers.emplace_back(fHelpers[i].CallMakeNew(vectorOfTypeErasedResults[i]));
0251 }
0252
0253 return std::unique_ptr<RVariedAction>(
0254 new RVariedAction(std::move(clonedHelpers), GetColumnNames(), fPrevNodes, GetColRegister()));
0255 }
0256
0257 private:
0258
0259
0260 template <typename H = Helper>
0261 auto PartialUpdateImpl(unsigned int slot) -> decltype(std::declval<H>().PartialUpdate(slot), (void *)(nullptr))
0262 {
0263 return &fHelpers[0].PartialUpdate(slot);
0264 }
0265
0266
0267 void *PartialUpdateImpl(...) { throw std::runtime_error("This action does not support callbacks!"); }
0268 };
0269
0270 }
0271 }
0272 }
0273
0274 #endif