File indexing completed on 2025-01-30 10:22:57
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031 #ifndef ROOT_TMVA_ResultsMulticlass
0032 #define ROOT_TMVA_ResultsMulticlass
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042 #include "TH1F.h"
0043 #include "TH2F.h"
0044
0045 #include "TMVA/Results.h"
0046 #include "TMVA/Event.h"
0047 #include "IFitterTarget.h"
0048
0049 #include <vector>
0050
0051 namespace TMVA {
0052
0053 class MsgLogger;
0054
0055 class ResultsMulticlass : public Results, public IFitterTarget {
0056
0057 public:
0058 ResultsMulticlass(const DataSetInfo *dsi, TString resultsName);
0059 ~ResultsMulticlass();
0060
0061
0062 void SetValue(std::vector<Float_t> &value, Int_t ievt);
0063 void Resize(Int_t entries) { fMultiClassValues.resize(entries); }
0064 using TObject::Clear;
0065 void Clear(Option_t *) override { fMultiClassValues.clear(); }
0066
0067
0068 Long64_t GetSize() const { return fMultiClassValues.size(); }
0069 const std::vector<Float_t> &operator[](Int_t ievt) const override { return fMultiClassValues.at(ievt); }
0070 std::vector<std::vector<Float_t>> *GetValueVector() { return &fMultiClassValues; }
0071
0072 Types::EAnalysisType GetAnalysisType() override { return Types::kMulticlass; }
0073 Float_t GetAchievableEff(UInt_t cls) { return fAchievableEff.at(cls); }
0074 Float_t GetAchievablePur(UInt_t cls) { return fAchievablePur.at(cls); }
0075 std::vector<Float_t> &GetAchievableEff() { return fAchievableEff; }
0076 std::vector<Float_t> &GetAchievablePur() { return fAchievablePur; }
0077
0078 TMatrixD GetConfusionMatrix(Double_t effB);
0079
0080
0081 void CreateMulticlassPerformanceHistos(TString prefix);
0082 void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high);
0083
0084 Double_t EstimatorFunction(std::vector<Double_t> &) override;
0085 std::vector<Double_t> GetBestMultiClassCuts(UInt_t targetClass);
0086
0087 private:
0088 mutable std::vector<std::vector<Float_t>> fMultiClassValues;
0089 mutable MsgLogger *fLogger;
0090 MsgLogger &Log() const { return *fLogger; }
0091 UInt_t fClassToOptimize;
0092 std::vector<Float_t> fAchievableEff;
0093 std::vector<Float_t> fAchievablePur;
0094 std::vector<std::vector<Double_t>> fBestCuts;
0095
0096
0097 std::vector<Float_t> fClassSumWeights;
0098 std::vector<Float_t> fEventWeights;
0099 std::vector<UInt_t> fEventClasses;
0100
0101 protected:
0102 ClassDefOverride(ResultsMulticlass, 3);
0103 };
0104
0105 }
0106
0107 #endif