File indexing completed on 2025-01-18 10:11:03
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifndef ROOT_TMVA_NodekNN
0027 #define ROOT_TMVA_NodekNN
0028
0029
0030 #include <cassert>
0031 #include <list>
0032 #include <string>
0033 #include <iostream>
0034 #include <utility>
0035
0036
0037 #include "RtypesCore.h"
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062 namespace TMVA
0063 {
0064 namespace kNN
0065 {
0066 template <class T>
0067 class Node
0068 {
0069
0070 public:
0071
0072 Node(const Node *parent, const T &event, Int_t mod);
0073 ~Node();
0074
0075 const Node* Add(const T &event, UInt_t depth);
0076
0077 void SetNodeL(Node *node);
0078 void SetNodeR(Node *node);
0079
0080 const T& GetEvent() const;
0081
0082 const Node* GetNodeL() const;
0083 const Node* GetNodeR() const;
0084 const Node* GetNodeP() const;
0085
0086 Double_t GetWeight() const;
0087
0088 Float_t GetVarDis() const;
0089 Float_t GetVarMin() const;
0090 Float_t GetVarMax() const;
0091
0092 UInt_t GetMod() const;
0093
0094 void Print() const;
0095 void Print(std::ostream& os, const std::string &offset = "") const;
0096
0097 private:
0098
0099
0100
0101 Node();
0102 Node(const Node &);
0103 const Node& operator=(const Node &);
0104
0105 private:
0106
0107 const Node* fNodeP;
0108
0109 Node* fNodeL;
0110 Node* fNodeR;
0111
0112 const T fEvent;
0113
0114 const Float_t fVarDis;
0115
0116 Float_t fVarMin;
0117 Float_t fVarMax;
0118
0119 const UInt_t fMod;
0120 };
0121
0122
0123 template<class T>
0124 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
0125 const Node<T> *node, const T &event, UInt_t nfind);
0126
0127
0128
0129 template<class T>
0130 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
0131 const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
0132
0133
0134 template <class T>
0135 UInt_t Depth(const Node<T> *node);
0136
0137
0138
0139
0140
0141
0142
0143
0144 template <class T>
0145 inline void Node<T>::SetNodeL(Node<T> *node)
0146 {
0147 fNodeL = node;
0148 }
0149
0150 template <class T>
0151 inline void Node<T>::SetNodeR(Node<T> *node)
0152 {
0153 fNodeR = node;
0154 }
0155
0156 template <class T>
0157 inline const T& Node<T>::GetEvent() const
0158 {
0159 return fEvent;
0160 }
0161
0162 template <class T>
0163 inline const Node<T>* Node<T>::GetNodeL() const
0164 {
0165 return fNodeL;
0166 }
0167
0168 template <class T>
0169 inline const Node<T>* Node<T>::GetNodeR() const
0170 {
0171 return fNodeR;
0172 }
0173
0174 template <class T>
0175 inline const Node<T>* Node<T>::GetNodeP() const
0176 {
0177 return fNodeP;
0178 }
0179
0180 template <class T>
0181 inline Double_t Node<T>::GetWeight() const
0182 {
0183 return fEvent.GetWeight();
0184 }
0185
0186 template <class T>
0187 inline Float_t Node<T>::GetVarDis() const
0188 {
0189 return fVarDis;
0190 }
0191
0192 template <class T>
0193 inline Float_t Node<T>::GetVarMin() const
0194 {
0195 return fVarMin;
0196 }
0197
0198 template <class T>
0199 inline Float_t Node<T>::GetVarMax() const
0200 {
0201 return fVarMax;
0202 }
0203
0204 template <class T>
0205 inline UInt_t Node<T>::GetMod() const
0206 {
0207 return fMod;
0208 }
0209
0210
0211
0212
0213 template <class T>
0214 inline UInt_t Depth(const Node<T> *node)
0215 {
0216 if (!node) return 0;
0217 else return Depth(node->GetNodeP()) + 1;
0218 }
0219
0220 }
0221 }
0222
0223
0224 template<class T>
0225 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
0226 :fNodeP(parent),
0227 fNodeL(nullptr),
0228 fNodeR(nullptr),
0229 fEvent(event),
0230 fVarDis(event.GetVar(mod)),
0231 fVarMin(fVarDis),
0232 fVarMax(fVarDis),
0233 fMod(mod)
0234 {}
0235
0236
0237 template<class T>
0238 TMVA::kNN::Node<T>::~Node()
0239 {
0240 if (fNodeL) delete fNodeL;
0241 if (fNodeR) delete fNodeR;
0242 }
0243
0244
0245
0246
0247
0248
0249 template<class T>
0250 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
0251 {
0252
0253 assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
0254
0255 const Float_t value = event.GetVar(fMod);
0256
0257 fVarMin = std::min(fVarMin, value);
0258 fVarMax = std::max(fVarMax, value);
0259
0260 Node<T> *node = nullptr;
0261 if (value < fVarDis) {
0262 if (fNodeL)
0263 {
0264 return fNodeL->Add(event, depth + 1);
0265 }
0266 else {
0267 fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
0268 node = fNodeL;
0269 }
0270 }
0271 else {
0272 if (fNodeR) {
0273 return fNodeR->Add(event, depth + 1);
0274 }
0275 else {
0276 fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
0277 node = fNodeR;
0278 }
0279 }
0280
0281 return node;
0282 }
0283
0284
0285 template<class T>
0286 void TMVA::kNN::Node<T>::Print() const
0287 {
0288 Print(std::cout);
0289 }
0290
0291
0292 template<class T>
0293 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
0294 {
0295 os << offset << "-----------------------------------------------------------" << std::endl;
0296 os << offset << "Node: mod " << fMod
0297 << " at " << fVarDis
0298 << " with weight: " << GetWeight() << std::endl
0299 << offset << fEvent;
0300
0301 if (fNodeL) {
0302 os << offset << "Has left node " << std::endl;
0303 }
0304 if (fNodeR) {
0305 os << offset << "Has right node" << std::endl;
0306 }
0307
0308 if (fNodeL) {
0309 os << offset << "PrInt_t left node " << std::endl;
0310 fNodeL->Print(os, offset + " ");
0311 }
0312 if (fNodeR) {
0313 os << offset << "PrInt_t right node" << std::endl;
0314 fNodeR->Print(os, offset + " ");
0315 }
0316
0317 if (!fNodeL && !fNodeR) {
0318 os << std::endl;
0319 }
0320 }
0321
0322
0323
0324
0325
0326
0327
0328
0329
0330
0331
0332 template<class T>
0333 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
0334 const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
0335 {
0336 if (!node || nfind < 1) {
0337 return 0;
0338 }
0339
0340 const Float_t value = event.GetVar(node->GetMod());
0341
0342 if (node->GetWeight() > 0.0) {
0343
0344 Float_t max_dist = 0.0;
0345
0346 if (!nlist.empty()) {
0347
0348 max_dist = nlist.back().second;
0349
0350 if (nlist.size() == nfind) {
0351 if (value > node->GetVarMax() &&
0352 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
0353 return 0;
0354 }
0355 if (value < node->GetVarMin() &&
0356 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
0357 return 0;
0358 }
0359 }
0360 }
0361
0362 const Float_t distance = event.GetDist(node->GetEvent());
0363
0364 Bool_t insert_this = kFALSE;
0365 Bool_t remove_back = kFALSE;
0366
0367 if (nlist.size() < nfind) {
0368 insert_this = kTRUE;
0369 }
0370 else if (nlist.size() == nfind) {
0371 if (distance < max_dist) {
0372 insert_this = kTRUE;
0373 remove_back = kTRUE;
0374 }
0375 }
0376 else {
0377 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
0378 return 1;
0379 }
0380
0381 if (insert_this) {
0382
0383
0384 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
0385
0386
0387 for (; lit != nlist.end(); ++lit) {
0388 if (distance < lit->second) {
0389 break;
0390 }
0391 else {
0392 continue;
0393 }
0394 }
0395
0396 nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
0397
0398 if (remove_back) {
0399 nlist.pop_back();
0400 }
0401 }
0402 }
0403
0404 UInt_t count = 1;
0405 if (node->GetNodeL() && node->GetNodeR()) {
0406 if (value < node->GetVarDis()) {
0407 count += Find(nlist, node->GetNodeL(), event, nfind);
0408 count += Find(nlist, node->GetNodeR(), event, nfind);
0409 }
0410 else {
0411 count += Find(nlist, node->GetNodeR(), event, nfind);
0412 count += Find(nlist, node->GetNodeL(), event, nfind);
0413 }
0414 }
0415 else {
0416 if (node->GetNodeL()) {
0417 count += Find(nlist, node->GetNodeL(), event, nfind);
0418 }
0419 if (node->GetNodeR()) {
0420 count += Find(nlist, node->GetNodeR(), event, nfind);
0421 }
0422 }
0423
0424 return count;
0425 }
0426
0427
0428
0429
0430
0431
0432
0433
0434
0435
0436
0437
0438
0439
0440 template<class T>
0441 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
0442 const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
0443 {
0444
0445 if (!node || !(nfind < 0.0)) {
0446 return 0;
0447 }
0448
0449 const Float_t value = event.GetVar(node->GetMod());
0450
0451 if (node->GetWeight() > 0.0) {
0452
0453 Float_t max_dist = 0.0;
0454
0455 if (!nlist.empty()) {
0456
0457 max_dist = nlist.back().second;
0458
0459 if (!(ncurr < nfind)) {
0460 if (value > node->GetVarMax() &&
0461 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
0462 return 0;
0463 }
0464 if (value < node->GetVarMin() &&
0465 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
0466 return 0;
0467 }
0468 }
0469 }
0470
0471 const Float_t distance = event.GetDist(node->GetEvent());
0472
0473 Bool_t insert_this = kFALSE;
0474
0475 if (ncurr < nfind) {
0476 insert_this = kTRUE;
0477 }
0478 else if (!nlist.empty()) {
0479 if (distance < max_dist) {
0480 insert_this = kTRUE;
0481 }
0482 }
0483 else {
0484 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
0485 return 1;
0486 }
0487
0488 if (insert_this) {
0489
0490 ncurr = 0;
0491
0492
0493
0494 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
0495
0496
0497 for (; lit != nlist.end(); ++lit) {
0498 if (distance < lit->second) {
0499 break;
0500 }
0501
0502 ncurr += lit -> first -> GetWeight();
0503 }
0504
0505 lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
0506
0507 for (; lit != nlist.end(); ++lit) {
0508 ncurr += lit -> first -> GetWeight();
0509 if (!(ncurr < nfind)) {
0510 ++lit;
0511 break;
0512 }
0513 }
0514
0515 if(lit != nlist.end())
0516 {
0517 nlist.erase(lit, nlist.end());
0518 }
0519 }
0520 }
0521
0522 UInt_t count = 1;
0523 if (node->GetNodeL() && node->GetNodeR()) {
0524 if (value < node->GetVarDis()) {
0525 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0526 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0527 }
0528 else {
0529 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0530 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0531 }
0532 }
0533 else {
0534 if (node->GetNodeL()) {
0535 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0536 }
0537 if (node->GetNodeR()) {
0538 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0539 }
0540 }
0541
0542 return count;
0543 }
0544
0545 #endif
0546