Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-16 09:08:24

0001 // Author: Enrico Guiraud CERN 08/2022
0002 
0003 /*************************************************************************
0004  * Copyright (C) 1995-2022, Rene Brun and Fons Rademakers.               *
0005  * All rights reserved.                                                  *
0006  *                                                                       *
0007  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0008  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0009  *************************************************************************/
0010 
0011 #ifndef ROOT_RDF_RINTERFACEBASE
0012 #define ROOT_RDF_RINTERFACEBASE
0013 
0014 #include "ROOT/RVec.hxx"
0015 #include <ROOT/RDF/InterfaceUtils.hxx>
0016 #include <ROOT/RDF/RColumnRegister.hxx>
0017 #include <ROOT/RDF/RDisplay.hxx>
0018 #include <ROOT/RDF/RLoopManager.hxx>
0019 #include <ROOT/RDataSource.hxx>
0020 #include <ROOT/RResultPtr.hxx>
0021 #include <string_view>
0022 #include <TError.h> // R__ASSERT
0023 
0024 #include <memory>
0025 #include <set>
0026 #include <string>
0027 #include <vector>
0028 
0029 namespace ROOT {
0030 namespace RDF {
0031 
0032 class RDFDescription;
0033 class RVariationsDescription;
0034 
0035 using ColumnNames_t = std::vector<std::string>;
0036 
0037 namespace RDFDetail = ROOT::Detail::RDF;
0038 namespace RDFInternal = ROOT::Internal::RDF;
0039 
0040 // clang-format off
0041 /**
0042  * \class ROOT::Internal::RDF::RInterfaceBase
0043  * \ingroup dataframe
0044  * \brief The public interface to the RDataFrame federation of classes.
0045  * \tparam Proxied One of the "node" base types (e.g. RLoopManager, RFilterBase). The user never specifies this type manually.
0046  * \tparam DataSource The type of the RDataSource which is providing the data to the data frame. There is no source by default.
0047  *
0048  * The documentation of each method features a one liner illustrating how to use the method, for example showing how
0049  * the majority of the template parameters are automatically deduced requiring no or very little effort by the user.
0050  */
0051 // clang-format on
0052 class RInterfaceBase {
0053 protected:
0054    ///< The RLoopManager at the root of this computation graph. Never null.
0055    std::shared_ptr<ROOT::Detail::RDF::RLoopManager> fLoopManager;
0056 
0057    /// Contains the columns defined up to this node.
0058    RDFInternal::RColumnRegister fColRegister;
0059 
0060    std::string DescribeDataset() const;
0061 
0062    ColumnNames_t GetColumnTypeNamesList(const ColumnNames_t &columnList);
0063 
0064    void CheckIMTDisabled(std::string_view callerName);
0065 
0066    void AddDefaultColumns();
0067 
0068    template <typename RetType>
0069    void SanityChecksForVary(const std::vector<std::string> &colNames, const std::vector<std::string> &variationTags,
0070                             std::string_view variationName)
0071    {
0072       R__ASSERT(!variationTags.empty() && "Must have at least one variation.");
0073       R__ASSERT(!colNames.empty() && "Must have at least one varied column.");
0074       R__ASSERT(!variationName.empty() && "Must provide a variation name.");
0075 
0076       for (auto &colName : colNames) {
0077          RDFInternal::CheckForDefinition("Vary", colName, fColRegister, fLoopManager->GetBranchNames(),
0078                                          GetDataSource() ? GetDataSource()->GetColumnNames() : ColumnNames_t{});
0079       }
0080       RDFInternal::CheckValidCppVarName(variationName, "Vary");
0081 
0082       static_assert(ROOT::Internal::VecOps::IsRVec<RetType>::value, "Vary expressions must return an RVec.");
0083 
0084       if (colNames.size() > 1) { // we are varying multiple columns simultaneously, RetType is RVec<RVec<T>>
0085          constexpr bool hasInnerRVec = ROOT::Internal::VecOps::IsRVec<typename RetType::value_type>::value;
0086          if (!hasInnerRVec)
0087             throw std::runtime_error("This Vary call is varying multiple columns simultaneously but the expression "
0088                                      "does not return an RVec of RVecs.");
0089 
0090          // Check for type mismatches. We are interested in two cases:
0091          // - All columns that are going to be varied must be of the same type
0092          // - The return type of the expression must match the type of the nominal column
0093          auto colTypes = GetColumnTypeNamesList(colNames);
0094          auto &&nColTypes = colTypes.size();
0095          // Cache type_info when requested
0096          std::vector<const std::type_info *> colTypeIDs(nColTypes);
0097          const auto &innerTypeID = typeid(RDFInternal::InnerValueType_t<RetType>);
0098          for (decltype(nColTypes) i{}; i < nColTypes; ++i) {
0099             // Need to retrieve the type_info for each column. We start with
0100             // checking if the column comes from a Define, in which case the
0101             // type_info is cached already. Otherwise, we need to retrieve it
0102             // via TypeName2TypeID, which eventually might call the interpreter.
0103             const auto *define = fColRegister.GetDefine(colNames[i]);
0104             colTypeIDs[i] = define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(colTypes[i]);
0105             // First check: whether the current column type is the same as the first one.
0106             if (*colTypeIDs[i] != *colTypeIDs[0]) {
0107                throw std::runtime_error("Cannot simultaneously vary multiple columns of different types.");
0108             }
0109             // Second check: mismatch between varied type and nominal type
0110             if (innerTypeID != *colTypeIDs[i])
0111                throw std::runtime_error("Varied values for column \"" + colNames[i] + "\" have a different type (" +
0112                                         RDFInternal::TypeID2TypeName(innerTypeID) + ") than the nominal value (" +
0113                                         colTypes[i] + ").");
0114          }
0115 
0116       } else { // we are varying a single column, RetType is RVec<T>
0117          const auto &retTypeID = typeid(typename RetType::value_type);
0118          const auto &colName = colNames[0]; // we have only one element in there
0119          const auto *define = fColRegister.GetDefine(colName);
0120          const auto *expectedTypeID =
0121             define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(GetColumnType(colName));
0122          if (retTypeID != *expectedTypeID)
0123             throw std::runtime_error("Varied values for column \"" + colName + "\" have a different type (" +
0124                                      RDFInternal::TypeID2TypeName(retTypeID) + ") than the nominal value (" +
0125                                      GetColumnType(colName) + ").");
0126       }
0127 
0128       // when varying multiple columns, they must be different columns
0129       if (colNames.size() > 1) {
0130          std::set<std::string> uniqueCols(colNames.begin(), colNames.end());
0131          if (uniqueCols.size() != colNames.size())
0132             throw std::logic_error("A column name was passed to the same Vary invocation multiple times.");
0133       }
0134    }
0135 
0136    RDFDetail::RLoopManager *GetLoopManager() const { return fLoopManager.get(); }
0137    RDataSource *GetDataSource() const { return fLoopManager->GetDataSource(); }
0138 
0139    ColumnNames_t GetValidatedColumnNames(const unsigned int nColumns, const ColumnNames_t &columns)
0140    {
0141       return RDFInternal::GetValidatedColumnNames(*fLoopManager, nColumns, columns, fColRegister, GetDataSource());
0142    }
0143 
0144    template <typename... ColumnTypes>
0145    void CheckAndFillDSColumns(ColumnNames_t validCols, TTraits::TypeList<ColumnTypes...> typeList)
0146    {
0147       if (auto dataSource = GetDataSource())
0148          RDFInternal::AddDSColumns(validCols, *fLoopManager, *dataSource, typeList, fColRegister);
0149    }
0150 
0151    /// Create RAction object, return RResultPtr for the action
0152    /// Overload for the case in which all column types were specified (no jitting).
0153    /// For most actions, `r` and `helperArg` will refer to the same object, because the only argument to forward to
0154    /// the action helper is the result value itself. We need the distinction for actions such as Snapshot or Cache,
0155    /// for which the constructor arguments of the action helper are different from the returned value.
0156    template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0157              typename HelperArgType = ActionResultType,
0158              std::enable_if_t<!RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0159    RResultPtr<ActionResultType> CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0160                                              const std::shared_ptr<HelperArgType> &helperArg,
0161                                              const std::shared_ptr<RDFNode> &proxiedPtr, const int /*nColumns*/ = -1)
0162    {
0163       constexpr auto nColumns = sizeof...(ColTypes);
0164 
0165       const auto validColumnNames = GetValidatedColumnNames(nColumns, columns);
0166       CheckAndFillDSColumns(validColumnNames, RDFInternal::TypeList<ColTypes...>());
0167 
0168       const auto nSlots = fLoopManager->GetNSlots();
0169 
0170       auto action = RDFInternal::BuildAction<ColTypes...>(validColumnNames, helperArg, nSlots, proxiedPtr, ActionTag{},
0171                                                           fColRegister);
0172       return MakeResultPtr(r, *fLoopManager, std::move(action));
0173    }
0174 
0175    /// Create RAction object, return RResultPtr for the action
0176    /// Overload for the case in which one or more column types were not specified (RTTI + jitting).
0177    /// This overload has a `nColumns` optional argument. If present, the number of required columns for
0178    /// this action is taken equal to nColumns, otherwise it is assumed to be sizeof...(ColTypes).
0179    template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0180              typename HelperArgType = ActionResultType,
0181              std::enable_if_t<RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0182    RResultPtr<ActionResultType>
0183    CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0184                 const std::shared_ptr<HelperArgType> &helperArg, const std::shared_ptr<RDFNode> &proxiedPtr,
0185                 const int nColumns = -1, const bool vector2RVec = true)
0186    {
0187       auto realNColumns = (nColumns > -1 ? nColumns : sizeof...(ColTypes));
0188 
0189       const auto validColumnNames = GetValidatedColumnNames(realNColumns, columns);
0190       const unsigned int nSlots = fLoopManager->GetNSlots();
0191 
0192       auto *tree = fLoopManager->GetTree();
0193       auto *helperArgOnHeap = RDFInternal::MakeSharedOnHeap(helperArg);
0194 
0195       auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(proxiedPtr));
0196 
0197       const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(*fLoopManager, validColumnNames,
0198                                                                              fColRegister, proxiedPtr->GetVariations());
0199       auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);
0200 
0201       auto toJit = RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType),
0202                                                typeid(ActionTag), helperArgOnHeap, tree, nSlots, fColRegister,
0203                                                GetDataSource(), jittedActionOnHeap, vector2RVec);
0204       fLoopManager->ToJitExec(toJit);
0205       return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
0206    }
0207 
0208 public:
0209    RInterfaceBase(std::shared_ptr<RDFDetail::RLoopManager> lm);
0210    RInterfaceBase(RDFDetail::RLoopManager &lm, const RDFInternal::RColumnRegister &colRegister);
0211 
0212    ColumnNames_t GetColumnNames();
0213 
0214    std::string GetColumnType(std::string_view column);
0215 
0216    RDFDescription Describe();
0217 
0218    RVariationsDescription GetVariations() const;
0219    bool HasColumn(std::string_view columnName);
0220    ColumnNames_t GetDefinedColumnNames();
0221    unsigned int GetNSlots() const;
0222    unsigned int GetNRuns() const;
0223    unsigned int GetNFiles();
0224 };
0225 } // namespace RDF
0226 } // namespace ROOT
0227 
0228 #endif