Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:52

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Jan Therhaag
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : Event                                                                 *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      LossFunction and associated classes                                       *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
0015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
0016  *      Peter Speckmayer <Peter.Speckmayer@cern.ch>  - CERN, Switzerland          *
0017  *      Jan Therhaag       <Jan.Therhaag@cern.ch>     - U of Bonn, Germany        *
0018  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
0019  *                                                                                *
0020  * Copyright (c) 2005-2011:                                                       *
0021  *      CERN, Switzerland                                                         *
0022  *      U. of Victoria, Canada                                                    *
0023  *      MPI-K Heidelberg, Germany                                                 *
0024  *      U. of Bonn, Germany                                                       *
0025  *                                                                                *
0026  * Redistribution and use in source and binary forms, with or without             *
0027  * modification, are permitted according to the terms listed in LICENSE           *
0028  * (http://mva.sourceforge.net/license.txt)                                       *
0029  **********************************************************************************/
0030 
0031 #ifndef ROOT_TMVA_LossFunction
0032 #define ROOT_TMVA_LossFunction
0033 
0034 //#include <iosfwd>
0035 #include <vector>
0036 #include <map>
0037 #include "TMVA/Event.h"
0038 
0039 #include "TMVA/Types.h"
0040 
0041 
0042 namespace TMVA {
0043 
0044    ///////////////////////////////////////////////////////////////////////////////////////////////
0045    // Data Structure  used by LossFunction and LossFunctionBDT to calculate errors, targets, etc
0046    ///////////////////////////////////////////////////////////////////////////////////////////////
0047 
0048    class LossFunctionEventInfo{
0049 
0050    public:
0051       LossFunctionEventInfo(){
0052           trueValue = 0.;
0053           predictedValue = 0.;
0054           weight = 0.;
0055       };
0056       LossFunctionEventInfo(Double_t trueValue_, Double_t predictedValue_, Double_t weight_){
0057          trueValue = trueValue_;
0058          predictedValue = predictedValue_;
0059          weight = weight_;
0060       }
0061       ~LossFunctionEventInfo(){};
0062 
0063       Double_t trueValue;
0064       Double_t predictedValue;
0065       Double_t weight;
0066    };
0067 
0068 
0069    ///////////////////////////////////////////////////////////////////////////////////////////////
0070    // Loss Function interface defining base class for general error calculations in
0071    // regression/classification
0072    ///////////////////////////////////////////////////////////////////////////////////////////////
0073 
0074    class LossFunction {
0075 
0076    public:
0077 
0078       // constructors
0079       LossFunction(){};
0080       virtual ~LossFunction(){};
0081 
0082       // abstract methods that need to be implemented
0083       virtual Double_t CalculateLoss(LossFunctionEventInfo& e) = 0;
0084       virtual Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs) = 0;
0085       virtual Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs) = 0;
0086 
0087       virtual TString Name() = 0;
0088       virtual Int_t Id() = 0;
0089    };
0090 
0091    ///////////////////////////////////////////////////////////////////////////////////////////////
0092    // Loss Function interface for boosted decision trees. Inherits from LossFunction
0093    ///////////////////////////////////////////////////////////////////////////////////////////////
0094 
0095    /* Must inherit LossFunction with the virtual keyword so that we only have to implement
0096    * the LossFunction interface once.
0097    *
0098    *       LossFunction
0099    *      /            \
0100    *SomeLossFunction  LossFunctionBDT
0101    *      \            /
0102    *       \          /
0103    *    SomeLossFunctionBDT
0104    *
0105    * Without the virtual keyword the two would point to their own LossFunction objects
0106    * and SomeLossFunctionBDT would have to implement the virtual functions of LossFunction twice, once
0107    * for each object. See diagram below.
0108    *
0109    * LossFunction  LossFunction
0110    *     |             |
0111    *SomeLossFunction  LossFunctionBDT
0112    *      \            /
0113    *       \          /
0114    *     SomeLossFunctionBDT
0115    *
0116    * Multiple inheritance is often frowned upon. To avoid this, We could make LossFunctionBDT separate
0117    * from LossFunction but it really is a type of loss function.
0118    * We could also put LossFunction into LossFunctionBDT. In either of these scenarios, if you are doing
0119    * different regression methods and want to compare the Loss this makes it more convoluted.
0120    * I think that multiple inheritance seems justified in this case, but we could change it if it's a problem.
0121    * Usually it isn't a big deal with interfaces and this results in the simplest code in this case.
0122    */
0123 
0124    class LossFunctionBDT : public virtual LossFunction{
0125 
0126    public:
0127 
0128       // constructors
0129       LossFunctionBDT(){};
0130       virtual ~LossFunctionBDT(){};
0131 
0132       // abstract methods that need to be implemented
0133       virtual void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights) = 0;
0134       virtual void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap) = 0;
0135       virtual Double_t Target(LossFunctionEventInfo& e) = 0;
0136       virtual Double_t Fit(std::vector<LossFunctionEventInfo>& evs) = 0;
0137 
0138    };
0139 
0140    ///////////////////////////////////////////////////////////////////////////////////////////////
0141    // Huber loss function for regression error calculations
0142    ///////////////////////////////////////////////////////////////////////////////////////////////
0143 
0144    class HuberLossFunction : public virtual LossFunction{
0145 
0146    public:
0147       HuberLossFunction();
0148       HuberLossFunction(Double_t quantile);
0149       ~HuberLossFunction();
0150 
0151       // The LossFunction methods
0152       Double_t CalculateLoss(LossFunctionEventInfo& e);
0153       Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
0154       Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
0155 
0156       // We go ahead and implement the simple ones
0157       TString Name(){ return TString("Huber"); };
0158       Int_t Id(){ return 0; } ;
0159 
0160       // Functions needed beyond the interface
0161       void Init(std::vector<LossFunctionEventInfo>& evs);
0162       Double_t CalculateQuantile(std::vector<LossFunctionEventInfo>& evs, Double_t whichQuantile, Double_t sumOfWeights, bool abs);
0163       Double_t CalculateSumOfWeights(const std::vector<LossFunctionEventInfo>& evs);
0164       void SetTransitionPoint(std::vector<LossFunctionEventInfo>& evs);
0165       void SetSumOfWeights(std::vector<LossFunctionEventInfo>& evs);
0166 
0167    protected:
0168       Double_t fQuantile;
0169       Double_t fTransitionPoint;
0170       Double_t fSumOfWeights;
0171    };
0172 
0173    ///////////////////////////////////////////////////////////////////////////////////////////////
0174    // Huber loss function with boosted decision tree functionality
0175    ///////////////////////////////////////////////////////////////////////////////////////////////
0176 
0177    // The bdt loss function implements the LossFunctionBDT interface and inherits the HuberLossFunction
0178    // functionality.
0179    class HuberLossFunctionBDT : public LossFunctionBDT, public HuberLossFunction{
0180 
0181    public:
0182       HuberLossFunctionBDT();
0183       HuberLossFunctionBDT(Double_t quantile):HuberLossFunction(quantile){};
0184       ~HuberLossFunctionBDT(){};
0185 
0186       // The LossFunctionBDT methods
0187       void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
0188       void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
0189       Double_t Target(LossFunctionEventInfo& e);
0190       Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
0191 
0192    private:
0193       // some data fields
0194    };
0195 
0196    ///////////////////////////////////////////////////////////////////////////////////////////////
0197    // LeastSquares loss function for regression error calculations
0198    ///////////////////////////////////////////////////////////////////////////////////////////////
0199 
0200    class LeastSquaresLossFunction : public virtual LossFunction{
0201 
0202    public:
0203       LeastSquaresLossFunction(){};
0204       ~LeastSquaresLossFunction(){};
0205 
0206       // The LossFunction methods
0207       Double_t CalculateLoss(LossFunctionEventInfo& e);
0208       Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
0209       Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
0210 
0211       // We go ahead and implement the simple ones
0212       TString Name(){ return TString("LeastSquares"); };
0213       Int_t Id(){ return 1; } ;
0214    };
0215 
0216    ///////////////////////////////////////////////////////////////////////////////////////////////
0217    // Least Squares loss function with boosted decision tree functionality
0218    ///////////////////////////////////////////////////////////////////////////////////////////////
0219 
0220    // The bdt loss function implements the LossFunctionBDT interface and inherits the LeastSquaresLossFunction
0221    // functionality.
0222    class LeastSquaresLossFunctionBDT : public LossFunctionBDT, public LeastSquaresLossFunction{
0223 
0224    public:
0225       LeastSquaresLossFunctionBDT(){};
0226       ~LeastSquaresLossFunctionBDT(){};
0227 
0228       // The LossFunctionBDT methods
0229       void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
0230       void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
0231       Double_t Target(LossFunctionEventInfo& e);
0232       Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
0233    };
0234 
0235    ///////////////////////////////////////////////////////////////////////////////////////////////
0236    // Absolute Deviation loss function for regression error calculations
0237    ///////////////////////////////////////////////////////////////////////////////////////////////
0238 
0239    class AbsoluteDeviationLossFunction : public virtual LossFunction{
0240 
0241    public:
0242       AbsoluteDeviationLossFunction(){};
0243       ~AbsoluteDeviationLossFunction(){};
0244 
0245       // The LossFunction methods
0246       Double_t CalculateLoss(LossFunctionEventInfo& e);
0247       Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
0248       Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
0249 
0250       // We go ahead and implement the simple ones
0251       TString Name(){ return TString("AbsoluteDeviation"); };
0252       Int_t Id(){ return 2; } ;
0253    };
0254 
0255    ///////////////////////////////////////////////////////////////////////////////////////////////
0256    // Absolute Deviation loss function with boosted decision tree functionality
0257    ///////////////////////////////////////////////////////////////////////////////////////////////
0258 
0259    // The bdt loss function implements the LossFunctionBDT interface and inherits the AbsoluteDeviationLossFunction
0260    // functionality.
0261    class AbsoluteDeviationLossFunctionBDT : public LossFunctionBDT, public AbsoluteDeviationLossFunction{
0262 
0263    public:
0264       AbsoluteDeviationLossFunctionBDT(){};
0265       ~AbsoluteDeviationLossFunctionBDT(){};
0266 
0267       // The LossFunctionBDT methods
0268       void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
0269       void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
0270       Double_t Target(LossFunctionEventInfo& e);
0271       Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
0272    };
0273 }
0274 
0275 #endif