Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:11

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Peter Speckmayer,Joerg Stelzer, Helge Voss
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : VariableTransformBase                                                 *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Pre-transformation of input variables (base class)                        *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0015  *      Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland           *
0016  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
0017  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0018  *                                                                                *
0019  * Copyright (c) 2005:                                                            *
0020  *      CERN, Switzerland                                                         *
0021  *      U. of Victoria, Canada                                                    *
0022  *      MPI-K Heidelberg, Germany                                                 *
0023  *                                                                                *
0024  * Redistribution and use in source and binary forms, with or without             *
0025  * modification, are permitted according to the terms listed in LICENSE           *
0026  * (see tmva/doc/LICENSE)                                          *
0027  **********************************************************************************/
0028 
0029 #ifndef ROOT_TMVA_VariableTransformBase
0030 #define ROOT_TMVA_VariableTransformBase
0031 
0032 //////////////////////////////////////////////////////////////////////////
0033 //                                                                      //
0034 // VariableTransformBase                                                //
0035 //                                                                      //
0036 // Linear interpolation class                                           //
0037 //                                                                      //
0038 //////////////////////////////////////////////////////////////////////////
0039 
0040 #include <vector>
0041 #include <utility>
0042 
0043 #include "TH1.h"
0044 #include "TDirectory.h"
0045 #include "TString.h"
0046 
0047 #include "TMVA/Types.h"
0048 #include "TMVA/Event.h"
0049 #include "TMVA/VariableInfo.h"
0050 #include "TMVA/DataSetInfo.h"
0051 
0052 namespace TMVA {
0053 
0054    class VariableTransformBase : public TObject {
0055 
0056    public:
0057 
0058       typedef std::vector<std::pair<Char_t,UInt_t> > VectorOfCharAndInt;
0059       typedef VectorOfCharAndInt::iterator       ItVarTypeIdx;
0060       typedef VectorOfCharAndInt::const_iterator ItVarTypeIdxConst;
0061 
0062       VariableTransformBase( DataSetInfo& dsi, Types::EVariableTransform tf, const TString& trfName );
0063       virtual ~VariableTransformBase( void );
0064 
0065       virtual void         Initialize() = 0;
0066       virtual Bool_t       PrepareTransformation (const std::vector<Event*>&  ) = 0;
0067       virtual const Event* Transform       ( const Event* const, Int_t cls ) const = 0;
0068       virtual const Event* InverseTransform( const Event* const, Int_t cls ) const = 0;
0069 
0070       // accessors
0071       void   SetEnabled  ( Bool_t e ) { fEnabled = e; }
0072       void   SetNormalise( Bool_t n ) { fNormalise = n; }
0073       Bool_t IsEnabled()    const { return fEnabled; }
0074       Bool_t IsCreated()    const { return fCreated; }
0075       Bool_t IsNormalised() const { return fNormalise; }
0076 
0077       // variable selection
0078       virtual void           SelectInput( const TString& inputVariables, Bool_t putIntoVariables = kFALSE );
0079       virtual Bool_t         GetInput ( const Event* event, std::vector<Float_t>& input, std::vector<Char_t>& mask, Bool_t backTransform = kFALSE  ) const;
0080       virtual void           SetOutput( Event* event, std::vector<Float_t>& output, std::vector<Char_t>& mask, const Event* oldEvent = nullptr, Bool_t backTransform = kFALSE ) const;
0081       virtual void           CountVariableTypes( UInt_t& nvars, UInt_t& ntgts, UInt_t& nspcts ) const;
0082 
0083       void ToggleInputSortOrder( Bool_t sortOrder ) { fSortGet = sortOrder; }
0084       void SetOutputDataSetInfo( DataSetInfo* outputDsi ) { fDsiOutput = outputDsi; }
0085 
0086 
0087 
0088       void SetUseSignalTransform( Bool_t e=kTRUE) { fUseSignalTransform = e; }
0089       Bool_t UseSignalTransform() const { return fUseSignalTransform; }
0090 
0091       virtual const char* GetName() const { return fTransformName.Data(); }
0092       TString GetShortName() const { TString a(fTransformName); a.ReplaceAll("Transform",""); return a; }
0093 
0094       virtual void WriteTransformationToStream ( std::ostream& o ) const = 0;
0095       virtual void ReadTransformationFromStream( std::istream& istr, const TString& classname="" ) = 0;
0096 
0097       virtual void AttachXMLTo(void* parent) = 0;
0098       virtual void ReadFromXML( void* trfnode ) = 0;
0099 
0100       Types::EVariableTransform GetVariableTransform() const { return fVariableTransform; }
0101 
0102       // writer of function code
0103       virtual void MakeFunction( std::ostream& fout, const TString& fncName, Int_t part,
0104                                  UInt_t trCounter, Int_t cls ) = 0;
0105 
0106       // provides string vector giving explicit transformation
0107       virtual std::vector<TString>* GetTransformationStrings( Int_t cls ) const;
0108       virtual void PrintTransformation( std::ostream & ) {}
0109 
0110       const std::vector<TMVA::VariableInfo>& Variables() const { return fVariables; }
0111       const std::vector<TMVA::VariableInfo>& Targets()   const { return fTargets;   }
0112       const std::vector<TMVA::VariableInfo>& Spectators()   const { return fSpectators;   }
0113 
0114       MsgLogger& Log() const { return *fLogger; }
0115 
0116       void SetTMVAVersion(TMVAVersion_t v) { fTMVAVersion = v; }
0117 
0118    protected:
0119 
0120       void CalcNorm( const std::vector<const Event*>& );
0121 
0122       void SetCreated( Bool_t c = kTRUE ) { fCreated = c; }
0123       void SetNVariables( UInt_t i )      { fNVars = i; }
0124       void SetName( const TString& c )    { fTransformName = c; }
0125 
0126       UInt_t GetNVariables() const { return fDsi.GetNVariables(); }
0127       UInt_t GetNTargets()   const { return fDsi.GetNTargets(); }
0128       UInt_t GetNSpectators() const { return fDsi.GetNSpectators(); }
0129 
0130       DataSetInfo& fDsi;
0131       DataSetInfo* fDsiOutput;
0132 
0133       std::vector<TMVA::VariableInfo>& Variables() { return fVariables; }
0134       std::vector<TMVA::VariableInfo>& Targets() { return fTargets; }
0135       std::vector<TMVA::VariableInfo>& Spectators() { return fSpectators; }
0136       Int_t GetNClasses() const { return fDsi.GetNClasses(); }
0137 
0138 
0139       mutable Event*           fTransformedEvent;     ///< holds the current transformed event
0140       mutable Event*           fBackTransformedEvent; ///< holds the current back-transformed event
0141 
0142       // variable selection
0143       VectorOfCharAndInt               fGet;          ///< get variables/targets/spectators
0144       VectorOfCharAndInt               fPut;          ///< put variables/targets/spectators
0145 
0146    private:
0147 
0148       Types::EVariableTransform fVariableTransform;   ///< Decorrelation, PCA, etc.
0149 
0150       void UpdateNorm( Int_t ivar, Double_t x );
0151 
0152       Bool_t                           fUseSignalTransform; ///< true if transformation bases on signal data
0153       Bool_t                           fEnabled;            ///< has been enabled
0154       Bool_t                           fCreated;            ///< has been created
0155       Bool_t                           fNormalise;          ///< normalise input variables
0156       UInt_t                           fNVars;              ///< number of variables
0157       TString                          fTransformName;      ///< name of transformation
0158       std::vector<TMVA::VariableInfo>  fVariables;          ///< event variables [saved to weight file]
0159       std::vector<TMVA::VariableInfo>  fTargets;            ///< event targets [saved to weight file --> TODO ]
0160       std::vector<TMVA::VariableInfo>  fSpectators;         ///< event spectators [saved to weight file --> TODO ]
0161 
0162       mutable Bool_t                   fVariableTypesAreCounted; ///< true if variable types have been counted already
0163       mutable UInt_t                   fNVariables;         ///< number of variables to be transformed
0164       mutable UInt_t                   fNTargets;           ///< number of targets to be transformed
0165       mutable UInt_t                   fNSpectators;        ///< number of spectators to be transformed
0166 
0167       Bool_t                           fSortGet;            ///< if true, sort the variables into the order as defined by the user at the var definition
0168                                                             ///< if false, sort the variables according to the order given for the var transformation
0169 
0170    protected:
0171 
0172       TMVAVersion_t                    fTMVAVersion;
0173 
0174       mutable MsgLogger* fLogger;                     ///<! message logger
0175 
0176       ClassDef(VariableTransformBase,0);   //  Base class for variable transformations
0177    };
0178 
0179 } // namespace TMVA
0180 
0181 #endif