Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:03

0001 // @(#)root/tmva $Id$
0002 // Author: Rustem Ospanov
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : Node                                                                  *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      kd-tree (binary tree) template                                            *
0012  *                                                                                *
0013  * Author:                                                                        *
0014  *      Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA             *
0015  *                                                                                *
0016  * Copyright (c) 2007:                                                            *
0017  *      CERN, Switzerland                                                         *
0018  *      MPI-K Heidelberg, Germany                                                 *
0019  *      U. of Texas at Austin, USA                                                *
0020  *                                                                                *
0021  * Redistribution and use in source and binary forms, with or without             *
0022  * modification, are permitted according to the terms listed in LICENSE           *
0023  * (see tmva/doc/LICENSE)                                          *
0024  **********************************************************************************/
0025 
0026 #ifndef ROOT_TMVA_NodekNN
0027 #define ROOT_TMVA_NodekNN
0028 
0029 // C++
0030 #include <cassert>
0031 #include <list>
0032 #include <string>
0033 #include <iostream>
0034 #include <utility>
0035 
0036 // ROOT
0037 #include "RtypesCore.h"
0038 
0039 /*! \class TMVA::kNN::Node
0040 \ingroup TMVA
0041 This file contains binary tree and global function template
0042 that searches tree for k-nearest neigbors
0043 
0044 Node class template parameter T has to provide these functions:
0045   rtype GetVar(UInt_t) const;
0046   - rtype is any type convertible to Float_t
0047   UInt_t GetNVar(void) const;
0048   rtype GetWeight(void) const;
0049   - rtype is any type convertible to Double_t
0050 
0051 Find function template parameter T has to provide these functions:
0052 (in addition to above requirements)
0053   rtype GetDist(Float_t, UInt_t) const;
0054   - rtype is any type convertible to Float_t
0055   rtype GetDist(const T &) const;
0056   - rtype is any type convertible to Float_t
0057 
0058   where T::GetDist(Float_t, UInt_t) <= T::GetDist(const T &)
0059   for any pair of events and any variable number for these events
0060 */
0061 
0062 namespace TMVA
0063 {
0064    namespace kNN
0065    {
0066       template <class T>
0067          class Node
0068          {
0069 
0070          public:
0071 
0072             Node(const Node *parent, const T &event, Int_t mod);
0073             ~Node();
0074 
0075             const Node* Add(const T &event, UInt_t depth);
0076 
0077             void SetNodeL(Node *node);
0078             void SetNodeR(Node *node);
0079 
0080             const T& GetEvent() const;
0081 
0082             const Node* GetNodeL() const;
0083             const Node* GetNodeR() const;
0084             const Node* GetNodeP() const;
0085 
0086             Double_t GetWeight() const;
0087 
0088             Float_t GetVarDis() const;
0089             Float_t GetVarMin() const;
0090             Float_t GetVarMax() const;
0091 
0092             UInt_t GetMod() const;
0093 
0094             void Print() const;
0095             void Print(std::ostream& os, const std::string &offset = "") const;
0096 
0097          private:
0098 
0099             // these methods are private and not implemented by design
0100             // use provided public constructor for all uses of this template class
0101             Node();
0102             Node(const Node &);
0103             const Node& operator=(const Node &);
0104 
0105          private:
0106 
0107             const Node* fNodeP;
0108 
0109             Node* fNodeL;
0110             Node* fNodeR;
0111 
0112             const T fEvent;
0113 
0114             const Float_t fVarDis;
0115 
0116             Float_t fVarMin;
0117             Float_t fVarMax;
0118 
0119             const UInt_t fMod;
0120          };
0121 
0122       // recursive search for k-nearest neighbor: k = nfind
0123       template<class T>
0124          UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
0125                      const Node<T> *node, const T &event, UInt_t nfind);
0126 
0127       // recursive search for k-nearest neighbor
0128       // find k events with sum of event weights >= nfind
0129       template<class T>
0130          UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
0131                      const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
0132 
0133       // recursively travel upward until root node is reached
0134       template <class T>
0135          UInt_t Depth(const Node<T> *node);
0136 
0137       // prInt_t node content and content of its children
0138       //template <class T>
0139       //std::ostream& operator<<(std::ostream& os, const Node<T> &node);
0140 
0141       //
0142       // Inlined functions for Node template
0143       //
0144       template <class T>
0145          inline void Node<T>::SetNodeL(Node<T> *node)
0146          {
0147             fNodeL = node;
0148          }
0149 
0150       template <class T>
0151          inline void Node<T>::SetNodeR(Node<T> *node)
0152          {
0153             fNodeR = node;
0154          }
0155 
0156       template <class T>
0157          inline const T& Node<T>::GetEvent() const
0158          {
0159             return fEvent;
0160          }
0161 
0162       template <class T>
0163          inline const Node<T>* Node<T>::GetNodeL() const
0164          {
0165             return fNodeL;
0166          }
0167 
0168       template <class T>
0169          inline const Node<T>* Node<T>::GetNodeR() const
0170          {
0171             return fNodeR;
0172          }
0173 
0174       template <class T>
0175          inline const Node<T>* Node<T>::GetNodeP() const
0176          {
0177             return fNodeP;
0178          }
0179 
0180       template <class T>
0181          inline Double_t Node<T>::GetWeight() const
0182          {
0183             return fEvent.GetWeight();
0184          }
0185 
0186       template <class T>
0187          inline Float_t Node<T>::GetVarDis() const
0188          {
0189             return fVarDis;
0190          }
0191 
0192       template <class T>
0193          inline Float_t Node<T>::GetVarMin() const
0194          {
0195             return fVarMin;
0196          }
0197 
0198       template <class T>
0199          inline Float_t Node<T>::GetVarMax() const
0200          {
0201             return fVarMax;
0202          }
0203 
0204       template <class T>
0205          inline UInt_t Node<T>::GetMod() const
0206          {
0207             return fMod;
0208          }
0209 
0210       //
0211       // Inlined global function(s)
0212       //
0213       template <class T>
0214          inline UInt_t Depth(const Node<T> *node)
0215          {
0216             if (!node) return 0;
0217             else return Depth(node->GetNodeP()) + 1;
0218          }
0219 
0220    } // end of kNN namespace
0221 } // end of TMVA namespace
0222 
0223 ////////////////////////////////////////////////////////////////////////////////
0224 template<class T>
0225 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
0226 :fNodeP(parent),
0227    fNodeL(nullptr),
0228    fNodeR(nullptr),
0229    fEvent(event),
0230    fVarDis(event.GetVar(mod)),
0231    fVarMin(fVarDis),
0232    fVarMax(fVarDis),
0233    fMod(mod)
0234 {}
0235 
0236 ////////////////////////////////////////////////////////////////////////////////
0237 template<class T>
0238 TMVA::kNN::Node<T>::~Node()
0239 {
0240    if (fNodeL) delete fNodeL;
0241    if (fNodeR) delete fNodeR;
0242 }
0243 
0244 ////////////////////////////////////////////////////////////////////////////////
0245 /// This is Node member function that adds a new node to a binary tree.
0246 /// each node contains maximum and minimum values of splitting variable
0247 /// left or right nodes are added based on value of splitting variable
0248 
0249 template<class T>
0250 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
0251 {
0252 
0253    assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
0254 
0255    const Float_t value = event.GetVar(fMod);
0256 
0257    fVarMin = std::min(fVarMin, value);
0258    fVarMax = std::max(fVarMax, value);
0259 
0260    Node<T> *node = nullptr;
0261    if (value < fVarDis) {
0262       if (fNodeL)
0263          {
0264             return fNodeL->Add(event, depth + 1);
0265          }
0266       else {
0267          fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
0268          node = fNodeL;
0269       }
0270    }
0271    else {
0272       if (fNodeR) {
0273          return fNodeR->Add(event, depth + 1);
0274       }
0275       else {
0276          fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
0277          node = fNodeR;
0278       }
0279    }
0280 
0281    return node;
0282 }
0283 
0284 ////////////////////////////////////////////////////////////////////////////////
0285 template<class T>
0286 void TMVA::kNN::Node<T>::Print() const
0287 {
0288    Print(std::cout);
0289 }
0290 
0291 ////////////////////////////////////////////////////////////////////////////////
0292 template<class T>
0293 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
0294 {
0295    os << offset << "-----------------------------------------------------------" << std::endl;
0296    os << offset << "Node: mod " << fMod
0297       << " at " << fVarDis
0298       << " with weight: " << GetWeight() << std::endl
0299       << offset << fEvent;
0300 
0301    if (fNodeL) {
0302       os << offset << "Has left node " << std::endl;
0303    }
0304    if (fNodeR) {
0305       os << offset << "Has right node" << std::endl;
0306    }
0307 
0308    if (fNodeL) {
0309       os << offset << "PrInt_t left node " << std::endl;
0310       fNodeL->Print(os, offset + " ");
0311    }
0312    if (fNodeR) {
0313       os << offset << "PrInt_t right node" << std::endl;
0314       fNodeR->Print(os, offset + " ");
0315    }
0316 
0317    if (!fNodeL && !fNodeR) {
0318       os << std::endl;
0319    }
0320 }
0321 
0322 ////////////////////////////////////////////////////////////////////////////////
0323 /// This is a global templated function that searches for k-nearest neighbors.
0324 /// list contains k or less nodes that are closest to event.
0325 /// only nodes with positive weights are added to list.
0326 /// each node contains maximum and minimum values of splitting variable
0327 /// for all its children - this range is checked to avoid descending into
0328 /// nodes that are definitely outside current minimum neighbourhood.
0329 ///
0330 /// This function should be modified with care.
0331 
0332 template<class T>
0333 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
0334                        const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
0335 {
0336    if (!node || nfind < 1) {
0337       return 0;
0338    }
0339 
0340    const Float_t value = event.GetVar(node->GetMod());
0341 
0342    if (node->GetWeight() > 0.0) {
0343 
0344       Float_t max_dist = 0.0;
0345 
0346       if (!nlist.empty()) {
0347 
0348          max_dist = nlist.back().second;
0349 
0350          if (nlist.size() == nfind) {
0351             if (value > node->GetVarMax() &&
0352                 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
0353                return 0;
0354             }
0355             if (value < node->GetVarMin() &&
0356                 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
0357                return 0;
0358             }
0359          }
0360       }
0361 
0362       const Float_t distance = event.GetDist(node->GetEvent());
0363 
0364       Bool_t insert_this = kFALSE;
0365       Bool_t remove_back = kFALSE;
0366 
0367       if (nlist.size() < nfind) {
0368          insert_this = kTRUE;
0369       }
0370       else if (nlist.size() == nfind) {
0371          if (distance < max_dist) {
0372             insert_this = kTRUE;
0373             remove_back = kTRUE;
0374          }
0375       }
0376       else {
0377          std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
0378          return 1;
0379       }
0380 
0381       if (insert_this) {
0382          // need typename keyword because qualified dependent names
0383          // are not valid types unless preceded by 'typename'.
0384          typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
0385 
0386          // find a place where current node should be inserted
0387          for (; lit != nlist.end(); ++lit) {
0388             if (distance < lit->second) {
0389                break;
0390             }
0391             else {
0392                continue;
0393             }
0394          }
0395 
0396          nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
0397 
0398          if (remove_back) {
0399             nlist.pop_back();
0400          }
0401       }
0402    }
0403 
0404    UInt_t count = 1;
0405    if (node->GetNodeL() && node->GetNodeR()) {
0406       if (value < node->GetVarDis()) {
0407          count += Find(nlist, node->GetNodeL(), event, nfind);
0408          count += Find(nlist, node->GetNodeR(), event, nfind);
0409       }
0410       else {
0411          count += Find(nlist, node->GetNodeR(), event, nfind);
0412          count += Find(nlist, node->GetNodeL(), event, nfind);
0413       }
0414    }
0415    else {
0416       if (node->GetNodeL()) {
0417          count += Find(nlist, node->GetNodeL(), event, nfind);
0418       }
0419       if (node->GetNodeR()) {
0420          count += Find(nlist, node->GetNodeR(), event, nfind);
0421       }
0422    }
0423 
0424    return count;
0425 }
0426 
0427 ////////////////////////////////////////////////////////////////////////////////
0428 /// This is a global templated function that searches for k-nearest neighbors.
0429 /// list contains all nodes that are closest to event
0430 /// and have sum of event weights >= nfind.
0431 /// Only nodes with positive weights are added to list.
0432 /// Requirement for used classes:
0433 ///  - each node contains maximum and minimum values of splitting variable
0434 ///    for all its children
0435 ///  - min and max range is checked to avoid descending into
0436 ///    nodes that are definitely outside current minimum neighbourhood.
0437 ///
0438 /// This function should be modified with care.
0439 
0440 template<class T>
0441 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
0442                        const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
0443 {
0444 
0445    if (!node || !(nfind < 0.0)) {
0446       return 0;
0447    }
0448 
0449    const Float_t value = event.GetVar(node->GetMod());
0450 
0451    if (node->GetWeight() > 0.0) {
0452 
0453       Float_t max_dist = 0.0;
0454 
0455       if (!nlist.empty()) {
0456 
0457          max_dist = nlist.back().second;
0458 
0459          if (!(ncurr < nfind)) {
0460             if (value > node->GetVarMax() &&
0461                 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
0462                return 0;
0463             }
0464             if (value < node->GetVarMin() &&
0465                 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
0466                return 0;
0467             }
0468          }
0469       }
0470 
0471       const Float_t distance = event.GetDist(node->GetEvent());
0472 
0473       Bool_t insert_this = kFALSE;
0474 
0475       if (ncurr < nfind) {
0476          insert_this = kTRUE;
0477       }
0478       else if (!nlist.empty()) {
0479          if (distance < max_dist) {
0480             insert_this = kTRUE;
0481          }
0482       }
0483       else {
0484          std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
0485          return 1;
0486       }
0487 
0488       if (insert_this) {
0489          // (re)compute total current weight when inserting a new node
0490          ncurr = 0;
0491 
0492          // need typename keyword because qualified dependent names
0493          // are not valid types unless preceded by 'typename'.
0494          typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
0495 
0496          // find a place where current node should be inserted
0497          for (; lit != nlist.end(); ++lit) {
0498             if (distance < lit->second) {
0499                break;
0500             }
0501 
0502             ncurr += lit -> first -> GetWeight();
0503          }
0504 
0505          lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
0506 
0507          for (; lit != nlist.end(); ++lit) {
0508             ncurr += lit -> first -> GetWeight();
0509             if (!(ncurr < nfind)) {
0510                ++lit;
0511                break;
0512             }
0513          }
0514 
0515          if(lit != nlist.end())
0516             {
0517                nlist.erase(lit, nlist.end());
0518             }
0519       }
0520    }
0521 
0522    UInt_t count = 1;
0523    if (node->GetNodeL() && node->GetNodeR()) {
0524       if (value < node->GetVarDis()) {
0525          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0526          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0527       }
0528       else {
0529          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0530          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0531       }
0532    }
0533    else {
0534       if (node->GetNodeL()) {
0535          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
0536       }
0537       if (node->GetNodeR()) {
0538          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
0539       }
0540    }
0541 
0542    return count;
0543 }
0544 
0545 #endif
0546