File indexing completed on 2025-09-16 09:08:09
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 #ifndef ROOT_Math_Functor
0015 #define ROOT_Math_Functor
0016
0017 #include "Math/IFunction.h"
0018
0019
0020
0021
0022
0023 #include <algorithm>
0024 #include <memory>
0025 #include <functional>
0026 #include <type_traits>
0027 #include <vector>
0028
0029 namespace ROOT {
0030
0031 namespace Math {
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049 class Functor : public IBaseFunctionMultiDim {
0050
0051 public:
0052
0053
0054 Functor () {}
0055
0056
0057 template <class PtrObj, typename MemFn>
0058 Functor(const PtrObj& p, MemFn memFn, unsigned int dim )
0059 : fDim{dim}, fFunc{std::bind(memFn, p, std::placeholders::_1)}
0060 {}
0061
0062
0063
0064 Functor(std::function<double(double const *)> const& f, unsigned int dim ) : fDim{dim}, fFunc{f} {}
0065
0066
0067 Functor * Clone() const override { return new Functor(*this); }
0068
0069
0070 unsigned int NDim() const override { return fDim; }
0071
0072 private :
0073
0074 inline double DoEval (const double * x) const override {
0075 return fFunc(x);
0076 }
0077
0078 unsigned int fDim;
0079 std::function<double(double const *)> fFunc;
0080 };
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097 class Functor1D : public IBaseFunctionOneDim {
0098
0099 public:
0100
0101
0102 Functor1D() = default;
0103
0104
0105
0106 Functor1D(std::function<double(double)> const& f) : fFunc{f} {}
0107
0108
0109 template <class PtrObj, typename MemFn>
0110 Functor1D(const PtrObj& p, MemFn memFn) : fFunc{std::bind(memFn, p, std::placeholders::_1)} {}
0111
0112
0113 Functor1D * Clone() const override { return new Functor1D(*this); }
0114
0115 private :
0116
0117 inline double DoEval (double x) const override {
0118 return fFunc(x);
0119 }
0120
0121 std::function<double(double)> fFunc;
0122 };
0123
0124
0125
0126
0127
0128
0129
0130
0131
0132
0133
0134
0135
0136
0137
0138
0139
0140
0141
0142
0143
0144 class GradFunctor : public IGradientFunctionMultiDim {
0145
0146
0147 public:
0148
0149
0150 GradFunctor() = default;
0151
0152
0153
0154
0155
0156
0157 template <typename Func>
0158 GradFunctor( const Func & f, unsigned int dim ) :
0159 fDim{dim}, fFunc{f}, fDerivFunc{std::bind(&Func::Derivative, f, std::placeholders::_1, std::placeholders::_2)}
0160 {}
0161
0162
0163 template <class PtrObj, typename MemFn, typename DerivMemFn,
0164 std::enable_if_t<std::is_floating_point<decltype((std::declval<std::remove_pointer_t<PtrObj>>().*
0165 std::declval<DerivMemFn>())(
0166 std::declval<const double *>(), std::declval<int>()))>::value,
0167 bool> = true>
0168 GradFunctor(const PtrObj &p, MemFn memFn, DerivMemFn gradFn, unsigned int dim)
0169 : fDim{dim},
0170 fFunc{std::bind(memFn, p, std::placeholders::_1)},
0171 fDerivFunc{std::bind(gradFn, p, std::placeholders::_1, std::placeholders::_2)}
0172 {}
0173
0174
0175
0176 template <
0177 class PtrObj, typename MemFn, typename GradMemFn,
0178 std::enable_if_t<std::is_void<decltype((std::declval<std::remove_pointer_t<PtrObj>>().*std::declval<GradMemFn>())(
0179 std::declval<const double *>(), std::declval<double *>()))>::value,
0180 bool> = true>
0181 GradFunctor(const PtrObj &p, MemFn memFn, GradMemFn gradFn, unsigned int dim)
0182 : fDim{dim},
0183 fFunc{std::bind(memFn, p, std::placeholders::_1)},
0184 fGradFunc{std::bind(gradFn, p, std::placeholders::_1, std::placeholders::_2)}
0185 {
0186 }
0187
0188
0189
0190
0191 GradFunctor(std::function<double(double const *)> const& f,
0192 std::function<double(double const *, unsigned int)> const& g, unsigned int dim)
0193 : fDim{dim}, fFunc{f}, fDerivFunc{g}
0194 {}
0195
0196
0197
0198
0199
0200
0201
0202
0203
0204
0205
0206 GradFunctor(std::function<double(double const *)> const&f, unsigned int dim,
0207 std::function<void(double const *, double *)> const& g)
0208 : fDim{dim}, fFunc{f}, fGradFunc{g}
0209 {}
0210
0211
0212 GradFunctor * Clone() const override { return new GradFunctor(*this); }
0213
0214
0215 unsigned int NDim() const override { return fDim; }
0216
0217 void Gradient(const double *x, double *g) const override {
0218
0219
0220 if(!fGradFunc) {
0221 IGradientFunctionMultiDim::Gradient(x, g);
0222 return;
0223 }
0224 fGradFunc(x, g);
0225 }
0226
0227 private :
0228
0229 inline double DoEval (const double * x) const override {
0230 return fFunc(x);
0231 }
0232
0233 inline double DoDerivative (const double * x, unsigned int icoord ) const override {
0234 if(fDerivFunc) {
0235 return fDerivFunc(x, icoord);
0236 }
0237
0238
0239 std::vector<double> gradBuffer(fDim);
0240 std::fill(gradBuffer.begin(), gradBuffer.end(), 0.0);
0241 fGradFunc(x, gradBuffer.data());
0242 return gradBuffer[icoord];
0243 }
0244
0245 unsigned int fDim;
0246 std::function<double(const double *)> fFunc;
0247 std::function<double(double const *, unsigned int)> fDerivFunc;
0248 std::function<void(const double *, double*)> fGradFunc;
0249 };
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269
0270
0271 class GradFunctor1D : public IGradientFunctionOneDim {
0272
0273 public:
0274
0275
0276 GradFunctor1D() = default;
0277
0278
0279
0280 template <typename Func>
0281 GradFunctor1D(const Func & f) : fFunc{f}, fDerivFunc{std::bind(&Func::Derivative, f, std::placeholders::_1)} {}
0282
0283
0284
0285
0286
0287
0288 template <class PtrObj, typename MemFn, typename GradMemFn>
0289 GradFunctor1D(const PtrObj& p, MemFn memFn, GradMemFn gradFn)
0290 : fFunc{std::bind(memFn, p, std::placeholders::_1)}, fDerivFunc{std::bind(gradFn, p, std::placeholders::_1)}
0291 {}
0292
0293
0294
0295
0296
0297 GradFunctor1D(std::function<double(double)> const& f, std::function<double(double)> const& g)
0298 : fFunc{f}, fDerivFunc{g}
0299 {}
0300
0301
0302 GradFunctor1D * Clone() const override { return new GradFunctor1D(*this); }
0303
0304 private :
0305
0306 inline double DoEval (double x) const override { return fFunc(x); }
0307 inline double DoDerivative (double x) const override { return fDerivFunc(x); }
0308
0309 std::function<double(double)> fFunc;
0310 std::function<double(double)> fDerivFunc;
0311 };
0312
0313
0314 }
0315
0316 }
0317
0318
0319 #endif