File indexing completed on 2025-12-14 10:31:08
0001
0002
0003
0004
0005 #ifndef ROOT_RAxes
0006 #define ROOT_RAxes
0007
0008 #include "RBinIndex.hxx"
0009 #include "RCategoricalAxis.hxx"
0010 #include "RLinearizedIndex.hxx"
0011 #include "RRegularAxis.hxx"
0012 #include "RVariableBinAxis.hxx"
0013
0014 #include <array>
0015 #include <cassert>
0016 #include <cstddef>
0017 #include <stdexcept>
0018 #include <tuple>
0019 #include <type_traits>
0020 #include <utility>
0021 #include <variant>
0022 #include <vector>
0023
0024 class TBuffer;
0025
0026 namespace ROOT {
0027 namespace Experimental {
0028
0029
0030 using RAxisVariant = std::variant<RRegularAxis, RVariableBinAxis, RCategoricalAxis>;
0031
0032
0033 template <typename T>
0034 class RHistEngine;
0035
0036 namespace Internal {
0037
0038
0039
0040
0041 class RAxes final {
0042 template <typename T>
0043 friend class ::ROOT::Experimental::RHistEngine;
0044
0045 std::vector<RAxisVariant> fAxes;
0046
0047 public:
0048
0049 explicit RAxes(std::vector<RAxisVariant> axes) : fAxes(std::move(axes))
0050 {
0051 if (fAxes.empty()) {
0052 throw std::invalid_argument("must have at least 1 axis object");
0053 }
0054 }
0055
0056 std::size_t GetNDimensions() const { return fAxes.size(); }
0057 const std::vector<RAxisVariant> &Get() const { return fAxes; }
0058
0059 friend bool operator==(const RAxes &lhs, const RAxes &rhs) { return lhs.fAxes == rhs.fAxes; }
0060 friend bool operator!=(const RAxes &lhs, const RAxes &rhs) { return !(lhs == rhs); }
0061
0062
0063
0064
0065
0066
0067 std::size_t ComputeTotalNBins() const
0068 {
0069 std::size_t totalNBins = 1;
0070 for (auto &&axis : fAxes) {
0071 if (auto *regular = std::get_if<RRegularAxis>(&axis)) {
0072 totalNBins *= regular->GetTotalNBins();
0073 } else if (auto *variable = std::get_if<RVariableBinAxis>(&axis)) {
0074 totalNBins *= variable->GetTotalNBins();
0075 } else if (auto *categorical = std::get_if<RCategoricalAxis>(&axis)) {
0076 totalNBins *= categorical->GetTotalNBins();
0077 } else {
0078 throw std::logic_error("unimplemented axis type");
0079 }
0080 }
0081 return totalNBins;
0082 }
0083
0084 private:
0085 template <std::size_t I, std::size_t N, typename... A>
0086 RLinearizedIndex ComputeGlobalIndexImpl(std::size_t index, const std::tuple<A...> &args) const
0087 {
0088 using ArgumentType = std::tuple_element_t<I, std::tuple<A...>>;
0089 const auto &axis = fAxes[I];
0090 RLinearizedIndex linIndex;
0091 if (auto *regular = std::get_if<RRegularAxis>(&axis)) {
0092 if constexpr (std::is_convertible_v<ArgumentType, RRegularAxis::ArgumentType>) {
0093 index *= regular->GetTotalNBins();
0094 linIndex = regular->ComputeLinearizedIndex(std::get<I>(args));
0095 } else {
0096 throw std::invalid_argument("invalid type of argument");
0097 }
0098 } else if (auto *variable = std::get_if<RVariableBinAxis>(&axis)) {
0099 if constexpr (std::is_convertible_v<ArgumentType, RVariableBinAxis::ArgumentType>) {
0100 index *= variable->GetTotalNBins();
0101 linIndex = variable->ComputeLinearizedIndex(std::get<I>(args));
0102 } else {
0103 throw std::invalid_argument("invalid type of argument");
0104 }
0105 } else if (auto *categorical = std::get_if<RCategoricalAxis>(&axis)) {
0106 if constexpr (std::is_convertible_v<ArgumentType, RCategoricalAxis::ArgumentType>) {
0107 index *= categorical->GetTotalNBins();
0108 linIndex = categorical->ComputeLinearizedIndex(std::get<I>(args));
0109 } else {
0110 throw std::invalid_argument("invalid type of argument");
0111 }
0112 } else {
0113 throw std::logic_error("unimplemented axis type");
0114 }
0115 if (!linIndex.fValid) {
0116 return {0, false};
0117 }
0118 index += linIndex.fIndex;
0119 if constexpr (I + 1 < N) {
0120 return ComputeGlobalIndexImpl<I + 1, N>(index, args);
0121 }
0122 return {index, true};
0123 }
0124
0125 template <std::size_t N, typename... A>
0126 RLinearizedIndex ComputeGlobalIndexImpl(const std::tuple<A...> &args) const
0127 {
0128 return ComputeGlobalIndexImpl<0, N>(0, args);
0129 }
0130
0131 public:
0132
0133
0134
0135
0136
0137
0138
0139 template <typename... A>
0140 RLinearizedIndex ComputeGlobalIndex(const std::tuple<A...> &args) const
0141 {
0142 if (sizeof...(A) != fAxes.size()) {
0143 throw std::invalid_argument("invalid number of arguments to ComputeGlobalIndex");
0144 }
0145 return ComputeGlobalIndexImpl<sizeof...(A)>(args);
0146 }
0147
0148
0149
0150
0151
0152 template <std::size_t N>
0153 RLinearizedIndex ComputeGlobalIndex(const std::array<RBinIndex, N> &indices) const
0154 {
0155 if (N != fAxes.size()) {
0156 throw std::invalid_argument("invalid number of indices passed to ComputeGlobalIndex");
0157 }
0158 std::size_t globalIndex = 0;
0159 for (std::size_t i = 0; i < N; i++) {
0160 const auto &index = indices[i];
0161 const auto &axis = fAxes[i];
0162 RLinearizedIndex linIndex;
0163 if (auto *regular = std::get_if<RRegularAxis>(&axis)) {
0164 globalIndex *= regular->GetTotalNBins();
0165 linIndex = regular->GetLinearizedIndex(index);
0166 } else if (auto *variable = std::get_if<RVariableBinAxis>(&axis)) {
0167 globalIndex *= variable->GetTotalNBins();
0168 linIndex = variable->GetLinearizedIndex(index);
0169 } else if (auto *categorical = std::get_if<RCategoricalAxis>(&axis)) {
0170 globalIndex *= categorical->GetTotalNBins();
0171 linIndex = categorical->GetLinearizedIndex(index);
0172 } else {
0173 throw std::logic_error("unimplemented axis type");
0174 }
0175 if (!linIndex.fValid) {
0176 return {0, false};
0177 }
0178 globalIndex += linIndex.fIndex;
0179 }
0180 return {globalIndex, true};
0181 }
0182
0183
0184 void Streamer(TBuffer &) { throw std::runtime_error("unable to store RAxes"); }
0185 };
0186
0187 }
0188 }
0189 }
0190
0191 #endif