File indexing completed on 2025-01-30 10:22:52
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031 #ifndef ROOT_TMVA_LossFunction
0032 #define ROOT_TMVA_LossFunction
0033
0034
0035 #include <vector>
0036 #include <map>
0037 #include "TMVA/Event.h"
0038
0039 #include "TMVA/Types.h"
0040
0041
0042 namespace TMVA {
0043
0044
0045
0046
0047
0048 class LossFunctionEventInfo{
0049
0050 public:
0051 LossFunctionEventInfo(){
0052 trueValue = 0.;
0053 predictedValue = 0.;
0054 weight = 0.;
0055 };
0056 LossFunctionEventInfo(Double_t trueValue_, Double_t predictedValue_, Double_t weight_){
0057 trueValue = trueValue_;
0058 predictedValue = predictedValue_;
0059 weight = weight_;
0060 }
0061 ~LossFunctionEventInfo(){};
0062
0063 Double_t trueValue;
0064 Double_t predictedValue;
0065 Double_t weight;
0066 };
0067
0068
0069
0070
0071
0072
0073
0074 class LossFunction {
0075
0076 public:
0077
0078
0079 LossFunction(){};
0080 virtual ~LossFunction(){};
0081
0082
0083 virtual Double_t CalculateLoss(LossFunctionEventInfo& e) = 0;
0084 virtual Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs) = 0;
0085 virtual Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs) = 0;
0086
0087 virtual TString Name() = 0;
0088 virtual Int_t Id() = 0;
0089 };
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124 class LossFunctionBDT : public virtual LossFunction{
0125
0126 public:
0127
0128
0129 LossFunctionBDT(){};
0130 virtual ~LossFunctionBDT(){};
0131
0132
0133 virtual void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights) = 0;
0134 virtual void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap) = 0;
0135 virtual Double_t Target(LossFunctionEventInfo& e) = 0;
0136 virtual Double_t Fit(std::vector<LossFunctionEventInfo>& evs) = 0;
0137
0138 };
0139
0140
0141
0142
0143
0144 class HuberLossFunction : public virtual LossFunction{
0145
0146 public:
0147 HuberLossFunction();
0148 HuberLossFunction(Double_t quantile);
0149 ~HuberLossFunction();
0150
0151
0152 Double_t CalculateLoss(LossFunctionEventInfo& e);
0153 Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
0154 Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
0155
0156
0157 TString Name(){ return TString("Huber"); };
0158 Int_t Id(){ return 0; } ;
0159
0160
0161 void Init(std::vector<LossFunctionEventInfo>& evs);
0162 Double_t CalculateQuantile(std::vector<LossFunctionEventInfo>& evs, Double_t whichQuantile, Double_t sumOfWeights, bool abs);
0163 Double_t CalculateSumOfWeights(const std::vector<LossFunctionEventInfo>& evs);
0164 void SetTransitionPoint(std::vector<LossFunctionEventInfo>& evs);
0165 void SetSumOfWeights(std::vector<LossFunctionEventInfo>& evs);
0166
0167 protected:
0168 Double_t fQuantile;
0169 Double_t fTransitionPoint;
0170 Double_t fSumOfWeights;
0171 };
0172
0173
0174
0175
0176
0177
0178
0179 class HuberLossFunctionBDT : public LossFunctionBDT, public HuberLossFunction{
0180
0181 public:
0182 HuberLossFunctionBDT();
0183 HuberLossFunctionBDT(Double_t quantile):HuberLossFunction(quantile){};
0184 ~HuberLossFunctionBDT(){};
0185
0186
0187 void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
0188 void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
0189 Double_t Target(LossFunctionEventInfo& e);
0190 Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
0191
0192 private:
0193
0194 };
0195
0196
0197
0198
0199
0200 class LeastSquaresLossFunction : public virtual LossFunction{
0201
0202 public:
0203 LeastSquaresLossFunction(){};
0204 ~LeastSquaresLossFunction(){};
0205
0206
0207 Double_t CalculateLoss(LossFunctionEventInfo& e);
0208 Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
0209 Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
0210
0211
0212 TString Name(){ return TString("LeastSquares"); };
0213 Int_t Id(){ return 1; } ;
0214 };
0215
0216
0217
0218
0219
0220
0221
0222 class LeastSquaresLossFunctionBDT : public LossFunctionBDT, public LeastSquaresLossFunction{
0223
0224 public:
0225 LeastSquaresLossFunctionBDT(){};
0226 ~LeastSquaresLossFunctionBDT(){};
0227
0228
0229 void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
0230 void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
0231 Double_t Target(LossFunctionEventInfo& e);
0232 Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
0233 };
0234
0235
0236
0237
0238
0239 class AbsoluteDeviationLossFunction : public virtual LossFunction{
0240
0241 public:
0242 AbsoluteDeviationLossFunction(){};
0243 ~AbsoluteDeviationLossFunction(){};
0244
0245
0246 Double_t CalculateLoss(LossFunctionEventInfo& e);
0247 Double_t CalculateNetLoss(std::vector<LossFunctionEventInfo>& evs);
0248 Double_t CalculateMeanLoss(std::vector<LossFunctionEventInfo>& evs);
0249
0250
0251 TString Name(){ return TString("AbsoluteDeviation"); };
0252 Int_t Id(){ return 2; } ;
0253 };
0254
0255
0256
0257
0258
0259
0260
0261 class AbsoluteDeviationLossFunctionBDT : public LossFunctionBDT, public AbsoluteDeviationLossFunction{
0262
0263 public:
0264 AbsoluteDeviationLossFunctionBDT(){};
0265 ~AbsoluteDeviationLossFunctionBDT(){};
0266
0267
0268 void Init(std::map<const TMVA::Event*, LossFunctionEventInfo>& evinfomap, std::vector<double>& boostWeights);
0269 void SetTargets(std::vector<const TMVA::Event*>& evs, std::map< const TMVA::Event*, LossFunctionEventInfo >& evinfomap);
0270 Double_t Target(LossFunctionEventInfo& e);
0271 Double_t Fit(std::vector<LossFunctionEventInfo>& evs);
0272 };
0273 }
0274
0275 #endif