Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:50

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : DataSet                                                               *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Contains all the data information                                         *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
0016  *      Peter Speckmayer <Peter.Speckmayer@cern.ch>  - CERN, Switzerland          *
0017  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0018  *                                                                                *
0019  * Copyright (c) 2006:                                                            *
0020  *      CERN, Switzerland                                                         *
0021  *      U. of Victoria, Canada                                                    *
0022  *      MPI-K Heidelberg, Germany                                                 *
0023  *                                                                                *
0024  * Redistribution and use in source and binary forms, with or without             *
0025  * modification, are permitted according to the terms listed in LICENSE           *
0026  * (see tmva/doc/LICENSE)                                          *
0027  **********************************************************************************/
0028 
0029 #ifndef ROOT_TMVA_DataSet
0030 #define ROOT_TMVA_DataSet
0031 
0032 //////////////////////////////////////////////////////////////////////////
0033 //                                                                      //
0034 // DataSet                                                              //
0035 //                                                                      //
0036 // Class that contains all the data information                         //
0037 //                                                                      //
0038 //////////////////////////////////////////////////////////////////////////
0039 
0040 #include <vector>
0041 #include <map>
0042 
0043 #include "TNamed.h"
0044 #include "TString.h"
0045 #include "TTree.h"
0046 #include "TRandom3.h"
0047 
0048 #include "TMVA/Types.h"
0049 #include "TMVA/VariableInfo.h"
0050 
0051 namespace TMVA {
0052 
0053    class Event;
0054    class DataSetInfo;
0055    class MsgLogger;
0056    class Results;
0057 
0058    class DataSet :public TNamed {
0059 
0060    public:
0061       DataSet();
0062       DataSet(const DataSetInfo&);
0063       virtual ~DataSet();
0064 
0065       void      AddEvent( Event *, Types::ETreeType );
0066 
0067       Long64_t  GetNEvents( Types::ETreeType type = Types::kMaxTreeType ) const;
0068       Long64_t  GetNTrainingEvents()              const { return GetNEvents(Types::kTraining); }
0069       Long64_t  GetNTestEvents()                  const { return GetNEvents(Types::kTesting); }
0070 
0071       // const getters
0072       const Event*    GetEvent()                        const; ///< returns event without transformations
0073       const Event*    GetEvent        ( Long64_t ievt ) const { fCurrentEventIdx = ievt; return GetEvent(); } // returns event without transformations
0074       const Event*    GetTrainingEvent( Long64_t ievt ) const { return GetEvent(ievt, Types::kTraining); }
0075       const Event*    GetTestEvent    ( Long64_t ievt ) const { return GetEvent(ievt, Types::kTesting); }
0076       const Event*    GetEvent        ( Long64_t ievt, Types::ETreeType type ) const
0077       {
0078          fCurrentTreeIdx = TreeIndex(type); fCurrentEventIdx = ievt; return GetEvent();
0079       }
0080 
0081 
0082 
0083 
0084       UInt_t    GetNVariables()   const;
0085       UInt_t    GetNTargets()     const;
0086       UInt_t    GetNSpectators()  const;
0087 
0088       void      SetCurrentEvent( Long64_t ievt         ) const { fCurrentEventIdx = ievt; }
0089       void      SetCurrentType ( Types::ETreeType type ) const { fCurrentTreeIdx = TreeIndex(type); }
0090       Types::ETreeType GetCurrentType() const;
0091 
0092       void                       SetEventCollection( std::vector<Event*>*, Types::ETreeType, Bool_t deleteEvents = true );
0093       const std::vector<Event*>& GetEventCollection( Types::ETreeType type = Types::kMaxTreeType ) const;
0094       const TTree*               GetEventCollectionAsTree();
0095 
0096       Long64_t  GetNEvtSigTest();
0097       Long64_t  GetNEvtBkgdTest();
0098       Long64_t  GetNEvtSigTrain();
0099       Long64_t  GetNEvtBkgdTrain();
0100 
0101       Bool_t    HasNegativeEventWeights() const { return fHasNegativeEventWeights; }
0102 
0103       Results*  GetResults   ( const TString &,
0104                                Types::ETreeType type,
0105                                Types::EAnalysisType analysistype );
0106       void      DeleteResults   ( const TString &,
0107                                   Types::ETreeType type,
0108                                   Types::EAnalysisType analysistype );
0109       void      DeleteAllResults(Types::ETreeType type,
0110                                  Types::EAnalysisType analysistype);
0111 
0112       void      SetVerbose( Bool_t ) {}
0113 
0114       // sets the number of blocks to which the training set is divided,
0115       // some of which are given to the Validation sample. As default they belong all to Training set.
0116       void      DivideTrainingSet( UInt_t blockNum );
0117 
0118       // sets a certain block from the origin training set to belong to either Training or Validation set
0119       void      MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges = kTRUE );
0120 
0121       void      IncrementNClassEvents( Int_t type, UInt_t classNumber );
0122       Long64_t  GetNClassEvents      ( Int_t type, UInt_t classNumber );
0123       void      ClearNClassEvents    ( Int_t type );
0124 
0125       TTree*    GetTree( Types::ETreeType type );
0126 
0127       // accessors for random and importance sampling
0128       void      InitSampling( Float_t fraction, Float_t weight, UInt_t seed = 0 );
0129       void      EventResult( Bool_t successful, Long64_t evtNumber = -1 );
0130       void      CreateSampling() const;
0131 
0132       UInt_t    TreeIndex(Types::ETreeType type) const;
0133 
0134    private:
0135 
0136       // data members
0137       void DestroyCollection( Types::ETreeType type, Bool_t deleteEvents );
0138 
0139       const DataSetInfo         *fdsi;                        ///<-> datasetinfo that created this dataset
0140 
0141       std::vector< std::vector<Event*>  > fEventCollection;   ///< list of events for training/testing/...
0142 
0143       std::vector< std::map< TString, Results* > > fResults;  ///<!  [train/test/...][method-identifier]
0144 
0145       mutable UInt_t             fCurrentTreeIdx;
0146       mutable Long64_t           fCurrentEventIdx;
0147 
0148       // event sampling
0149       std::vector<Char_t>        fSampling;                   ///< random or importance sampling (not all events are taken) !! Bool_t are stored ( no std::vector<bool> taken for speed (performance) issues )
0150       std::vector<Int_t>         fSamplingNEvents;            ///< number of events which should be sampled
0151       std::vector<Float_t>       fSamplingWeight;             ///< weight change factor [weight is indicating if sampling is random (1.0) or importance (<1.0)]
0152       mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingEventList;  ///< weights and indices for sampling
0153       mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingSelected;   ///< selected events
0154       TRandom3                   *fSamplingRandom;            ///<-> random generator for sampling
0155 
0156 
0157       // further things
0158       std::vector< std::vector<Long64_t> > fClassEvents;      ///< number of events of class 0,1,2,... in training[0]
0159                                                               ///< and testing[1] (+validation, trainingoriginal)
0160 
0161       Bool_t                     fHasNegativeEventWeights;    ///< true if at least one signal or bkg event has negative weight
0162 
0163       mutable MsgLogger*         fLogger;                     ///<! message logger
0164       MsgLogger& Log() const { return *fLogger; }
0165       std::vector<Char_t>        fBlockBelongToTraining;      ///< when dividing the dataset to blocks, sets whether
0166                                                               ///< the certain block is in the Training set or else
0167                                                               ///< in the validation set
0168                                                               ///< boolean are stored, taken std::vector<Char_t> for performance reasons (instead of std::vector<Bool_t>)
0169       Long64_t                   fTrainingBlockSize;          ///< block size into which the training dataset is divided
0170 
0171       void  ApplyTrainingBlockDivision();
0172       void  ApplyTrainingSetDivision();
0173    public:
0174 
0175        ClassDef(DataSet,1);
0176    };
0177 }
0178 
0179 
0180 //_______________________________________________________________________
0181 inline UInt_t TMVA::DataSet::TreeIndex(Types::ETreeType type) const
0182 {
0183    switch (type) {
0184    case Types::kMaxTreeType : return fCurrentTreeIdx;
0185    case Types::kTraining : return 0;
0186    case Types::kTesting : return 1;
0187    case Types::kValidation : return 2;
0188    case Types::kTrainingOriginal : return 3;
0189    default : return fCurrentTreeIdx;
0190    }
0191 }
0192 
0193 //_______________________________________________________________________
0194 inline TMVA::Types::ETreeType TMVA::DataSet::GetCurrentType() const
0195 {
0196    switch (fCurrentTreeIdx) {
0197    case 0: return Types::kTraining;
0198    case 1: return Types::kTesting;
0199    case 2: return Types::kValidation;
0200    case 3: return Types::kTrainingOriginal;
0201    }
0202    return Types::kMaxTreeType;
0203 }
0204 
0205 //_______________________________________________________________________
0206 inline Long64_t TMVA::DataSet::GetNEvents(Types::ETreeType type) const
0207 {
0208    Int_t treeIdx = TreeIndex(type);
0209    if (fSampling.size() > UInt_t(treeIdx) && fSampling.at(treeIdx)) {
0210       return fSamplingSelected.at(treeIdx).size();
0211    }
0212    return GetEventCollection(type).size();
0213 }
0214 
0215 //_______________________________________________________________________
0216 inline const std::vector<TMVA::Event*>& TMVA::DataSet::GetEventCollection( TMVA::Types::ETreeType type ) const
0217 {
0218    return fEventCollection.at(TreeIndex(type));
0219 }
0220 
0221 
0222 #endif