Warning, file /include/root/TMVA/NodekNN.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifndef ROOT_TMVA_NodekNN
0027 #define ROOT_TMVA_NodekNN
0028
0029
0030 #include <cassert>
0031 #include <list>
0032 #include <string>
0033 #include <iostream>
0034 #include <utility>
0035
0036
0037 #include "RtypesCore.h"
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062 namespace TMVA
0063 {
0064 namespace kNN
0065 {
0066 template <class T>
0067 class Node
0068 {
0069
0070 public:
0071
0072 Node(const Node *parent, const T &event, Int_t mod);
0073 ~Node();
0074
0075 const Node* Add(const T &event, UInt_t depth);
0076
0077 void SetNodeL(Node *node);
0078 void SetNodeR(Node *node);
0079
0080 const T& GetEvent() const;
0081
0082 const Node* GetNodeL() const;
0083 const Node* GetNodeR() const;
0084 const Node* GetNodeP() const;
0085
0086 Double_t GetWeight() const;
0087
0088 Float_t GetVarDis() const;
0089 Float_t GetVarMin() const;
0090 Float_t GetVarMax() const;
0091
0092 UInt_t GetMod() const;
0093
0094 void Print() const;
0095 void Print(std::ostream& os, const std::string &offset = "") const;
0096
0097 private:
0098
0099
0100
0101 Node();
0102 Node(const Node &);
0103 const Node& operator=(const Node &);
0104
0105 private:
0106
0107 const Node* fNodeP;
0108
0109 Node* fNodeL;
0110 Node* fNodeR;
0111
0112 const T fEvent;
0113
0114 const Float_t fVarDis;
0115
0116 Float_t fVarMin;
0117 Float_t fVarMax;
0118
0119 const UInt_t fMod;
0120 };
0121
0122
0123 template<class T>
0124 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
0125 const Node<T> *node, const T &event, UInt_t nfind);
0126
0127
0128
0129 template<class T>
0130 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
0131 const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
0132
0133
0134 template <class T>
0135 UInt_t Depth(const Node<T> *node);
0136
0137
0138
0139
0140
0141
0142
0143
0144 template <class T>
0145 inline void Node<T>::SetNodeL(Node<T> *node)
0146 {
0147 fNodeL = node;
0148 }
0149
0150 template <class T>
0151 inline void Node<T>::SetNodeR(Node<T> *node)
0152 {
0153 fNodeR = node;
0154 }
0155
0156 template <class T>
0157 inline const T& Node<T>::GetEvent() const
0158 {
0159 return fEvent;
0160 }
0161
0162 template <class T>
0163 inline const Node<T>* Node<T>::GetNodeL() const
0164 {
0165 return fNodeL;
0166 }
0167
0168 template <class T>
0169 inline const Node<T>* Node<T>::GetNodeR() const
0170 {
0171 return fNodeR;
0172 }
0173
0174 template <class T>
0175 inline const Node<T>* Node<T>::GetNodeP() const
0176 {
0177 return fNodeP;
0178 }
0179
0180 template <class T>
0181 inline Double_t Node<T>::GetWeight() const
0182 {
0183 return fEvent.GetWeight();
0184 }
0185
0186 template <class T>
0187 inline Float_t Node<T>::GetVarDis() const
0188 {
0189 return fVarDis;
0190 }
0191
0192 template <class T>
0193 inline Float_t Node<T>::GetVarMin() const
0194 {
0195 return fVarMin;
0196 }
0197
0198 template <class T>
0199 inline Float_t Node<T>::GetVarMax() const
0200 {
0201 return fVarMax;
0202 }
0203
0204 template <class T>
0205 inline UInt_t Node<T>::GetMod() const
0206 {
0207 return fMod;
0208 }
0209
0210
0211
0212
0213 template <class T>
0214 inline UInt_t Depth(const Node<T> *node)
0215 {
0216 if (!node) return 0;
0217 else return Depth(node->GetNodeP()) + 1;
0218 }
0219
0220 }
0221 }
0222
0223
0224 template<class T>
0225 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
0226 :fNodeP(parent),
0227 fNodeL(nullptr),
0228 fNodeR(nullptr),
0229 fEvent(event),
0230 fVarDis(event.GetVar(mod)),
0231 fVarMin(fVarDis),
0232 fVarMax(fVarDis),
0233 fMod(mod)
0234 {}
0235
0236
0237 template<class T>
0238 TMVA::kNN::Node<T>::~Node()
0239 {
0240 if (fNodeL) delete fNodeL;
0241 if (fNodeR) delete fNodeR;
0242 }
0243
0244
0245
0246
0247
0248
0249 template<class T>
0250 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
0251 {
0252
0253 assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
0254
0255 const Float_t value = event.GetVar(fMod);
0256
0257 fVarMin = std::min(fVarMin, value);
0258 fVarMax = std::max(fVarMax, value);
0259
0260 Node<T> *node = nullptr;
0261 if (value < fVarDis) {
0262 if (fNodeL)
0263 {
0264 return fNodeL->Add(event, depth + 1);
0265 }
0266 else {
0267 fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
0268 node = fNodeL;
0269 }
0270 }
0271 else {
0272 if (fNodeR) {
0273 return fNodeR->Add(event, depth + 1);
0274 }
0275 else {
0276 fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
0277 node = fNodeR;
0278 }
0279 }
0280
0281 return node;
0282 }
0283
0284
0285 template<class T>
0286 void TMVA::kNN::Node<T>::Print() const
0287 {
0288 Print(std::cout);
0289 }
0290
0291
0292 template<class T>
0293 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
0294 {
0295 os << offset << "-----------------------------------------------------------" << std::endl;
0296 os << offset << "Node: mod " << fMod
0297 << " at " << fVarDis
0298 << " with weight: " << GetWeight() << std::endl
0299 << offset << fEvent;
0300
0301 if (fNodeL) {
0302 os << offset << "Has left node " << std::endl;
0303 }
0304 if (fNodeR) {
0305 os << offset << "Has right node" << std::endl;
0306 }
0307
0308 if (fNodeL) {
0309 os << offset << "PrInt_t left node " << std::endl;
0310 fNodeL->Print(os, offset + " ");
0311 }
0312 if (fNodeR) {
0313 os << offset << "PrInt_t right node" << std::endl;
0314 fNodeR->Print(os, offset + " ");
0315 }
0316
0317 if (!fNodeL && !fNodeR) {
0318 os << std::endl;
0319 }
0320 }
0321
0322
0323
0324
0325
0326
0327
0328
0329
0330
0331
0332 template<class T>
0333 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
0334 const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
0335 {
0336 if (!node || nfind < 1) {
0337 return 0;
0338 }
0339
0340 const Float_t value = event.GetVar(node->GetMod());
0341
0342 if (node->GetWeight() > 0.0) {
0343
0344 Float_t max_dist = 0.0;
0345
0346 if (!nlist.empty()) {
0347
0348 max_dist = nlist.back().second;
0349
0350 if (nlist.size() == nfind) {
0351 if (value > node->GetVarMax() &&
0352 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
0353 return 0;
0354 }
0355 if (value < node->GetVarMin() &&
0356 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
0357 return 0;
0358 }
0359 }
0360 }
0361
0362 const Float_t distance = event.GetDist(node->GetEvent());
0363
0364 Bool_t insert_this = kFALSE;
0365 Bool_t remove_back = kFALSE;
0366
0367 if (nlist.size() < nfind) {
0368 insert_this = kTRUE;
0369 }
0370 else if (nlist.size() == nfind) {
0371 if (distance < max_dist) {
0372 insert_this = kTRUE;
0373 remove_back = kTRUE;
0374 }
0375 }
0376 else {
0377 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
0378 return 1;
0379 }
0380
0381 if (insert_this) {
0382
0383
0384 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
0385
0386
0387 for (; lit != nlist.end(); ++lit) {
0388 if (distance < lit->second) {
0389 break;
0390 }
0391 else {
0392 continue;
0393 }
0394 }
0395
0396 nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
0397
0398 if (remove_back) {
0399 nlist.pop_back();
0400 }
0401 }
0402 }
0403
0404 UInt_t count = 1;
0405 if (node->GetNodeL() && node->GetNodeR()) {
0406 if (value < node->GetVarDis()) {
0407 count += Find(nlist, node->GetNodeL(), event, nfind);
0408 count += Find(nlist, node->GetNodeR(), event, nfind);
0409 }
0410 else {
0411 count += Find(nlist, node->GetNodeR(), event, nfind);
0412 count += Find(nlist, node->GetNodeL(), event, nfind);
0413 }
0414 }
0415 else {
0416 if (node->GetNodeL()) {
0417 count += Find(nlist, node->GetNodeL(), event, nfind);
0418 }
0419 if (node->GetNodeR()) {
0420 count += Find(nlist, node->GetNodeR(), event, nfind);
0421 }
0422 }
0423
0424 return count;
0425 }
0426
0427
0428
0429
0430
0431
0432
0433
0434
0435
0436
0437
0438
0439
0440 template<class T>
0441 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
0442 const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
0443 {
0444
0445 if (!node || !(nfind < 0.0)) {
0446 return 0;
0447 }
0448
0449 const Float_t value = event.GetVar(node->GetMod());
0450
0451 if (node->GetWeight() > 0.0) {
0452
0453 Float_t max_dist = 0.0;
0454
0455 if (!nlist.empty()) {
0456
0457 max_dist = nlist.back().second;
0458
0459 if (!(ncurr < nfind)) {
0460 if (value > node->GetVarMax() &&
0461 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
0462 return 0;
0463 }
0464 if (value < node->GetVarMin() &&
0465 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
0466 return 0;
0467 }
0468 }
0469 }
0470
0471 const Float_t distance = event.GetDist(node->GetEvent());
0472
0473 Bool_t insert_this = kFALSE;
0474
0475 if (ncurr < nfind) {
0476 insert_this = kTRUE;
0477 }
0478 else if (!nlist.empty()) {
0479 if (distance < max_dist) {
0480 insert_this = kTRUE;
0481 }
0482 }
0483 else {
0484 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
0485 return 1;
0486 }
0487
0488 if (insert_this) {
0489
0490 ncurr = 0;
0491
0492
0493
0494 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
0495
0496
0497 for (; lit != nlist.end(); ++lit) {
0498 if (distance < lit->second) {
0499 break;
0500 }
0501
0502 ncurr += lit -> first -> GetWeight();
0503 }
0504
0505 lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
0506
0507 for (; lit != nlist.end(); ++lit) {
0508 ncurr += lit -> first -> GetWeight();
0509 if (!(ncurr < nfind)) {
0510 ++lit;
0511 break;
0512 }
0513 }
0514
0515 if(lit != nlist.end())
0516 {
0517 nlist.erase(lit, nlist.end());
0518 }
0519 }
0520 }
0521
0522 UInt_t count = 1;
0523 if (node->GetNodeL() && node->GetNodeR()) {
0524 if (value < node->GetVarDis()) {
0525 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0526 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0527 }
0528 else {
0529 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0530 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0531 }
0532 }
0533 else {
0534 if (node->GetNodeL()) {
0535 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0536 }
0537 if (node->GetNodeR()) {
0538 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0539 }
0540 }
0541
0542 return count;
0543 }
0544
0545 #endif
0546