Warning, file /include/root/TMVA/VariableTransformBase.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifndef ROOT_TMVA_VariableTransformBase
0030 #define ROOT_TMVA_VariableTransformBase
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 #include <vector>
0041 #include <utility>
0042
0043 #include "TH1.h"
0044 #include "TDirectory.h"
0045 #include "TString.h"
0046
0047 #include "TMVA/Types.h"
0048 #include "TMVA/Event.h"
0049 #include "TMVA/VariableInfo.h"
0050 #include "TMVA/DataSetInfo.h"
0051
0052 namespace TMVA {
0053
0054 class VariableTransformBase : public TObject {
0055
0056 public:
0057
0058 typedef std::vector<std::pair<Char_t,UInt_t> > VectorOfCharAndInt;
0059 typedef VectorOfCharAndInt::iterator ItVarTypeIdx;
0060 typedef VectorOfCharAndInt::const_iterator ItVarTypeIdxConst;
0061
0062 VariableTransformBase( DataSetInfo& dsi, Types::EVariableTransform tf, const TString& trfName );
0063 virtual ~VariableTransformBase( void );
0064
0065 virtual void Initialize() = 0;
0066 virtual Bool_t PrepareTransformation (const std::vector<Event*>& ) = 0;
0067 virtual const Event* Transform ( const Event* const, Int_t cls ) const = 0;
0068 virtual const Event* InverseTransform( const Event* const, Int_t cls ) const = 0;
0069
0070
0071 void SetEnabled ( Bool_t e ) { fEnabled = e; }
0072 void SetNormalise( Bool_t n ) { fNormalise = n; }
0073 Bool_t IsEnabled() const { return fEnabled; }
0074 Bool_t IsCreated() const { return fCreated; }
0075 Bool_t IsNormalised() const { return fNormalise; }
0076
0077
0078 virtual void SelectInput( const TString& inputVariables, Bool_t putIntoVariables = kFALSE );
0079 virtual Bool_t GetInput ( const Event* event, std::vector<Float_t>& input, std::vector<Char_t>& mask, Bool_t backTransform = kFALSE ) const;
0080 virtual void SetOutput( Event* event, std::vector<Float_t>& output, std::vector<Char_t>& mask, const Event* oldEvent = nullptr, Bool_t backTransform = kFALSE ) const;
0081 virtual void CountVariableTypes( UInt_t& nvars, UInt_t& ntgts, UInt_t& nspcts ) const;
0082
0083 void ToggleInputSortOrder( Bool_t sortOrder ) { fSortGet = sortOrder; }
0084 void SetOutputDataSetInfo( DataSetInfo* outputDsi ) { fDsiOutput = outputDsi; }
0085
0086
0087
0088 void SetUseSignalTransform( Bool_t e=kTRUE) { fUseSignalTransform = e; }
0089 Bool_t UseSignalTransform() const { return fUseSignalTransform; }
0090
0091 const char* GetName() const override { return fTransformName.Data(); }
0092 TString GetShortName() const { TString a(fTransformName); a.ReplaceAll("Transform",""); return a; }
0093
0094 virtual void WriteTransformationToStream ( std::ostream& o ) const = 0;
0095 virtual void ReadTransformationFromStream( std::istream& istr, const TString& classname="" ) = 0;
0096
0097 virtual void AttachXMLTo(void* parent) = 0;
0098 virtual void ReadFromXML( void* trfnode ) = 0;
0099
0100 Types::EVariableTransform GetVariableTransform() const { return fVariableTransform; }
0101
0102
0103 virtual void MakeFunction( std::ostream& fout, const TString& fncName, Int_t part,
0104 UInt_t trCounter, Int_t cls ) = 0;
0105
0106
0107 virtual std::vector<TString>* GetTransformationStrings( Int_t cls ) const;
0108 virtual void PrintTransformation( std::ostream & ) {}
0109
0110 const std::vector<TMVA::VariableInfo>& Variables() const { return fVariables; }
0111 const std::vector<TMVA::VariableInfo>& Targets() const { return fTargets; }
0112 const std::vector<TMVA::VariableInfo>& Spectators() const { return fSpectators; }
0113
0114 MsgLogger& Log() const { return *fLogger; }
0115
0116 void SetTMVAVersion(TMVAVersion_t v) { fTMVAVersion = v; }
0117
0118 protected:
0119
0120 void CalcNorm( const std::vector<const Event*>& );
0121
0122 void SetCreated( Bool_t c = kTRUE ) { fCreated = c; }
0123 void SetNVariables( UInt_t i ) { fNVars = i; }
0124 void SetName( const TString& c ) { fTransformName = c; }
0125
0126 UInt_t GetNVariables() const { return fDsi.GetNVariables(); }
0127 UInt_t GetNTargets() const { return fDsi.GetNTargets(); }
0128 UInt_t GetNSpectators() const { return fDsi.GetNSpectators(); }
0129
0130 DataSetInfo& fDsi;
0131 DataSetInfo* fDsiOutput;
0132
0133 std::vector<TMVA::VariableInfo>& Variables() { return fVariables; }
0134 std::vector<TMVA::VariableInfo>& Targets() { return fTargets; }
0135 std::vector<TMVA::VariableInfo>& Spectators() { return fSpectators; }
0136 Int_t GetNClasses() const { return fDsi.GetNClasses(); }
0137
0138
0139 mutable Event* fTransformedEvent;
0140 mutable Event* fBackTransformedEvent;
0141
0142
0143 VectorOfCharAndInt fGet;
0144 VectorOfCharAndInt fPut;
0145
0146 private:
0147
0148 Types::EVariableTransform fVariableTransform;
0149
0150 void UpdateNorm( Int_t ivar, Double_t x );
0151
0152 Bool_t fUseSignalTransform;
0153 Bool_t fEnabled;
0154 Bool_t fCreated;
0155 Bool_t fNormalise;
0156 UInt_t fNVars;
0157 TString fTransformName;
0158 std::vector<TMVA::VariableInfo> fVariables;
0159 std::vector<TMVA::VariableInfo> fTargets;
0160 std::vector<TMVA::VariableInfo> fSpectators;
0161
0162 mutable Bool_t fVariableTypesAreCounted;
0163 mutable UInt_t fNVariables;
0164 mutable UInt_t fNTargets;
0165 mutable UInt_t fNSpectators;
0166
0167 Bool_t fSortGet;
0168
0169
0170 protected:
0171
0172 TMVAVersion_t fTMVAVersion;
0173
0174 mutable MsgLogger* fLogger;
0175
0176 ClassDefOverride(VariableTransformBase,0);
0177 };
0178
0179 }
0180
0181 #endif