File indexing completed on 2025-09-16 09:08:24
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #ifndef ROOT_RDF_RINTERFACEBASE
0012 #define ROOT_RDF_RINTERFACEBASE
0013
0014 #include "ROOT/RVec.hxx"
0015 #include <ROOT/RDF/InterfaceUtils.hxx>
0016 #include <ROOT/RDF/RColumnRegister.hxx>
0017 #include <ROOT/RDF/RDisplay.hxx>
0018 #include <ROOT/RDF/RLoopManager.hxx>
0019 #include <ROOT/RDataSource.hxx>
0020 #include <ROOT/RResultPtr.hxx>
0021 #include <string_view>
0022 #include <TError.h> // R__ASSERT
0023
0024 #include <memory>
0025 #include <set>
0026 #include <string>
0027 #include <vector>
0028
0029 namespace ROOT {
0030 namespace RDF {
0031
0032 class RDFDescription;
0033 class RVariationsDescription;
0034
0035 using ColumnNames_t = std::vector<std::string>;
0036
0037 namespace RDFDetail = ROOT::Detail::RDF;
0038 namespace RDFInternal = ROOT::Internal::RDF;
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052 class RInterfaceBase {
0053 protected:
0054
0055 std::shared_ptr<ROOT::Detail::RDF::RLoopManager> fLoopManager;
0056
0057
0058 RDFInternal::RColumnRegister fColRegister;
0059
0060 std::string DescribeDataset() const;
0061
0062 ColumnNames_t GetColumnTypeNamesList(const ColumnNames_t &columnList);
0063
0064 void CheckIMTDisabled(std::string_view callerName);
0065
0066 void AddDefaultColumns();
0067
0068 template <typename RetType>
0069 void SanityChecksForVary(const std::vector<std::string> &colNames, const std::vector<std::string> &variationTags,
0070 std::string_view variationName)
0071 {
0072 R__ASSERT(!variationTags.empty() && "Must have at least one variation.");
0073 R__ASSERT(!colNames.empty() && "Must have at least one varied column.");
0074 R__ASSERT(!variationName.empty() && "Must provide a variation name.");
0075
0076 for (auto &colName : colNames) {
0077 RDFInternal::CheckForDefinition("Vary", colName, fColRegister, fLoopManager->GetBranchNames(),
0078 GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});
0079 }
0080 RDFInternal::CheckValidCppVarName(variationName, "Vary");
0081
0082 static_assert(ROOT::Internal::VecOps::IsRVec<RetType>::value, "Vary expressions must return an RVec.");
0083
0084 if (colNames.size() > 1) {
0085 constexpr bool hasInnerRVec = ROOT::Internal::VecOps::IsRVec<typename RetType::value_type>::value;
0086 if (!hasInnerRVec)
0087 throw std::runtime_error("This Vary call is varying multiple columns simultaneously but the expression "
0088 "does not return an RVec of RVecs.");
0089
0090
0091
0092
0093 auto colTypes = GetColumnTypeNamesList(colNames);
0094 auto &&nColTypes = colTypes.size();
0095
0096 std::vector<const std::type_info *> colTypeIDs(nColTypes);
0097 const auto &innerTypeID = typeid(RDFInternal::InnerValueType_t<RetType>);
0098 for (decltype(nColTypes) i{}; i < nColTypes; ++i) {
0099
0100
0101
0102
0103 const auto *define = fColRegister.GetDefine(colNames[i]);
0104 colTypeIDs[i] = define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(colTypes[i]);
0105
0106 if (*colTypeIDs[i] != *colTypeIDs[0]) {
0107 throw std::runtime_error("Cannot simultaneously vary multiple columns of different types.");
0108 }
0109
0110 if (innerTypeID != *colTypeIDs[i])
0111 throw std::runtime_error("Varied values for column \"" + colNames[i] + "\" have a different type (" +
0112 RDFInternal::TypeID2TypeName(innerTypeID) + ") than the nominal value (" +
0113 colTypes[i] + ").");
0114 }
0115
0116 } else {
0117 const auto &retTypeID = typeid(typename RetType::value_type);
0118 const auto &colName = colNames[0];
0119 const auto *define = fColRegister.GetDefine(colName);
0120 const auto *expectedTypeID =
0121 define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(GetColumnType(colName));
0122 if (retTypeID != *expectedTypeID)
0123 throw std::runtime_error("Varied values for column \"" + colName + "\" have a different type (" +
0124 RDFInternal::TypeID2TypeName(retTypeID) + ") than the nominal value (" +
0125 GetColumnType(colName) + ").");
0126 }
0127
0128
0129 if (colNames.size() > 1) {
0130 std::set<std::string> uniqueCols(colNames.begin(), colNames.end());
0131 if (uniqueCols.size() != colNames.size())
0132 throw std::logic_error("A column name was passed to the same Vary invocation multiple times.");
0133 }
0134 }
0135
0136 RDFDetail::RLoopManager *GetLoopManager() const { return fLoopManager.get(); }
0137 RDataSource *GetDataSource() const { return fLoopManager->GetDataSource(); }
0138
0139 ColumnNames_t GetValidatedColumnNames(const unsigned int nColumns, const ColumnNames_t &columns)
0140 {
0141 return RDFInternal::GetValidatedColumnNames(*fLoopManager, nColumns, columns, fColRegister, GetDataSource());
0142 }
0143
0144 template <typename... ColumnTypes>
0145 void CheckAndFillDSColumns(ColumnNames_t validCols, TTraits::TypeList<ColumnTypes...> typeList)
0146 {
0147 if (auto dataSource = GetDataSource())
0148 RDFInternal::AddDSColumns(validCols, *fLoopManager, *dataSource, typeList, fColRegister);
0149 }
0150
0151
0152
0153
0154
0155
0156 template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0157 typename HelperArgType = ActionResultType,
0158 std::enable_if_t<!RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0159 RResultPtr<ActionResultType> CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0160 const std::shared_ptr<HelperArgType> &helperArg,
0161 const std::shared_ptr<RDFNode> &proxiedPtr, const int = -1)
0162 {
0163 constexpr auto nColumns = sizeof...(ColTypes);
0164
0165 const auto validColumnNames = GetValidatedColumnNames(nColumns, columns);
0166 CheckAndFillDSColumns(validColumnNames, RDFInternal::TypeList<ColTypes...>());
0167
0168 const auto nSlots = fLoopManager->GetNSlots();
0169
0170 auto action = RDFInternal::BuildAction<ColTypes...>(validColumnNames, helperArg, nSlots, proxiedPtr, ActionTag{},
0171 fColRegister);
0172 return MakeResultPtr(r, *fLoopManager, std::move(action));
0173 }
0174
0175
0176
0177
0178
0179 template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0180 typename HelperArgType = ActionResultType,
0181 std::enable_if_t<RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0182 RResultPtr<ActionResultType>
0183 CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0184 const std::shared_ptr<HelperArgType> &helperArg, const std::shared_ptr<RDFNode> &proxiedPtr,
0185 const int nColumns = -1, const bool vector2RVec = true)
0186 {
0187 auto realNColumns = (nColumns > -1 ? nColumns : sizeof...(ColTypes));
0188
0189 const auto validColumnNames = GetValidatedColumnNames(realNColumns, columns);
0190 const unsigned int nSlots = fLoopManager->GetNSlots();
0191
0192 auto *tree = fLoopManager->GetTree();
0193 auto *helperArgOnHeap = RDFInternal::MakeSharedOnHeap(helperArg);
0194
0195 auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(proxiedPtr));
0196
0197 const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(*fLoopManager, validColumnNames,
0198 fColRegister, proxiedPtr->GetVariations());
0199 auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);
0200
0201 auto toJit = RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType),
0202 typeid(ActionTag), helperArgOnHeap, tree, nSlots, fColRegister,
0203 GetDataSource(), jittedActionOnHeap, vector2RVec);
0204 fLoopManager->ToJitExec(toJit);
0205 return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
0206 }
0207
0208 public:
0209 RInterfaceBase(std::shared_ptr<RDFDetail::RLoopManager> lm);
0210 RInterfaceBase(RDFDetail::RLoopManager &lm, const RDFInternal::RColumnRegister &colRegister);
0211
0212 ColumnNames_t GetColumnNames();
0213
0214 std::string GetColumnType(std::string_view column);
0215
0216 RDFDescription Describe();
0217
0218 RVariationsDescription GetVariations() const;
0219 bool HasColumn(std::string_view columnName);
0220 ColumnNames_t GetDefinedColumnNames();
0221 unsigned int GetNSlots() const;
0222 unsigned int GetNRuns() const;
0223 unsigned int GetNFiles();
0224 };
0225 }
0226 }
0227
0228 #endif