File indexing completed on 2025-01-18 10:10:49
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #include <ROOT/RDataFrame.hxx>
0012 #include <ROOT/RDataSource.hxx>
0013 #include <ROOT/RVec.hxx>
0014 #include <ROOT/TSeq.hxx>
0015
0016 #include <algorithm>
0017 #include <functional>
0018 #include <map>
0019 #include <memory>
0020 #include <string>
0021 #include <tuple>
0022 #include <typeinfo>
0023 #include <utility>
0024 #include <vector>
0025
0026 #ifndef ROOT_RVECDS
0027 #define ROOT_RVECDS
0028
0029 namespace ROOT {
0030
0031 namespace Internal {
0032
0033 namespace RDF {
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044 template <typename... ColumnTypes>
0045 class RVecDS final : public ROOT::RDF::RDataSource {
0046 using PointerHolderPtrs_t = std::vector<ROOT::Internal::TDS::TPointerHolder *>;
0047
0048 std::tuple<ROOT::RVec<ColumnTypes>...> fColumns;
0049 const std::vector<std::string> fColNames;
0050 const std::map<std::string, std::string> fColTypesMap;
0051
0052
0053
0054
0055 const PointerHolderPtrs_t fPointerHoldersModels;
0056 std::vector<PointerHolderPtrs_t> fPointerHolders;
0057 std::vector<std::pair<ULong64_t, ULong64_t>> fEntryRanges{};
0058 unsigned int fNSlots{0};
0059 std::function<void()> fDeleteRVecs;
0060
0061 Record_t GetColumnReadersImpl(std::string_view colName, const std::type_info &id)
0062 {
0063 auto colNameStr = std::string(colName);
0064
0065 const auto idName = ROOT::Internal::RDF::TypeID2TypeName(id);
0066 auto it = fColTypesMap.find(colNameStr);
0067 if (fColTypesMap.end() == it) {
0068 std::string err = "The specified column name, \"" + colNameStr + "\" is not known to the data source.";
0069 throw std::runtime_error(err);
0070 }
0071
0072 const auto colIdName = it->second;
0073 if (colIdName != idName) {
0074 std::string err = "Column " + colNameStr + " has type " + colIdName +
0075 " while the id specified is associated to type " + idName;
0076 throw std::runtime_error(err);
0077 }
0078
0079 const auto colBegin = fColNames.begin();
0080 const auto colEnd = fColNames.end();
0081 const auto namesIt = std::find(colBegin, colEnd, colName);
0082 const auto index = std::distance(colBegin, namesIt);
0083
0084 Record_t ret(fNSlots);
0085 for (auto slot : ROOT::TSeqU(fNSlots)) {
0086 ret[slot] = fPointerHolders[index][slot]->GetPointerAddr();
0087 }
0088 return ret;
0089 }
0090
0091 size_t GetEntriesNumber() { return std::get<0>(fColumns).size(); }
0092 template <std::size_t... S>
0093 void SetEntryHelper(unsigned int slot, ULong64_t entry, std::index_sequence<S...>)
0094 {
0095 std::initializer_list<int> expander{
0096 (*static_cast<ColumnTypes *>(fPointerHolders[S][slot]->GetPointer()) = std::get<S>(fColumns)[entry], 0)...};
0097 (void)expander;
0098 }
0099
0100 template <std::size_t... S>
0101 void ColLengthChecker(std::index_sequence<S...>)
0102 {
0103 if (sizeof...(S) < 2)
0104 return;
0105
0106 const std::vector<size_t> colLengths{std::get<S>(fColumns).size()...};
0107 const auto expectedLen = colLengths[0];
0108 std::string err;
0109 for (auto i : TSeqI(1, colLengths.size())) {
0110 if (expectedLen != colLengths[i]) {
0111 err += "Column \"" + fColNames[i] + "\" and column \"" + fColNames[0] +
0112 "\" have different lengths: " + std::to_string(expectedLen) + " and " +
0113 std::to_string(colLengths[i]);
0114 }
0115 }
0116 if (!err.empty()) {
0117 throw std::runtime_error(err);
0118 }
0119 }
0120
0121 protected:
0122 std::string AsString() { return "Numpy data source"; };
0123
0124 public:
0125 RVecDS(std::function<void()> deleteRVecs, std::pair<std::string, ROOT::RVec<ColumnTypes>> const &...colsNameVals)
0126 : fColumns(colsNameVals.second...),
0127 fColNames{colsNameVals.first...},
0128 fColTypesMap({{colsNameVals.first, ROOT::Internal::RDF::TypeID2TypeName(typeid(ColumnTypes))}...}),
0129 fPointerHoldersModels({new ROOT::Internal::TDS::TTypedPointerHolder<ColumnTypes>(new ColumnTypes())...}),
0130 fDeleteRVecs(deleteRVecs)
0131 {
0132 }
0133
0134 ~RVecDS()
0135 {
0136 for (auto &&ptrHolderv : fPointerHolders) {
0137 for (auto &&ptrHolder : ptrHolderv) {
0138 delete ptrHolder;
0139 }
0140 }
0141
0142 fDeleteRVecs();
0143 }
0144
0145 const std::vector<std::string> &GetColumnNames() const { return fColNames; }
0146
0147 std::vector<std::pair<ULong64_t, ULong64_t>> GetEntryRanges()
0148 {
0149 auto entryRanges(std::move(fEntryRanges));
0150 return entryRanges;
0151 }
0152
0153 std::string GetTypeName(std::string_view colName) const
0154 {
0155 const auto key = std::string(colName);
0156 return fColTypesMap.at(key);
0157 }
0158
0159 bool HasColumn(std::string_view colName) const
0160 {
0161 const auto key = std::string(colName);
0162 const auto endIt = fColTypesMap.end();
0163 return endIt != fColTypesMap.find(key);
0164 }
0165
0166 bool SetEntry(unsigned int slot, ULong64_t entry)
0167 {
0168 SetEntryHelper(slot, entry, std::index_sequence_for<ColumnTypes...>());
0169 return true;
0170 }
0171
0172 void SetNSlots(unsigned int nSlots)
0173 {
0174 fNSlots = nSlots;
0175 const auto nCols = fColNames.size();
0176 fPointerHolders.resize(nCols);
0177 auto colIndex = 0U;
0178 for (auto &&ptrHolderv : fPointerHolders) {
0179 for (auto slot : ROOT::TSeqI(fNSlots)) {
0180 auto ptrHolder = fPointerHoldersModels[colIndex]->GetDeepCopy();
0181 ptrHolderv.emplace_back(ptrHolder);
0182 (void)slot;
0183 }
0184 colIndex++;
0185 }
0186 for (auto &&ptrHolder : fPointerHoldersModels)
0187 delete ptrHolder;
0188 }
0189
0190 void Initialize()
0191 {
0192 ColLengthChecker(std::index_sequence_for<ColumnTypes...>());
0193 const auto nEntries = GetEntriesNumber();
0194 const auto nEntriesInRange = nEntries / fNSlots;
0195 auto reminder = 1U == fNSlots ? 0 : nEntries % fNSlots;
0196 fEntryRanges.resize(fNSlots);
0197 auto init = 0ULL;
0198 auto end = 0ULL;
0199 for (auto &&range : fEntryRanges) {
0200 end = init + nEntriesInRange;
0201 if (0 != reminder) {
0202 reminder--;
0203 end += 1;
0204 }
0205 range.first = init;
0206 range.second = end;
0207 init = end;
0208 }
0209 }
0210
0211 std::string GetLabel() { return "RVecDS"; }
0212 };
0213
0214
0215
0216
0217
0218 template <typename... ColumnTypes>
0219 std::unique_ptr<RDataFrame>
0220 MakeRVecDataFrame(std::function<void()> deleteRVecs,
0221 std::pair<std::string, ROOT::RVec<ColumnTypes>> const &...colNameProxyPairs)
0222 {
0223 return std::make_unique<RDataFrame>(std::make_unique<RVecDS<ColumnTypes...>>(deleteRVecs, colNameProxyPairs...));
0224 }
0225
0226 }
0227 }
0228 }
0229
0230 #endif