Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:51

0001 // // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : DataSetInfo                                                           *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Contains all the data information                                         *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland              *
0015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - DESY, Germany                  *
0016  *                                                                                *
0017  * Copyright (c) 2008-2011:                                                       *
0018  *      CERN, Switzerland                                                         *
0019  *      MPI-K Heidelberg, Germany                                                 *
0020  *      DESY Hamburg, Germany                                                     *
0021  *                                                                                *
0022  * Redistribution and use in source and binary forms, with or without             *
0023  * modification, are permitted according to the terms listed in LICENSE           *
0024  * (see tmva/doc/LICENSE)                                          *
0025  **********************************************************************************/
0026 
0027 #ifndef ROOT_TMVA_DataSetInfo
0028 #define ROOT_TMVA_DataSetInfo
0029 
0030 //////////////////////////////////////////////////////////////////////////
0031 //                                                                      //
0032 // DataSetInfo                                                          //
0033 //                                                                      //
0034 // Class that contains all the data information                         //
0035 //                                                                      //
0036 //////////////////////////////////////////////////////////////////////////
0037 
0038 #include <iosfwd>
0039 #include <vector>
0040 #include <map>
0041 
0042 #include "TObject.h"
0043 #include "TString.h"
0044 #include "TTree.h"
0045 #include "TCut.h"
0046 #include "TMatrixDfwd.h"
0047 
0048 #include "TMVA/Types.h"
0049 #include "TMVA/VariableInfo.h"
0050 #include "TMVA/ClassInfo.h"
0051 #include "TMVA/Event.h"
0052 
0053 class TH2;
0054 
0055 namespace TMVA {
0056 
0057    class DataSet;
0058    class VariableTransformBase;
0059    class MsgLogger;
0060    class DataSetManager;
0061 
0062    class DataSetInfo : public TObject {
0063 
0064    public:
0065 
0066       enum { kIsArrayVariable = BIT(15) };
0067 
0068       DataSetInfo(const TString& name = "Default");
0069       virtual ~DataSetInfo();
0070 
0071       virtual const char* GetName() const { return fName.Data(); }
0072 
0073       // the data set
0074       void        ClearDataSet() const;
0075       DataSet*    GetDataSet() const;
0076 
0077       // ---
0078       // the variable data
0079       // ---
0080       VariableInfo&     AddVariable( const TString& expression, const TString& title = "", const TString& unit = "",
0081                                      Double_t min = 0, Double_t max = 0, char varType='F',
0082                                      Bool_t normalized = kTRUE, void* external = nullptr );
0083       VariableInfo&     AddVariable( const VariableInfo& varInfo );
0084 
0085       // NEW: add an array of variables (e.g. for image data)
0086       void AddVariablesArray(const TString &expression, Int_t size, const TString &title = "", const TString &unit = "",
0087                              Double_t min = 0, Double_t max = 0, char type = 'F', Bool_t normalized = kTRUE,
0088                              void *external = nullptr );
0089 
0090       VariableInfo&     AddTarget  ( const TString& expression, const TString& title, const TString& unit,
0091                                      Double_t min, Double_t max, Bool_t normalized = kTRUE, void* external = nullptr );
0092       VariableInfo&     AddTarget  ( const VariableInfo& varInfo );
0093 
0094       VariableInfo&     AddSpectator ( const TString& expression, const TString& title, const TString& unit,
0095                                        Double_t min, Double_t max, char type = 'F', Bool_t normalized = kTRUE, void* external = nullptr );
0096       VariableInfo&     AddSpectator ( const VariableInfo& varInfo );
0097 
0098       ClassInfo*        AddClass   ( const TString& className );
0099 
0100       // accessors
0101 
0102       // general
0103       std::vector<VariableInfo>&       GetVariableInfos()         { return fVariables; }
0104       const std::vector<VariableInfo>& GetVariableInfos() const   { return fVariables; }
0105       VariableInfo&                    GetVariableInfo( Int_t i ) { return fVariables.at(i); }
0106       const VariableInfo&              GetVariableInfo( Int_t i ) const { return fVariables.at(i); }
0107 
0108       Int_t GetVarArraySize(const TString &expression) const {
0109          auto element = fVarArrays.find(expression);
0110          return (element != fVarArrays.end()) ? element->second : -1;
0111        }
0112        Bool_t IsVariableFromArray(Int_t i) const { return GetVariableInfo(i).TestBit(DataSetInfo::kIsArrayVariable);  }
0113 
0114        std::vector<VariableInfo> &GetTargetInfos()
0115        {
0116           return fTargets;
0117        }
0118        const std::vector<VariableInfo> &GetTargetInfos() const { return fTargets; }
0119        VariableInfo &GetTargetInfo(Int_t i) { return fTargets.at(i); }
0120        const VariableInfo &GetTargetInfo(Int_t i) const { return fTargets.at(i); }
0121 
0122        std::vector<VariableInfo> &GetSpectatorInfos() { return fSpectators; }
0123        const std::vector<VariableInfo> &GetSpectatorInfos() const { return fSpectators; }
0124        VariableInfo &GetSpectatorInfo(Int_t i) { return fSpectators.at(i); }
0125        const VariableInfo &GetSpectatorInfo(Int_t i) const { return fSpectators.at(i); }
0126 
0127        UInt_t GetNVariables() const { return fVariables.size(); }
0128        UInt_t GetNTargets() const { return fTargets.size(); }
0129        UInt_t GetNSpectators(bool all = kTRUE) const;
0130 
0131        const TString &GetNormalization() const { return fNormalization; }
0132        void SetNormalization(const TString &norm) { fNormalization = norm; }
0133 
0134        void SetTrainingSumSignalWeights(Double_t trainingSumSignalWeights)
0135        {
0136           fTrainingSumSignalWeights = trainingSumSignalWeights;}
0137       void SetTrainingSumBackgrWeights(Double_t trainingSumBackgrWeights){fTrainingSumBackgrWeights = trainingSumBackgrWeights;}
0138       void SetTestingSumSignalWeights (Double_t testingSumSignalWeights ){fTestingSumSignalWeights  = testingSumSignalWeights ;}
0139       void SetTestingSumBackgrWeights (Double_t testingSumBackgrWeights ){fTestingSumBackgrWeights  = testingSumBackgrWeights ;}
0140 
0141       Double_t GetTrainingSumSignalWeights();
0142       Double_t GetTrainingSumBackgrWeights();
0143       Double_t GetTestingSumSignalWeights ();
0144       Double_t GetTestingSumBackgrWeights ();
0145 
0146 
0147 
0148       // classification information
0149       Int_t              GetClassNameMaxLength() const;
0150       Int_t              GetVariableNameMaxLength() const;
0151       Int_t              GetTargetNameMaxLength() const;
0152       ClassInfo*         GetClassInfo( Int_t clNum ) const;
0153       ClassInfo*         GetClassInfo( const TString& name ) const;
0154       void               PrintClasses() const;
0155       UInt_t             GetNClasses() const { return fClasses.size(); }
0156       Bool_t             IsSignal( const Event* ev ) const;
0157       std::vector<Float_t>* GetTargetsForMulticlass( const Event* ev );
0158       UInt_t             GetSignalClassIndex(){return fSignalClass;}
0159 
0160       // by variable
0161       Int_t              FindVarIndex( const TString& )      const;
0162 
0163       // weights
0164       const TString      GetWeightExpression(Int_t i)      const { return GetClassInfo(i)->GetWeight(); }
0165       void               SetWeightExpression( const TString& exp, const TString& className = "" );
0166 
0167       // cuts
0168       const TCut&        GetCut (Int_t i)                         const { return GetClassInfo(i)->GetCut(); }
0169       const TCut&        GetCut ( const TString& className )      const { return GetClassInfo(className)->GetCut(); }
0170       void               SetCut ( const TCut& cut, const TString& className );
0171       void               AddCut ( const TCut& cut, const TString& className );
0172       Bool_t             HasCuts() const;
0173 
0174       std::vector<TString> GetListOfVariables() const;
0175 
0176       // correlation matrix
0177       const TMatrixD*    CorrelationMatrix     ( const TString& className ) const;
0178       void               SetCorrelationMatrix  ( const TString& className, TMatrixD* matrix );
0179       void               PrintCorrelationMatrix( const TString& className );
0180       TH2*               CreateCorrelationMatrixHist( const TMatrixD* m,
0181                                                       const TString& hName,
0182                                                       const TString& hTitle ) const;
0183 
0184       // options
0185       void               SetSplitOptions(const TString& so) { fSplitOptions = so; fNeedsRebuilding = kTRUE; }
0186       const TString&     GetSplitOptions() const { return fSplitOptions; }
0187 
0188       // root dir
0189       void               SetRootDir(TDirectory* d) { fOwnRootDir = d; }
0190       TDirectory*        GetRootDir() const { return fOwnRootDir; }
0191 
0192       void               SetMsgType( EMsgType t ) const;
0193 
0194       DataSetManager*   GetDataSetManager(){return fDataSetManager;}
0195    private:
0196 
0197       TMVA::DataSetManager*            fDataSetManager; // DSMTEST
0198       void                       SetDataSetManager( DataSetManager* dsm ) { fDataSetManager = dsm; } // DSMTEST
0199       friend class DataSetManager;  // DSMTEST (datasetmanager test)
0200 
0201       DataSetInfo(const DataSetInfo &) = delete;
0202       DataSetInfo & operator= (const DataSetInfo &) = delete;
0203 
0204       void PrintCorrelationMatrix( TTree* theTree );
0205 
0206       TString                    fName;              ///< name of the dataset info object
0207 
0208       mutable DataSet*           fDataSet;           ///< dataset, owned by this datasetinfo object
0209       mutable Bool_t             fNeedsRebuilding;   ///< flag if rebuilding of dataset is needed (after change of cuts, vars, etc.)
0210 
0211       // expressions/formulas
0212       std::vector<VariableInfo>  fVariables;         ///< list of variable expressions/internal names
0213       std::vector<VariableInfo>  fTargets;           ///< list of targets expressions/internal names
0214       std::vector<VariableInfo>  fSpectators;        ///< list of spectators expressions/internal names
0215 
0216       // variable arrays
0217       std::map<TString, int> fVarArrays;
0218 
0219       // the classes
0220       mutable std::vector<ClassInfo*> fClasses;      ///< name and other infos of the classes
0221 
0222       TString                    fNormalization;
0223       TString                    fSplitOptions;
0224 
0225       Double_t                   fTrainingSumSignalWeights;
0226       Double_t                   fTrainingSumBackgrWeights;
0227       Double_t                   fTestingSumSignalWeights ;
0228       Double_t                   fTestingSumBackgrWeights ;
0229 
0230 
0231 
0232       TDirectory*                fOwnRootDir;        ///< ROOT output dir
0233       Bool_t                     fVerbose;           ///< Verbosity
0234 
0235       UInt_t                     fSignalClass;       ///< index of the class with the name signal
0236 
0237       std::vector<Float_t>*      fTargetsForMulticlass;///<-> all targets 0 except the one with index==classNumber
0238 
0239       mutable MsgLogger*         fLogger;            ///<! message logger
0240       MsgLogger& Log() const { return *fLogger; }
0241 
0242    public:
0243 
0244        ClassDef(DataSetInfo,1);
0245    };
0246 }
0247 
0248 #endif