File indexing completed on 2025-12-16 10:29:49
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #ifndef ROOT_RFILTER
0012 #define ROOT_RFILTER
0013
0014 #include "ROOT/RDF/ColumnReaderUtils.hxx"
0015 #include "ROOT/RDF/RColumnReaderBase.hxx"
0016 #include "ROOT/RDF/RCutFlowReport.hxx"
0017 #include "ROOT/RDF/Utils.hxx"
0018 #include "ROOT/RDF/RFilterBase.hxx"
0019 #include "ROOT/RDF/RLoopManager.hxx"
0020 #include "ROOT/TypeTraits.hxx"
0021 #include "RtypesCore.h"
0022
0023 #include <algorithm>
0024 #include <cassert>
0025 #include <memory>
0026 #include <string>
0027 #include <unordered_map>
0028 #include <utility> // std::index_sequence
0029 #include <vector>
0030
0031 namespace ROOT {
0032
0033 namespace Internal {
0034 namespace RDF {
0035 using namespace ROOT::Detail::RDF;
0036
0037
0038 namespace GraphDrawing {
0039 std::shared_ptr<GraphNode>
0040 CreateFilterNode(const RFilterBase *filterPtr, std::unordered_map<void *, std::shared_ptr<GraphNode>> &visitedMap);
0041
0042 std::shared_ptr<GraphNode> AddDefinesToGraph(std::shared_ptr<GraphNode> node, const RColumnRegister &colRegister,
0043 const std::vector<std::string> &prevNodeDefines,
0044 std::unordered_map<void *, std::shared_ptr<GraphNode>> &visitedMap);
0045 }
0046
0047 }
0048 }
0049
0050 namespace Detail {
0051 namespace RDF {
0052 using namespace ROOT::TypeTraits;
0053 namespace RDFGraphDrawing = ROOT::Internal::RDF::GraphDrawing;
0054 class RJittedFilter;
0055
0056 template <typename FilterF, typename PrevNodeRaw>
0057 class R__CLING_PTRCHECK(off) RFilter final : public RFilterBase {
0058 using ColumnTypes_t = typename CallableTraits<FilterF>::arg_types;
0059 using TypeInd_t = std::make_index_sequence<ColumnTypes_t::list_size>;
0060
0061
0062
0063 using PrevNode_t = std::conditional_t<std::is_same<PrevNodeRaw, RJittedFilter>::value, RFilterBase, PrevNodeRaw>;
0064
0065 FilterF fFilter;
0066
0067 std::vector<std::array<RColumnReaderBase *, ColumnTypes_t::list_size>> fValues;
0068 const std::shared_ptr<PrevNode_t> fPrevNodePtr;
0069 PrevNode_t &fPrevNode;
0070
0071 public:
0072 RFilter(FilterF f, const ROOT::RDF::ColumnNames_t &columns, std::shared_ptr<PrevNode_t> pd,
0073 const RDFInternal::RColumnRegister &colRegister, std::string_view name = "",
0074 const std::string &variationName = "nominal")
0075 : RFilterBase(pd->GetLoopManagerUnchecked(), name, pd->GetLoopManagerUnchecked()->GetNSlots(), colRegister,
0076 columns, pd->GetVariations(), variationName),
0077 fFilter(std::move(f)), fValues(pd->GetLoopManagerUnchecked()->GetNSlots()), fPrevNodePtr(std::move(pd)),
0078 fPrevNode(*fPrevNodePtr)
0079 {
0080 fLoopManager->Register(this);
0081 }
0082
0083
0084
0085 RFilter(const RFilter &) = delete;
0086 RFilter &operator=(const RFilter &) = delete;
0087 RFilter(RFilter &&) = delete;
0088 RFilter &operator=(RFilter &&) = delete;
0089 ~RFilter() final
0090 {
0091
0092
0093 fLoopManager->Deregister(this);
0094 }
0095
0096 bool CheckFilters(unsigned int slot, Long64_t entry) final
0097 {
0098 if (entry != fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()]) {
0099 if (!fPrevNode.CheckFilters(slot, entry)) {
0100
0101 fLastResult[slot * RDFInternal::CacheLineStep<int>()] = false;
0102 } else {
0103
0104 auto passed = CheckFilterHelper(slot, entry, ColumnTypes_t{}, TypeInd_t{});
0105 passed ? ++fAccepted[slot * RDFInternal::CacheLineStep<ULong64_t>()]
0106 : ++fRejected[slot * RDFInternal::CacheLineStep<ULong64_t>()];
0107 fLastResult[slot * RDFInternal::CacheLineStep<int>()] = passed;
0108 }
0109 fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = entry;
0110 }
0111 return fLastResult[slot * RDFInternal::CacheLineStep<int>()];
0112 }
0113
0114 template <typename ColType>
0115 auto GetValueChecked(unsigned int slot, std::size_t readerIdx, Long64_t entry) -> ColType &
0116 {
0117 if (auto *val = fValues[slot][readerIdx]->template TryGet<ColType>(entry))
0118 return *val;
0119
0120 throw std::out_of_range{"RDataFrame: Filter could not retrieve value for column '" + fColumnNames[readerIdx] +
0121 "' for entry " + std::to_string(entry) +
0122 ". You can use the DefaultValueFor operation to provide a default value, or "
0123 "FilterAvailable/FilterMissing to discard/keep entries with missing values instead."};
0124 }
0125
0126 template <typename... ColTypes, std::size_t... S>
0127 bool CheckFilterHelper(unsigned int slot, Long64_t entry, TypeList<ColTypes...>, std::index_sequence<S...>)
0128 {
0129 return fFilter(GetValueChecked<ColTypes>(slot, S, entry)...);
0130
0131 (void)slot;
0132 (void)entry;
0133 }
0134
0135 void InitSlot(TTreeReader *r, unsigned int slot) final
0136 {
0137 RDFInternal::RColumnReadersInfo info{fColumnNames, fColRegister, fIsDefine.data(), *fLoopManager};
0138 fValues[slot] = RDFInternal::GetColumnReaders(slot, r, ColumnTypes_t{}, info, fVariation);
0139 fLastCheckedEntry[slot * RDFInternal::CacheLineStep<Long64_t>()] = -1;
0140 }
0141
0142
0143 void Report(ROOT::RDF::RCutFlowReport &rep) const final { PartialReport(rep); }
0144
0145 void PartialReport(ROOT::RDF::RCutFlowReport &rep) const final
0146 {
0147 fPrevNode.PartialReport(rep);
0148 FillReport(rep);
0149 }
0150
0151 void StopProcessing() final
0152 {
0153 ++fNStopsReceived;
0154 if (fNStopsReceived == fNChildren)
0155 fPrevNode.StopProcessing();
0156 }
0157
0158 void IncrChildrenCount() final
0159 {
0160 ++fNChildren;
0161
0162 if (fNChildren == 1 && fName.empty())
0163 fPrevNode.IncrChildrenCount();
0164 }
0165
0166 void TriggerChildrenCount() final
0167 {
0168 assert(!fName.empty());
0169 fPrevNode.IncrChildrenCount();
0170 }
0171
0172 void AddFilterName(std::vector<std::string> &filters) final
0173 {
0174 fPrevNode.AddFilterName(filters);
0175 auto name = (HasName() ? fName : "Unnamed Filter");
0176 filters.push_back(name);
0177 }
0178
0179
0180 void FinalizeSlot(unsigned int slot) final { fValues[slot].fill(nullptr); }
0181
0182 std::shared_ptr<RDFGraphDrawing::GraphNode>
0183 GetGraph(std::unordered_map<void *, std::shared_ptr<RDFGraphDrawing::GraphNode>> &visitedMap) final
0184 {
0185
0186 auto prevNode = fPrevNode.GetGraph(visitedMap);
0187 const auto &prevColumns = prevNode->GetDefinedColumns();
0188
0189 auto thisNode = RDFGraphDrawing::CreateFilterNode(this, visitedMap);
0190
0191
0192
0193
0194 if (!thisNode->IsNew()) {
0195 return thisNode;
0196 }
0197
0198 auto upmostNode = AddDefinesToGraph(thisNode, fColRegister, prevColumns, visitedMap);
0199
0200
0201 thisNode->AddDefinedColumns(fColRegister.GenerateColumnNames());
0202
0203 upmostNode->SetPrevNode(prevNode);
0204 return thisNode;
0205 }
0206
0207
0208 std::shared_ptr<RNodeBase> GetVariedFilter(const std::string &variationName) final
0209 {
0210
0211 assert(fVariation == "nominal");
0212
0213
0214 assert(variationName != "nominal");
0215
0216
0217 assert(RDFInternal::IsStrInVec(variationName, fVariations));
0218
0219 auto it = fVariedFilters.find(variationName);
0220 if (it != fVariedFilters.end())
0221 return it->second;
0222
0223 auto prevNode = fPrevNodePtr;
0224 if (static_cast<RNodeBase *>(fPrevNodePtr.get()) != static_cast<RNodeBase *>(fLoopManager) &&
0225 RDFInternal::IsStrInVec(variationName, prevNode->GetVariations()))
0226 prevNode = std::static_pointer_cast<PrevNode_t>(prevNode->GetVariedFilter(variationName));
0227
0228
0229
0230 auto variedFilter = std::unique_ptr<RFilterBase>(
0231 new RFilter(fFilter, fColumnNames, std::move(prevNode), fColRegister, fName, variationName));
0232 auto e = fVariedFilters.insert({variationName, std::move(variedFilter)});
0233 return e.first->second;
0234 }
0235 };
0236
0237 }
0238 }
0239 }
0240
0241 #endif