File indexing completed on 2024-11-15 09:55:43
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #ifndef ROOT_RDF_RINTERFACEBASE
0012 #define ROOT_RDF_RINTERFACEBASE
0013
0014 #include "ROOT/RVec.hxx"
0015 #include <ROOT/RDF/InterfaceUtils.hxx>
0016 #include <ROOT/RDF/RColumnRegister.hxx>
0017 #include <ROOT/RDF/RDisplay.hxx>
0018 #include <ROOT/RDF/RLoopManager.hxx>
0019 #include <ROOT/RDataSource.hxx>
0020 #include <ROOT/RResultPtr.hxx>
0021 #include <ROOT/RStringView.hxx>
0022 #include <TError.h> // R__ASSERT
0023
0024 #include <memory>
0025 #include <set>
0026 #include <string>
0027 #include <vector>
0028
0029 namespace ROOT {
0030 namespace RDF {
0031
0032 class RDFDescription;
0033 class RVariationsDescription;
0034
0035 using ColumnNames_t = std::vector<std::string>;
0036
0037 namespace RDFDetail = ROOT::Detail::RDF;
0038 namespace RDFInternal = ROOT::Internal::RDF;
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052 class RInterfaceBase {
0053 protected:
0054
0055 RDFDetail::RLoopManager *fLoopManager;
0056
0057 RDataSource *fDataSource = nullptr;
0058
0059
0060 RDFInternal::RColumnRegister fColRegister;
0061
0062 std::string DescribeDataset() const;
0063
0064 ColumnNames_t GetColumnTypeNamesList(const ColumnNames_t &columnList);
0065
0066 void CheckIMTDisabled(std::string_view callerName);
0067
0068 void AddDefaultColumns();
0069
0070 template <typename RetType>
0071 void SanityChecksForVary(const std::vector<std::string> &colNames, const std::vector<std::string> &variationTags,
0072 std::string_view variationName)
0073 {
0074 R__ASSERT(variationTags.size() > 0 && "Must have at least one variation.");
0075 R__ASSERT(colNames.size() > 0 && "Must have at least one varied column.");
0076 R__ASSERT(!variationName.empty() && "Must provide a variation name.");
0077
0078 for (auto &colName : colNames) {
0079 RDFInternal::CheckForDefinition("Vary", colName, fColRegister, fLoopManager->GetBranchNames(),
0080 fDataSource ? fDataSource->GetColumnNames() : ColumnNames_t{});
0081 }
0082 RDFInternal::CheckValidCppVarName(variationName, "Vary");
0083
0084 static_assert(ROOT::Internal::VecOps::IsRVec<RetType>::value, "Vary expressions must return an RVec.");
0085
0086 if (colNames.size() > 1) {
0087 constexpr bool hasInnerRVec = ROOT::Internal::VecOps::IsRVec<typename RetType::value_type>::value;
0088 if (!hasInnerRVec)
0089 throw std::runtime_error("This Vary call is varying multiple columns simultaneously but the expression "
0090 "does not return an RVec of RVecs.");
0091
0092 auto colTypes = GetColumnTypeNamesList(colNames);
0093 auto allColTypesEqual =
0094 std::all_of(colTypes.begin() + 1, colTypes.end(), [&](const std::string &t) { return t == colTypes[0]; });
0095 if (!allColTypesEqual)
0096 throw std::runtime_error("Cannot simultaneously vary multiple columns of different types.");
0097
0098 const auto &innerTypeID = typeid(RDFInternal::InnerValueType_t<RetType>);
0099
0100 for (auto i = 0u; i < colTypes.size(); ++i) {
0101 const auto *define = fColRegister.GetDefine(colNames[i]);
0102 const auto *expectedTypeID = define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(colTypes[i]);
0103 if (innerTypeID != *expectedTypeID)
0104 throw std::runtime_error("Varied values for column \"" + colNames[i] + "\" have a different type (" +
0105 RDFInternal::TypeID2TypeName(innerTypeID) + ") than the nominal value (" +
0106 colTypes[i] + ").");
0107 }
0108 } else {
0109 const auto &retTypeID = typeid(typename RetType::value_type);
0110 const auto &colName = colNames[0];
0111 const auto *define = fColRegister.GetDefine(colName);
0112 const auto *expectedTypeID =
0113 define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(GetColumnType(colName));
0114 if (retTypeID != *expectedTypeID)
0115 throw std::runtime_error("Varied values for column \"" + colName + "\" have a different type (" +
0116 RDFInternal::TypeID2TypeName(retTypeID) + ") than the nominal value (" +
0117 GetColumnType(colName) + ").");
0118 }
0119
0120
0121 if (colNames.size() > 1) {
0122 std::set<std::string> uniqueCols(colNames.begin(), colNames.end());
0123 if (uniqueCols.size() != colNames.size())
0124 throw std::logic_error("A column name was passed to the same Vary invocation multiple times.");
0125 }
0126 }
0127
0128 RDFDetail::RLoopManager *GetLoopManager() const { return fLoopManager; }
0129
0130 ColumnNames_t GetValidatedColumnNames(const unsigned int nColumns, const ColumnNames_t &columns)
0131 {
0132 return RDFInternal::GetValidatedColumnNames(*fLoopManager, nColumns, columns, fColRegister, fDataSource);
0133 }
0134
0135 template <typename... ColumnTypes>
0136 void CheckAndFillDSColumns(ColumnNames_t validCols, TTraits::TypeList<ColumnTypes...> typeList)
0137 {
0138 if (fDataSource != nullptr)
0139 RDFInternal::AddDSColumns(validCols, *fLoopManager, *fDataSource, typeList, fColRegister);
0140 }
0141
0142
0143
0144
0145
0146
0147 template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0148 typename HelperArgType = ActionResultType,
0149 std::enable_if_t<!RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0150 RResultPtr<ActionResultType> CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0151 const std::shared_ptr<HelperArgType> &helperArg,
0152 const std::shared_ptr<RDFNode> &proxiedPtr, const int = -1)
0153 {
0154 constexpr auto nColumns = sizeof...(ColTypes);
0155
0156 const auto validColumnNames = GetValidatedColumnNames(nColumns, columns);
0157 CheckAndFillDSColumns(validColumnNames, RDFInternal::TypeList<ColTypes...>());
0158
0159 const auto nSlots = fLoopManager->GetNSlots();
0160
0161 auto action = RDFInternal::BuildAction<ColTypes...>(validColumnNames, helperArg, nSlots, proxiedPtr, ActionTag{},
0162 fColRegister);
0163 return MakeResultPtr(r, *fLoopManager, std::move(action));
0164 }
0165
0166
0167
0168
0169
0170 template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0171 typename HelperArgType = ActionResultType,
0172 std::enable_if_t<RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0173 RResultPtr<ActionResultType> CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0174 const std::shared_ptr<HelperArgType> &helperArg,
0175 const std::shared_ptr<RDFNode> &proxiedPtr, const int nColumns = -1)
0176 {
0177 auto realNColumns = (nColumns > -1 ? nColumns : sizeof...(ColTypes));
0178
0179 const auto validColumnNames = GetValidatedColumnNames(realNColumns, columns);
0180 const unsigned int nSlots = fLoopManager->GetNSlots();
0181
0182 auto *tree = fLoopManager->GetTree();
0183 auto *helperArgOnHeap = RDFInternal::MakeSharedOnHeap(helperArg);
0184
0185 auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(proxiedPtr));
0186
0187 const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(*fLoopManager, validColumnNames,
0188 fColRegister, proxiedPtr->GetVariations());
0189 auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);
0190
0191 auto toJit =
0192 RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType), typeid(ActionTag),
0193 helperArgOnHeap, tree, nSlots, fColRegister, fDataSource, jittedActionOnHeap);
0194 fLoopManager->ToJitExec(toJit);
0195 return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
0196 }
0197
0198 public:
0199 RInterfaceBase(std::shared_ptr<RDFDetail::RLoopManager> lm);
0200 RInterfaceBase(RDFDetail::RLoopManager &lm, const RDFInternal::RColumnRegister &colRegister);
0201
0202 ColumnNames_t GetColumnNames();
0203
0204 std::string GetColumnType(std::string_view column);
0205
0206 RDFDescription Describe();
0207
0208 RVariationsDescription GetVariations() const;
0209 bool HasColumn(std::string_view columnName);
0210 ColumnNames_t GetDefinedColumnNames();
0211 unsigned int GetNSlots() const;
0212 unsigned int GetNRuns() const;
0213 unsigned int GetNFiles();
0214 };
0215 }
0216 }
0217
0218 #endif