File indexing completed on 2025-01-18 10:10:16
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 #ifndef ROOT_Math_Functor
0015 #define ROOT_Math_Functor
0016
0017 #include "Math/IFunction.h"
0018
0019
0020
0021
0022
0023 #include <memory>
0024 #include <functional>
0025 #include <vector>
0026
0027 namespace ROOT {
0028
0029 namespace Math {
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047 class Functor : public IBaseFunctionMultiDim {
0048
0049 public:
0050
0051
0052 Functor () {}
0053
0054
0055 template <class PtrObj, typename MemFn>
0056 Functor(const PtrObj& p, MemFn memFn, unsigned int dim )
0057 : fDim{dim}, fFunc{std::bind(memFn, p, std::placeholders::_1)}
0058 {}
0059
0060
0061
0062 Functor(std::function<double(double const *)> const& f, unsigned int dim ) : fDim{dim}, fFunc{f} {}
0063
0064
0065 Functor * Clone() const override { return new Functor(*this); }
0066
0067
0068 unsigned int NDim() const override { return fDim; }
0069
0070 private :
0071
0072 inline double DoEval (const double * x) const override {
0073 return fFunc(x);
0074 }
0075
0076 unsigned int fDim;
0077 std::function<double(double const *)> fFunc;
0078 };
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095 class Functor1D : public IBaseFunctionOneDim {
0096
0097 public:
0098
0099
0100 Functor1D() = default;
0101
0102
0103
0104 Functor1D(std::function<double(double)> const& f) : fFunc{f} {}
0105
0106
0107 template <class PtrObj, typename MemFn>
0108 Functor1D(const PtrObj& p, MemFn memFn) : fFunc{std::bind(memFn, p, std::placeholders::_1)} {}
0109
0110
0111 Functor1D * Clone() const override { return new Functor1D(*this); }
0112
0113 private :
0114
0115 inline double DoEval (double x) const override {
0116 return fFunc(x);
0117 }
0118
0119 std::function<double(double)> fFunc;
0120 };
0121
0122
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142 class GradFunctor : public IGradientFunctionMultiDim {
0143
0144
0145 public:
0146
0147
0148 GradFunctor() = default;
0149
0150
0151
0152
0153
0154
0155 template <typename Func>
0156 GradFunctor( const Func & f, unsigned int dim ) :
0157 fDim{dim}, fFunc{f}, fDerivFunc{std::bind(&Func::Derivative, f, std::placeholders::_1, std::placeholders::_2)}
0158 {}
0159
0160
0161 template <class PtrObj, typename MemFn, typename DerivMemFn,
0162 std::enable_if_t<std::is_floating_point<decltype((std::declval<std::remove_pointer_t<PtrObj>>().*
0163 std::declval<DerivMemFn>())(
0164 std::declval<const double *>(), std::declval<int>()))>::value,
0165 bool> = true>
0166 GradFunctor(const PtrObj &p, MemFn memFn, DerivMemFn gradFn, unsigned int dim)
0167 : fDim{dim},
0168 fFunc{std::bind(memFn, p, std::placeholders::_1)},
0169 fDerivFunc{std::bind(gradFn, p, std::placeholders::_1, std::placeholders::_2)}
0170 {}
0171
0172
0173
0174 template <
0175 class PtrObj, typename MemFn, typename GradMemFn,
0176 std::enable_if_t<std::is_void<decltype((std::declval<std::remove_pointer_t<PtrObj>>().*std::declval<GradMemFn>())(
0177 std::declval<const double *>(), std::declval<double *>()))>::value,
0178 bool> = true>
0179 GradFunctor(const PtrObj &p, MemFn memFn, GradMemFn gradFn, unsigned int dim)
0180 : fDim{dim},
0181 fFunc{std::bind(memFn, p, std::placeholders::_1)},
0182 fGradFunc{std::bind(gradFn, p, std::placeholders::_1, std::placeholders::_2)}
0183 {
0184 }
0185
0186
0187
0188
0189 GradFunctor(std::function<double(double const *)> const& f,
0190 std::function<double(double const *, unsigned int)> const& g, unsigned int dim)
0191 : fDim{dim}, fFunc{f}, fDerivFunc{g}
0192 {}
0193
0194
0195
0196
0197
0198
0199
0200
0201
0202
0203
0204 GradFunctor(std::function<double(double const *)> const&f, unsigned int dim,
0205 std::function<void(double const *, double *)> const& g)
0206 : fDim{dim}, fFunc{f}, fGradFunc{g}
0207 {}
0208
0209
0210 GradFunctor * Clone() const override { return new GradFunctor(*this); }
0211
0212
0213 unsigned int NDim() const override { return fDim; }
0214
0215 void Gradient(const double *x, double *g) const override {
0216
0217
0218 if(!fGradFunc) {
0219 IGradientFunctionMultiDim::Gradient(x, g);
0220 return;
0221 }
0222 fGradFunc(x, g);
0223 }
0224
0225 private :
0226
0227 inline double DoEval (const double * x) const override {
0228 return fFunc(x);
0229 }
0230
0231 inline double DoDerivative (const double * x, unsigned int icoord ) const override {
0232 if(fDerivFunc) {
0233 return fDerivFunc(x, icoord);
0234 }
0235
0236
0237 std::vector<double> gradBuffer(fDim);
0238 std::fill(gradBuffer.begin(), gradBuffer.end(), 0.0);
0239 fGradFunc(x, gradBuffer.data());
0240 return gradBuffer[icoord];
0241 }
0242
0243 unsigned int fDim;
0244 std::function<double(const double *)> fFunc;
0245 std::function<double(double const *, unsigned int)> fDerivFunc;
0246 std::function<void(const double *, double*)> fGradFunc;
0247 };
0248
0249
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269 class GradFunctor1D : public IGradientFunctionOneDim {
0270
0271 public:
0272
0273
0274 GradFunctor1D() = default;
0275
0276
0277
0278 template <typename Func>
0279 GradFunctor1D(const Func & f) : fFunc{f}, fDerivFunc{std::bind(&Func::Derivative, f, std::placeholders::_1)} {}
0280
0281
0282
0283
0284
0285
0286 template <class PtrObj, typename MemFn, typename GradMemFn>
0287 GradFunctor1D(const PtrObj& p, MemFn memFn, GradMemFn gradFn)
0288 : fFunc{std::bind(memFn, p, std::placeholders::_1)}, fDerivFunc{std::bind(gradFn, p, std::placeholders::_1)}
0289 {}
0290
0291
0292
0293
0294
0295 GradFunctor1D(std::function<double(double)> const& f, std::function<double(double)> const& g)
0296 : fFunc{f}, fDerivFunc{g}
0297 {}
0298
0299
0300 GradFunctor1D * Clone() const override { return new GradFunctor1D(*this); }
0301
0302 private :
0303
0304 inline double DoEval (double x) const override { return fFunc(x); }
0305 inline double DoDerivative (double x) const override { return fDerivFunc(x); }
0306
0307 std::function<double(double)> fFunc;
0308 std::function<double(double)> fDerivFunc;
0309 };
0310
0311
0312 }
0313
0314 }
0315
0316
0317 #endif