File indexing completed on 2025-01-18 10:11:11
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifndef ROOT_TMVA_VariableTransformBase
0030 #define ROOT_TMVA_VariableTransformBase
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 #include <vector>
0041 #include <utility>
0042
0043 #include "TH1.h"
0044 #include "TDirectory.h"
0045 #include "TString.h"
0046
0047 #include "TMVA/Types.h"
0048 #include "TMVA/Event.h"
0049 #include "TMVA/VariableInfo.h"
0050 #include "TMVA/DataSetInfo.h"
0051
0052 namespace TMVA {
0053
0054 class VariableTransformBase : public TObject {
0055
0056 public:
0057
0058 typedef std::vector<std::pair<Char_t,UInt_t> > VectorOfCharAndInt;
0059 typedef VectorOfCharAndInt::iterator ItVarTypeIdx;
0060 typedef VectorOfCharAndInt::const_iterator ItVarTypeIdxConst;
0061
0062 VariableTransformBase( DataSetInfo& dsi, Types::EVariableTransform tf, const TString& trfName );
0063 virtual ~VariableTransformBase( void );
0064
0065 virtual void Initialize() = 0;
0066 virtual Bool_t PrepareTransformation (const std::vector<Event*>& ) = 0;
0067 virtual const Event* Transform ( const Event* const, Int_t cls ) const = 0;
0068 virtual const Event* InverseTransform( const Event* const, Int_t cls ) const = 0;
0069
0070
0071 void SetEnabled ( Bool_t e ) { fEnabled = e; }
0072 void SetNormalise( Bool_t n ) { fNormalise = n; }
0073 Bool_t IsEnabled() const { return fEnabled; }
0074 Bool_t IsCreated() const { return fCreated; }
0075 Bool_t IsNormalised() const { return fNormalise; }
0076
0077
0078 virtual void SelectInput( const TString& inputVariables, Bool_t putIntoVariables = kFALSE );
0079 virtual Bool_t GetInput ( const Event* event, std::vector<Float_t>& input, std::vector<Char_t>& mask, Bool_t backTransform = kFALSE ) const;
0080 virtual void SetOutput( Event* event, std::vector<Float_t>& output, std::vector<Char_t>& mask, const Event* oldEvent = nullptr, Bool_t backTransform = kFALSE ) const;
0081 virtual void CountVariableTypes( UInt_t& nvars, UInt_t& ntgts, UInt_t& nspcts ) const;
0082
0083 void ToggleInputSortOrder( Bool_t sortOrder ) { fSortGet = sortOrder; }
0084 void SetOutputDataSetInfo( DataSetInfo* outputDsi ) { fDsiOutput = outputDsi; }
0085
0086
0087
0088 void SetUseSignalTransform( Bool_t e=kTRUE) { fUseSignalTransform = e; }
0089 Bool_t UseSignalTransform() const { return fUseSignalTransform; }
0090
0091 virtual const char* GetName() const { return fTransformName.Data(); }
0092 TString GetShortName() const { TString a(fTransformName); a.ReplaceAll("Transform",""); return a; }
0093
0094 virtual void WriteTransformationToStream ( std::ostream& o ) const = 0;
0095 virtual void ReadTransformationFromStream( std::istream& istr, const TString& classname="" ) = 0;
0096
0097 virtual void AttachXMLTo(void* parent) = 0;
0098 virtual void ReadFromXML( void* trfnode ) = 0;
0099
0100 Types::EVariableTransform GetVariableTransform() const { return fVariableTransform; }
0101
0102
0103 virtual void MakeFunction( std::ostream& fout, const TString& fncName, Int_t part,
0104 UInt_t trCounter, Int_t cls ) = 0;
0105
0106
0107 virtual std::vector<TString>* GetTransformationStrings( Int_t cls ) const;
0108 virtual void PrintTransformation( std::ostream & ) {}
0109
0110 const std::vector<TMVA::VariableInfo>& Variables() const { return fVariables; }
0111 const std::vector<TMVA::VariableInfo>& Targets() const { return fTargets; }
0112 const std::vector<TMVA::VariableInfo>& Spectators() const { return fSpectators; }
0113
0114 MsgLogger& Log() const { return *fLogger; }
0115
0116 void SetTMVAVersion(TMVAVersion_t v) { fTMVAVersion = v; }
0117
0118 protected:
0119
0120 void CalcNorm( const std::vector<const Event*>& );
0121
0122 void SetCreated( Bool_t c = kTRUE ) { fCreated = c; }
0123 void SetNVariables( UInt_t i ) { fNVars = i; }
0124 void SetName( const TString& c ) { fTransformName = c; }
0125
0126 UInt_t GetNVariables() const { return fDsi.GetNVariables(); }
0127 UInt_t GetNTargets() const { return fDsi.GetNTargets(); }
0128 UInt_t GetNSpectators() const { return fDsi.GetNSpectators(); }
0129
0130 DataSetInfo& fDsi;
0131 DataSetInfo* fDsiOutput;
0132
0133 std::vector<TMVA::VariableInfo>& Variables() { return fVariables; }
0134 std::vector<TMVA::VariableInfo>& Targets() { return fTargets; }
0135 std::vector<TMVA::VariableInfo>& Spectators() { return fSpectators; }
0136 Int_t GetNClasses() const { return fDsi.GetNClasses(); }
0137
0138
0139 mutable Event* fTransformedEvent;
0140 mutable Event* fBackTransformedEvent;
0141
0142
0143 VectorOfCharAndInt fGet;
0144 VectorOfCharAndInt fPut;
0145
0146 private:
0147
0148 Types::EVariableTransform fVariableTransform;
0149
0150 void UpdateNorm( Int_t ivar, Double_t x );
0151
0152 Bool_t fUseSignalTransform;
0153 Bool_t fEnabled;
0154 Bool_t fCreated;
0155 Bool_t fNormalise;
0156 UInt_t fNVars;
0157 TString fTransformName;
0158 std::vector<TMVA::VariableInfo> fVariables;
0159 std::vector<TMVA::VariableInfo> fTargets;
0160 std::vector<TMVA::VariableInfo> fSpectators;
0161
0162 mutable Bool_t fVariableTypesAreCounted;
0163 mutable UInt_t fNVariables;
0164 mutable UInt_t fNTargets;
0165 mutable UInt_t fNSpectators;
0166
0167 Bool_t fSortGet;
0168
0169
0170 protected:
0171
0172 TMVAVersion_t fTMVAVersion;
0173
0174 mutable MsgLogger* fLogger;
0175
0176 ClassDef(VariableTransformBase,0);
0177 };
0178
0179 }
0180
0181 #endif