Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:01

0001 // @(#)root/tmva $Id$
0002 // Author: Rustem Ospanov
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : ModulekNN                                                             *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Module for k-nearest neighbor algorithm                                   *
0012  *                                                                                *
0013  * Author:                                                                        *
0014  *      Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA             *
0015  *                                                                                *
0016  * Copyright (c) 2007:                                                            *
0017  *      CERN, Switzerland                                                         *
0018  *      MPI-K Heidelberg, Germany                                                 *
0019  *      U. of Texas at Austin, USA                                                *
0020  *                                                                                *
0021  * Redistribution and use in source and binary forms, with or without             *
0022  * modification, are permitted according to the terms listed in LICENSE           *
0023  * (see tmva/doc/LICENSE)                                          *
0024  **********************************************************************************/
0025 
0026 #ifndef ROOT_TMVA_ModulekNN
0027 #define ROOT_TMVA_ModulekNN
0028 
0029 //______________________________________________________________________
0030 /*
0031   kNN::Event describes point in input variable vector-space, with
0032   additional functionality like distance between points
0033 */
0034 //______________________________________________________________________
0035 
0036 
0037 // C++
0038 #include <iosfwd>
0039 #include <map>
0040 #include <string>
0041 #include <vector>
0042 #include <list>
0043 
0044 // ROOT
0045 #include "RtypesCore.h"
0046 #include "TRandom3.h"
0047 #include "ThreadLocalStorage.h"
0048 #include "TMVA/NodekNN.h"
0049 
0050 namespace TMVA {
0051 
0052    class MsgLogger;
0053 
0054    namespace kNN {
0055 
0056       typedef Float_t VarType;
0057       typedef std::vector<VarType> VarVec;
0058 
0059       class Event {
0060       public:
0061 
0062          Event();
0063          Event(const VarVec &vec, Double_t weight, Short_t type);
0064          Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
0065          ~Event();
0066 
0067          Double_t GetWeight() const;
0068 
0069          VarType GetVar(UInt_t i) const;
0070          VarType GetTgt(UInt_t i) const;
0071 
0072          UInt_t GetNVar() const;
0073          UInt_t GetNTgt() const;
0074 
0075          Short_t GetType() const;
0076 
0077          // keep these two function separate
0078          VarType GetDist(VarType var, UInt_t ivar) const;
0079          VarType GetDist(const Event &other) const;
0080 
0081          void SetTargets(const VarVec &tvec);
0082          const VarVec& GetTargets() const;
0083          const VarVec& GetVars() const;
0084 
0085          void Print() const;
0086          void Print(std::ostream& os) const;
0087 
0088       private:
0089 
0090          VarVec fVar; ///< coordinates (variables) for knn search
0091          VarVec fTgt; ///< targets for regression analysis
0092 
0093          Double_t fWeight; // event weight
0094          Short_t fType; // event type ==0 or == 1, expand it to arbitrary class types?
0095       };
0096 
0097       typedef std::vector<TMVA::kNN::Event> EventVec;
0098       typedef std::pair<const Node<Event> *, VarType> Elem;
0099       typedef std::list<Elem> List;
0100 
0101       std::ostream& operator<<(std::ostream& os, const Event& event);
0102 
0103       class ModulekNN
0104       {
0105       public:
0106 
0107          typedef std::map<int, std::vector<Double_t> > VarMap;
0108 
0109       public:
0110 
0111          ModulekNN();
0112          ~ModulekNN();
0113 
0114          void Clear();
0115 
0116          void Add(const Event &event);
0117 
0118          Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
0119 
0120          Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
0121          Bool_t Find(UInt_t nfind, const std::string &option) const;
0122 
0123          const EventVec& GetEventVec() const;
0124 
0125          const List& GetkNNList() const;
0126          const Event& GetkNNEvent() const;
0127 
0128          const VarMap& GetVarMap() const;
0129 
0130          const std::map<Int_t, Double_t>& GetMetric() const;
0131 
0132          void Print() const;
0133          void Print(std::ostream &os) const;
0134 
0135       private:
0136 
0137          Node<Event>* Optimize(UInt_t optimize_depth);
0138 
0139          void ComputeMetric(UInt_t ifrac);
0140 
0141          const Event Scale(const Event &event) const;
0142 
0143       private:
0144 
0145          // This is a workaround for OSx where static thread_local data members are
0146          // not supported. The C++ solution would indeed be the following:
0147          static TRandom3& GetRndmThreadLocal() {TTHREAD_TLS_DECL_ARG(TRandom3,fgRndm,1); return fgRndm;};
0148 
0149          UInt_t fDimn;
0150 
0151          Node<Event> *fTree;
0152 
0153          std::map<Int_t, Double_t> fVarScale;
0154 
0155          mutable List  fkNNList;     // latest result from kNN search
0156          mutable Event fkNNEvent;    // latest event used for kNN search
0157 
0158          std::map<Short_t, UInt_t> fCount; // count number of events of each type
0159 
0160          EventVec fEvent; // vector of all events used to build tree and analysis
0161          VarMap   fVar;   // sorted map of variables in each dimension for all event types
0162 
0163          mutable MsgLogger* fLogger;   //! message logger
0164          MsgLogger& Log() const { return *fLogger; }
0165       };
0166 
0167       //
0168       // inlined functions for Event class
0169       //
0170       inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
0171          {
0172             const VarType var2 = GetVar(ivar);
0173             return (var1 - var2) * (var1 - var2);
0174          }
0175       inline Double_t Event::GetWeight() const
0176          {
0177             return fWeight;
0178          }
0179       inline VarType Event::GetVar(const UInt_t i) const
0180          {
0181             return fVar[i];
0182          }
0183       inline VarType Event::GetTgt(const UInt_t i) const
0184          {
0185             return fTgt[i];
0186          }
0187 
0188       inline UInt_t Event::GetNVar() const
0189          {
0190             return fVar.size();
0191          }
0192       inline UInt_t Event::GetNTgt() const
0193          {
0194             return fTgt.size();
0195          }
0196       inline Short_t Event::GetType() const
0197          {
0198             return fType;
0199          }
0200 
0201       //
0202       // inline functions for ModulekNN class
0203       //
0204       inline const List& ModulekNN::GetkNNList() const
0205       {
0206          return fkNNList;
0207       }
0208       inline const Event& ModulekNN::GetkNNEvent() const
0209       {
0210          return fkNNEvent;
0211       }
0212       inline const EventVec& ModulekNN::GetEventVec() const
0213       {
0214          return fEvent;
0215       }
0216       inline const ModulekNN::VarMap& ModulekNN::GetVarMap() const
0217       {
0218          return fVar;
0219       }
0220       inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
0221          {
0222             return fVarScale;
0223          }
0224 
0225    } // end of kNN namespace
0226 } // end of TMVA namespace
0227 
0228 #endif
0229