File indexing completed on 2025-01-18 10:11:11
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031 #ifndef ROOT_TMVA_TransformationHandler
0032 #define ROOT_TMVA_TransformationHandler
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042 #include "TList.h"
0043 #include "TString.h"
0044 #include <vector>
0045
0046 #include "TMVA/DataSetInfo.h"
0047
0048 namespace TMVA {
0049
0050 class Event;
0051 class DataSet;
0052 class Ranking;
0053 class VariableTransformBase;
0054 class MsgLogger;
0055
0056 class TransformationHandler {
0057 public:
0058
0059 struct VariableStat {
0060 Double_t fMean;
0061 Double_t fRMS;
0062 Double_t fMin;
0063 Double_t fMax;
0064 };
0065
0066 TransformationHandler( DataSetInfo&, const TString& callerName );
0067 ~TransformationHandler();
0068
0069 TString GetName() const;
0070 TString GetVariableAxisTitle( const VariableInfo& info ) const;
0071
0072 const Event* Transform(const Event*) const;
0073 const Event* InverseTransform(const Event*, Bool_t suppressIfNoTargets=true ) const;
0074
0075
0076 void SetTransformationReferenceClass( Int_t cls );
0077
0078 VariableTransformBase* AddTransformation(VariableTransformBase*, Int_t cls );
0079 const TList& GetTransformationList() const { return fTransformations; }
0080 Int_t GetNumOfTransformations() const { return fTransformations.GetSize(); }
0081 const std::vector<Event*>* CalcTransformations( const std::vector<Event*>&, Bool_t createNewVector = kFALSE );
0082
0083 void CalcStats( const std::vector<Event*>& events );
0084 void AddStats ( Int_t k, UInt_t ivar, Double_t mean, Double_t rms, Double_t min, Double_t max );
0085 Double_t GetMean ( Int_t ivar, Int_t cls = -1 ) const;
0086 Double_t GetRMS ( Int_t ivar, Int_t cls = -1 ) const;
0087 Double_t GetMin ( Int_t ivar, Int_t cls = -1 ) const;
0088 Double_t GetMax ( Int_t ivar, Int_t cls = -1 ) const;
0089
0090 void WriteToStream ( std::ostream& o ) const;
0091 void AddXMLTo ( void* parent=nullptr ) const;
0092 void ReadFromStream( std::istream& istr );
0093 void ReadFromXML ( void* trfsnode );
0094
0095
0096 void MakeFunction(std::ostream& fout, const TString& fncName, Int_t part) const;
0097
0098
0099 void PrintVariableRanking() const;
0100
0101
0102 std::vector<TString>* GetTransformationStringsOfLastTransform() const;
0103 const char* GetNameOfLastTransform() const;
0104
0105
0106 void SetCallerName( const TString& name );
0107 const TString& GetCallerName() const { return fCallerName; }
0108
0109
0110 TDirectory* GetRootDir() const { return fRootBaseDir; }
0111 void SetRootDir( TDirectory *d ) { fRootBaseDir = d; }
0112
0113 void PlotVariables( const std::vector<Event*>& events, TDirectory* theDirectory = nullptr );
0114
0115 private:
0116
0117
0118
0119
0120
0121
0122 const TMVA::VariableInfo& Variable(UInt_t ivar) const { return fDataSetInfo.GetVariableInfos().at(ivar); }
0123 const TMVA::VariableInfo& Target (UInt_t itgt) const { return fDataSetInfo.GetTargetInfos()[itgt]; }
0124
0125 DataSet* Data() { return fDataSetInfo.GetDataSet(); }
0126
0127 DataSetInfo& fDataSetInfo;
0128 TList fTransformations;
0129 std::vector< Int_t > fTransformationsReferenceClasses;
0130 std::vector<std::vector<TMVA::TransformationHandler::VariableStat> > fVariableStats;
0131
0132 Int_t fNumC;
0133
0134 std::vector<Ranking*> fRanking;
0135 TDirectory* fRootBaseDir;
0136 TString fCallerName;
0137 mutable MsgLogger* fLogger;
0138 MsgLogger& Log() const { return *fLogger; }
0139 };
0140 }
0141
0142 #endif