Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-11-15 09:55:43

0001 // Author: Enrico Guiraud CERN 08/2022
0002 
0003 /*************************************************************************
0004  * Copyright (C) 1995-2022, Rene Brun and Fons Rademakers.               *
0005  * All rights reserved.                                                  *
0006  *                                                                       *
0007  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0008  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0009  *************************************************************************/
0010 
0011 #ifndef ROOT_RDF_RINTERFACEBASE
0012 #define ROOT_RDF_RINTERFACEBASE
0013 
0014 #include "ROOT/RVec.hxx"
0015 #include <ROOT/RDF/InterfaceUtils.hxx>
0016 #include <ROOT/RDF/RColumnRegister.hxx>
0017 #include <ROOT/RDF/RDisplay.hxx>
0018 #include <ROOT/RDF/RLoopManager.hxx>
0019 #include <ROOT/RDataSource.hxx>
0020 #include <ROOT/RResultPtr.hxx>
0021 #include <ROOT/RStringView.hxx>
0022 #include <TError.h> // R__ASSERT
0023 
0024 #include <memory>
0025 #include <set>
0026 #include <string>
0027 #include <vector>
0028 
0029 namespace ROOT {
0030 namespace RDF {
0031 
0032 class RDFDescription;
0033 class RVariationsDescription;
0034 
0035 using ColumnNames_t = std::vector<std::string>;
0036 
0037 namespace RDFDetail = ROOT::Detail::RDF;
0038 namespace RDFInternal = ROOT::Internal::RDF;
0039 
0040 // clang-format off
0041 /**
0042  * \class ROOT::Internal::RDF::RInterfaceBase
0043  * \ingroup dataframe
0044  * \brief The public interface to the RDataFrame federation of classes.
0045  * \tparam Proxied One of the "node" base types (e.g. RLoopManager, RFilterBase). The user never specifies this type manually.
0046  * \tparam DataSource The type of the RDataSource which is providing the data to the data frame. There is no source by default.
0047  *
0048  * The documentation of each method features a one liner illustrating how to use the method, for example showing how
0049  * the majority of the template parameters are automatically deduced requiring no or very little effort by the user.
0050  */
0051 // clang-format on
0052 class RInterfaceBase {
0053 protected:
0054    ///< The RLoopManager at the root of this computation graph. Never null.
0055    RDFDetail::RLoopManager *fLoopManager;
0056    /// Non-owning pointer to a data-source object. Null if no data-source. RLoopManager has ownership of the object.
0057    RDataSource *fDataSource = nullptr;
0058 
0059    /// Contains the columns defined up to this node.
0060    RDFInternal::RColumnRegister fColRegister;
0061 
0062    std::string DescribeDataset() const;
0063 
0064    ColumnNames_t GetColumnTypeNamesList(const ColumnNames_t &columnList);
0065 
0066    void CheckIMTDisabled(std::string_view callerName);
0067 
0068    void AddDefaultColumns();
0069 
0070    template <typename RetType>
0071    void SanityChecksForVary(const std::vector<std::string> &colNames, const std::vector<std::string> &variationTags,
0072                             std::string_view variationName)
0073    {
0074       R__ASSERT(variationTags.size() > 0 && "Must have at least one variation.");
0075       R__ASSERT(colNames.size() > 0 && "Must have at least one varied column.");
0076       R__ASSERT(!variationName.empty() && "Must provide a variation name.");
0077 
0078       for (auto &colName : colNames) {
0079          RDFInternal::CheckForDefinition("Vary", colName, fColRegister, fLoopManager->GetBranchNames(),
0080                                          fDataSource ? fDataSource->GetColumnNames() : ColumnNames_t{});
0081       }
0082       RDFInternal::CheckValidCppVarName(variationName, "Vary");
0083 
0084       static_assert(ROOT::Internal::VecOps::IsRVec<RetType>::value, "Vary expressions must return an RVec.");
0085 
0086       if (colNames.size() > 1) { // we are varying multiple columns simultaneously, RetType is RVec<RVec<T>>
0087          constexpr bool hasInnerRVec = ROOT::Internal::VecOps::IsRVec<typename RetType::value_type>::value;
0088          if (!hasInnerRVec)
0089             throw std::runtime_error("This Vary call is varying multiple columns simultaneously but the expression "
0090                                      "does not return an RVec of RVecs.");
0091 
0092          auto colTypes = GetColumnTypeNamesList(colNames);
0093          auto allColTypesEqual =
0094             std::all_of(colTypes.begin() + 1, colTypes.end(), [&](const std::string &t) { return t == colTypes[0]; });
0095          if (!allColTypesEqual)
0096             throw std::runtime_error("Cannot simultaneously vary multiple columns of different types.");
0097 
0098          const auto &innerTypeID = typeid(RDFInternal::InnerValueType_t<RetType>);
0099 
0100          for (auto i = 0u; i < colTypes.size(); ++i) {
0101             const auto *define = fColRegister.GetDefine(colNames[i]);
0102             const auto *expectedTypeID = define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(colTypes[i]);
0103             if (innerTypeID != *expectedTypeID)
0104                throw std::runtime_error("Varied values for column \"" + colNames[i] + "\" have a different type (" +
0105                                         RDFInternal::TypeID2TypeName(innerTypeID) + ") than the nominal value (" +
0106                                         colTypes[i] + ").");
0107          }
0108       } else { // we are varying a single column, RetType is RVec<T>
0109          const auto &retTypeID = typeid(typename RetType::value_type);
0110          const auto &colName = colNames[0]; // we have only one element in there
0111          const auto *define = fColRegister.GetDefine(colName);
0112          const auto *expectedTypeID =
0113             define ? &define->GetTypeId() : &RDFInternal::TypeName2TypeID(GetColumnType(colName));
0114          if (retTypeID != *expectedTypeID)
0115             throw std::runtime_error("Varied values for column \"" + colName + "\" have a different type (" +
0116                                      RDFInternal::TypeID2TypeName(retTypeID) + ") than the nominal value (" +
0117                                      GetColumnType(colName) + ").");
0118       }
0119 
0120       // when varying multiple columns, they must be different columns
0121       if (colNames.size() > 1) {
0122          std::set<std::string> uniqueCols(colNames.begin(), colNames.end());
0123          if (uniqueCols.size() != colNames.size())
0124             throw std::logic_error("A column name was passed to the same Vary invocation multiple times.");
0125       }
0126    }
0127 
0128    RDFDetail::RLoopManager *GetLoopManager() const { return fLoopManager; }
0129 
0130    ColumnNames_t GetValidatedColumnNames(const unsigned int nColumns, const ColumnNames_t &columns)
0131    {
0132       return RDFInternal::GetValidatedColumnNames(*fLoopManager, nColumns, columns, fColRegister, fDataSource);
0133    }
0134 
0135    template <typename... ColumnTypes>
0136    void CheckAndFillDSColumns(ColumnNames_t validCols, TTraits::TypeList<ColumnTypes...> typeList)
0137    {
0138       if (fDataSource != nullptr)
0139          RDFInternal::AddDSColumns(validCols, *fLoopManager, *fDataSource, typeList, fColRegister);
0140    }
0141 
0142    /// Create RAction object, return RResultPtr for the action
0143    /// Overload for the case in which all column types were specified (no jitting).
0144    /// For most actions, `r` and `helperArg` will refer to the same object, because the only argument to forward to
0145    /// the action helper is the result value itself. We need the distinction for actions such as Snapshot or Cache,
0146    /// for which the constructor arguments of the action helper are different from the returned value.
0147    template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0148              typename HelperArgType = ActionResultType,
0149              std::enable_if_t<!RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0150    RResultPtr<ActionResultType> CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0151                                              const std::shared_ptr<HelperArgType> &helperArg,
0152                                              const std::shared_ptr<RDFNode> &proxiedPtr, const int /*nColumns*/ = -1)
0153    {
0154       constexpr auto nColumns = sizeof...(ColTypes);
0155 
0156       const auto validColumnNames = GetValidatedColumnNames(nColumns, columns);
0157       CheckAndFillDSColumns(validColumnNames, RDFInternal::TypeList<ColTypes...>());
0158 
0159       const auto nSlots = fLoopManager->GetNSlots();
0160 
0161       auto action = RDFInternal::BuildAction<ColTypes...>(validColumnNames, helperArg, nSlots, proxiedPtr, ActionTag{},
0162                                                           fColRegister);
0163       return MakeResultPtr(r, *fLoopManager, std::move(action));
0164    }
0165 
0166    /// Create RAction object, return RResultPtr for the action
0167    /// Overload for the case in which one or more column types were not specified (RTTI + jitting).
0168    /// This overload has a `nColumns` optional argument. If present, the number of required columns for
0169    /// this action is taken equal to nColumns, otherwise it is assumed to be sizeof...(ColTypes).
0170    template <typename ActionTag, typename... ColTypes, typename ActionResultType, typename RDFNode,
0171              typename HelperArgType = ActionResultType,
0172              std::enable_if_t<RDFInternal::RNeedJitting<ColTypes...>::value, int> = 0>
0173    RResultPtr<ActionResultType> CreateAction(const ColumnNames_t &columns, const std::shared_ptr<ActionResultType> &r,
0174                                              const std::shared_ptr<HelperArgType> &helperArg,
0175                                              const std::shared_ptr<RDFNode> &proxiedPtr, const int nColumns = -1)
0176    {
0177       auto realNColumns = (nColumns > -1 ? nColumns : sizeof...(ColTypes));
0178 
0179       const auto validColumnNames = GetValidatedColumnNames(realNColumns, columns);
0180       const unsigned int nSlots = fLoopManager->GetNSlots();
0181 
0182       auto *tree = fLoopManager->GetTree();
0183       auto *helperArgOnHeap = RDFInternal::MakeSharedOnHeap(helperArg);
0184 
0185       auto upcastNodeOnHeap = RDFInternal::MakeSharedOnHeap(RDFInternal::UpcastNode(proxiedPtr));
0186 
0187       const auto jittedAction = std::make_shared<RDFInternal::RJittedAction>(*fLoopManager, validColumnNames,
0188                                                                              fColRegister, proxiedPtr->GetVariations());
0189       auto jittedActionOnHeap = RDFInternal::MakeWeakOnHeap(jittedAction);
0190 
0191       auto toJit =
0192          RDFInternal::JitBuildAction(validColumnNames, upcastNodeOnHeap, typeid(HelperArgType), typeid(ActionTag),
0193                                      helperArgOnHeap, tree, nSlots, fColRegister, fDataSource, jittedActionOnHeap);
0194       fLoopManager->ToJitExec(toJit);
0195       return MakeResultPtr(r, *fLoopManager, std::move(jittedAction));
0196    }
0197 
0198 public:
0199    RInterfaceBase(std::shared_ptr<RDFDetail::RLoopManager> lm);
0200    RInterfaceBase(RDFDetail::RLoopManager &lm, const RDFInternal::RColumnRegister &colRegister);
0201 
0202    ColumnNames_t GetColumnNames();
0203 
0204    std::string GetColumnType(std::string_view column);
0205 
0206    RDFDescription Describe();
0207 
0208    RVariationsDescription GetVariations() const;
0209    bool HasColumn(std::string_view columnName);
0210    ColumnNames_t GetDefinedColumnNames();
0211    unsigned int GetNSlots() const;
0212    unsigned int GetNRuns() const;
0213    unsigned int GetNFiles();
0214 };
0215 } // namespace RDF
0216 } // namespace ROOT
0217 
0218 #endif