File indexing completed on 2025-01-18 10:11:01
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifndef ROOT_TMVA_ModulekNN
0027 #define ROOT_TMVA_ModulekNN
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038 #include <iosfwd>
0039 #include <map>
0040 #include <string>
0041 #include <vector>
0042 #include <list>
0043
0044
0045 #include "RtypesCore.h"
0046 #include "TRandom3.h"
0047 #include "ThreadLocalStorage.h"
0048 #include "TMVA/NodekNN.h"
0049
0050 namespace TMVA {
0051
0052 class MsgLogger;
0053
0054 namespace kNN {
0055
0056 typedef Float_t VarType;
0057 typedef std::vector<VarType> VarVec;
0058
0059 class Event {
0060 public:
0061
0062 Event();
0063 Event(const VarVec &vec, Double_t weight, Short_t type);
0064 Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
0065 ~Event();
0066
0067 Double_t GetWeight() const;
0068
0069 VarType GetVar(UInt_t i) const;
0070 VarType GetTgt(UInt_t i) const;
0071
0072 UInt_t GetNVar() const;
0073 UInt_t GetNTgt() const;
0074
0075 Short_t GetType() const;
0076
0077
0078 VarType GetDist(VarType var, UInt_t ivar) const;
0079 VarType GetDist(const Event &other) const;
0080
0081 void SetTargets(const VarVec &tvec);
0082 const VarVec& GetTargets() const;
0083 const VarVec& GetVars() const;
0084
0085 void Print() const;
0086 void Print(std::ostream& os) const;
0087
0088 private:
0089
0090 VarVec fVar;
0091 VarVec fTgt;
0092
0093 Double_t fWeight;
0094 Short_t fType;
0095 };
0096
0097 typedef std::vector<TMVA::kNN::Event> EventVec;
0098 typedef std::pair<const Node<Event> *, VarType> Elem;
0099 typedef std::list<Elem> List;
0100
0101 std::ostream& operator<<(std::ostream& os, const Event& event);
0102
0103 class ModulekNN
0104 {
0105 public:
0106
0107 typedef std::map<int, std::vector<Double_t> > VarMap;
0108
0109 public:
0110
0111 ModulekNN();
0112 ~ModulekNN();
0113
0114 void Clear();
0115
0116 void Add(const Event &event);
0117
0118 Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
0119
0120 Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
0121 Bool_t Find(UInt_t nfind, const std::string &option) const;
0122
0123 const EventVec& GetEventVec() const;
0124
0125 const List& GetkNNList() const;
0126 const Event& GetkNNEvent() const;
0127
0128 const VarMap& GetVarMap() const;
0129
0130 const std::map<Int_t, Double_t>& GetMetric() const;
0131
0132 void Print() const;
0133 void Print(std::ostream &os) const;
0134
0135 private:
0136
0137 Node<Event>* Optimize(UInt_t optimize_depth);
0138
0139 void ComputeMetric(UInt_t ifrac);
0140
0141 const Event Scale(const Event &event) const;
0142
0143 private:
0144
0145
0146
0147 static TRandom3& GetRndmThreadLocal() {TTHREAD_TLS_DECL_ARG(TRandom3,fgRndm,1); return fgRndm;};
0148
0149 UInt_t fDimn;
0150
0151 Node<Event> *fTree;
0152
0153 std::map<Int_t, Double_t> fVarScale;
0154
0155 mutable List fkNNList;
0156 mutable Event fkNNEvent;
0157
0158 std::map<Short_t, UInt_t> fCount;
0159
0160 EventVec fEvent;
0161 VarMap fVar;
0162
0163 mutable MsgLogger* fLogger;
0164 MsgLogger& Log() const { return *fLogger; }
0165 };
0166
0167
0168
0169
0170 inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
0171 {
0172 const VarType var2 = GetVar(ivar);
0173 return (var1 - var2) * (var1 - var2);
0174 }
0175 inline Double_t Event::GetWeight() const
0176 {
0177 return fWeight;
0178 }
0179 inline VarType Event::GetVar(const UInt_t i) const
0180 {
0181 return fVar[i];
0182 }
0183 inline VarType Event::GetTgt(const UInt_t i) const
0184 {
0185 return fTgt[i];
0186 }
0187
0188 inline UInt_t Event::GetNVar() const
0189 {
0190 return fVar.size();
0191 }
0192 inline UInt_t Event::GetNTgt() const
0193 {
0194 return fTgt.size();
0195 }
0196 inline Short_t Event::GetType() const
0197 {
0198 return fType;
0199 }
0200
0201
0202
0203
0204 inline const List& ModulekNN::GetkNNList() const
0205 {
0206 return fkNNList;
0207 }
0208 inline const Event& ModulekNN::GetkNNEvent() const
0209 {
0210 return fkNNEvent;
0211 }
0212 inline const EventVec& ModulekNN::GetEventVec() const
0213 {
0214 return fEvent;
0215 }
0216 inline const ModulekNN::VarMap& ModulekNN::GetVarMap() const
0217 {
0218 return fVar;
0219 }
0220 inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
0221 {
0222 return fVarScale;
0223 }
0224
0225 }
0226 }
0227
0228 #endif
0229