File indexing completed on 2025-12-14 10:31:09
0001
0002
0003
0004
0005 #ifndef ROOT_RCategoricalAxis
0006 #define ROOT_RCategoricalAxis
0007
0008 #include "RBinIndex.hxx"
0009 #include "RBinIndexRange.hxx"
0010 #include "RLinearizedIndex.hxx"
0011
0012 #include <cassert>
0013 #include <cstddef>
0014 #include <stdexcept>
0015 #include <string>
0016 #include <string_view>
0017 #include <unordered_set>
0018 #include <utility>
0019 #include <vector>
0020
0021 class TBuffer;
0022
0023 namespace ROOT {
0024 namespace Experimental {
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041 class RCategoricalAxis final {
0042 public:
0043 using ArgumentType = std::string_view;
0044
0045 private:
0046
0047 std::vector<std::string> fCategories;
0048
0049 bool fEnableOverflowBin;
0050
0051 public:
0052
0053
0054
0055
0056 explicit RCategoricalAxis(std::vector<std::string> categories, bool enableOverflowBin = true)
0057 : fCategories(std::move(categories)), fEnableOverflowBin(enableOverflowBin)
0058 {
0059 if (fCategories.size() < 1) {
0060 throw std::invalid_argument("must have at least one category");
0061 }
0062
0063 std::unordered_set<std::string_view> set;
0064 for (std::size_t i = 0; i < fCategories.size(); i++) {
0065 if (!set.insert(fCategories[i]).second) {
0066 std::string msg = "duplicate category '" + fCategories[i] + "' for bin " + std::to_string(i);
0067 throw std::invalid_argument(msg);
0068 }
0069 }
0070 }
0071
0072 std::size_t GetNNormalBins() const { return fCategories.size(); }
0073 std::size_t GetTotalNBins() const { return fEnableOverflowBin ? fCategories.size() + 1 : fCategories.size(); }
0074 const std::vector<std::string> &GetCategories() const { return fCategories; }
0075 bool HasOverflowBin() const { return fEnableOverflowBin; }
0076
0077 friend bool operator==(const RCategoricalAxis &lhs, const RCategoricalAxis &rhs)
0078 {
0079 return lhs.fCategories == rhs.fCategories && lhs.fEnableOverflowBin == rhs.fEnableOverflowBin;
0080 }
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090 RLinearizedIndex ComputeLinearizedIndex(std::string_view x) const
0091 {
0092
0093 for (std::size_t bin = 0; bin < fCategories.size(); bin++) {
0094 if (fCategories[bin] == x) {
0095 return {bin, true};
0096 }
0097 }
0098
0099
0100 return {fCategories.size(), fEnableOverflowBin};
0101 }
0102
0103
0104
0105
0106
0107
0108
0109
0110 RLinearizedIndex GetLinearizedIndex(RBinIndex index) const
0111 {
0112 if (index.IsUnderflow()) {
0113
0114 return {0, false};
0115 } else if (index.IsOverflow()) {
0116 return {fCategories.size(), fEnableOverflowBin};
0117 } else if (index.IsInvalid()) {
0118 return {0, false};
0119 }
0120 assert(index.IsNormal());
0121 std::size_t bin = index.GetIndex();
0122 return {bin, bin < fCategories.size()};
0123 }
0124
0125
0126
0127
0128 RBinIndexRange GetNormalRange() const
0129 {
0130 return Internal::CreateBinIndexRange(RBinIndex(0), RBinIndex(fCategories.size()), 0);
0131 }
0132
0133
0134
0135
0136
0137
0138 RBinIndexRange GetNormalRange(RBinIndex begin, RBinIndex end) const
0139 {
0140 if (!begin.IsNormal()) {
0141 throw std::invalid_argument("begin must be a normal bin");
0142 }
0143 if (begin.GetIndex() >= fCategories.size()) {
0144 throw std::invalid_argument("begin must be inside the axis");
0145 }
0146 if (!end.IsNormal()) {
0147 throw std::invalid_argument("end must be a normal bin");
0148 }
0149 if (end.GetIndex() > fCategories.size()) {
0150 throw std::invalid_argument("end must be inside or past the axis");
0151 }
0152 if (!(end >= begin)) {
0153 throw std::invalid_argument("end must be >= begin");
0154 }
0155 return Internal::CreateBinIndexRange(begin, end, 0);
0156 }
0157
0158
0159
0160
0161
0162
0163 RBinIndexRange GetFullRange() const
0164 {
0165 return fEnableOverflowBin ? Internal::CreateBinIndexRange(RBinIndex(0), RBinIndex(), fCategories.size())
0166 : GetNormalRange();
0167 }
0168
0169
0170 void Streamer(TBuffer &) { throw std::runtime_error("unable to store RCategoricalAxis"); }
0171 };
0172
0173 }
0174 }
0175
0176 #endif