Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:57

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : ResultsMulticlass                                                     *
0008  *                                                                                *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Derived-class for result-vectors                                          *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0015  *      Peter Speckmayer <Peter.Speckmayer@cern.ch>  - CERN, Switzerland          *
0016  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
0017  *      Jan Therhaag       <Jan.Therhaag@cern.ch>     - U of Bonn, Germany        *
0018  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0019  *                                                                                *
0020  * Copyright (c) 2005-2011:                                                       *
0021  *      CERN, Switzerland                                                         *
0022  *      U. of Victoria, Canada                                                    *
0023  *      MPI-K Heidelberg, Germany                                                 *
0024  *      U. of Bonn, Germany                                                       *
0025  *                                                                                *
0026  * Redistribution and use in source and binary forms, with or without             *
0027  * modification, are permitted according to the terms listed in LICENSE           *
0028  * (see tmva/doc/LICENSE)                                                         *
0029  **********************************************************************************/
0030 
0031 #ifndef ROOT_TMVA_ResultsMulticlass
0032 #define ROOT_TMVA_ResultsMulticlass
0033 
0034 //////////////////////////////////////////////////////////////////////////
0035 //                                                                      //
0036 // ResultsMulticlass                                                    //
0037 //                                                                      //
0038 // Class which takes the results of a multiclass classification         //
0039 //                                                                      //
0040 //////////////////////////////////////////////////////////////////////////
0041 
0042 #include "TH1F.h"
0043 #include "TH2F.h"
0044 
0045 #include "TMVA/Results.h"
0046 #include "TMVA/Event.h"
0047 #include "IFitterTarget.h"
0048 
0049 #include <vector>
0050 
0051 namespace TMVA {
0052 
0053 class MsgLogger;
0054 
0055 class ResultsMulticlass : public Results, public IFitterTarget {
0056 
0057 public:
0058    ResultsMulticlass(const DataSetInfo *dsi, TString resultsName);
0059    ~ResultsMulticlass();
0060 
0061    // setters
0062    void SetValue(std::vector<Float_t> &value, Int_t ievt);
0063    void Resize(Int_t entries) { fMultiClassValues.resize(entries); }
0064    using TObject::Clear;
0065    void Clear(Option_t *) override { fMultiClassValues.clear(); }
0066 
0067    // getters
0068    Long64_t GetSize() const { return fMultiClassValues.size(); }
0069    const std::vector<Float_t> &operator[](Int_t ievt) const override { return fMultiClassValues.at(ievt); }
0070    std::vector<std::vector<Float_t>> *GetValueVector() { return &fMultiClassValues; }
0071 
0072    Types::EAnalysisType GetAnalysisType() override { return Types::kMulticlass; }
0073    Float_t GetAchievableEff(UInt_t cls) { return fAchievableEff.at(cls); }
0074    Float_t GetAchievablePur(UInt_t cls) { return fAchievablePur.at(cls); }
0075    std::vector<Float_t> &GetAchievableEff() { return fAchievableEff; }
0076    std::vector<Float_t> &GetAchievablePur() { return fAchievablePur; }
0077 
0078    TMatrixD GetConfusionMatrix(Double_t effB);
0079 
0080    // histogramming
0081    void CreateMulticlassPerformanceHistos(TString prefix);
0082    void CreateMulticlassHistos(TString prefix, Int_t nbins, Int_t nbins_high);
0083 
0084    Double_t EstimatorFunction(std::vector<Double_t> &) override;
0085    std::vector<Double_t> GetBestMultiClassCuts(UInt_t targetClass);
0086 
0087 private:
0088    mutable std::vector<std::vector<Float_t>> fMultiClassValues; ///< mva values (Results)
0089    mutable MsgLogger *fLogger;                                  ///<! message logger
0090    MsgLogger &Log() const { return *fLogger; }
0091    UInt_t fClassToOptimize;
0092    std::vector<Float_t> fAchievableEff;
0093    std::vector<Float_t> fAchievablePur;
0094    std::vector<std::vector<Double_t>> fBestCuts;
0095 
0096    // Temporary storage used during GetBestMultiClassCuts
0097    std::vector<Float_t> fClassSumWeights;
0098    std::vector<Float_t> fEventWeights;
0099    std::vector<UInt_t> fEventClasses;
0100 
0101 protected:
0102    ClassDefOverride(ResultsMulticlass, 3);
0103 };
0104 
0105 } // namespace TMVA
0106 
0107 #endif