|
|
|||
File indexing completed on 2025-12-25 10:03:12
0001 // @(#)root/minuit2:$Id$ 0002 // Author: L. Moneta 10/2006 0003 0004 /********************************************************************** 0005 * * 0006 * Copyright (c) 2006 ROOT Foundation, CERN/PH-SFT * 0007 * * 0008 **********************************************************************/ 0009 0010 #ifndef ROOT_Minuit2_FCNAdapter 0011 #define ROOT_Minuit2_FCNAdapter 0012 0013 #include "Minuit2/FCNBase.h" 0014 0015 #include <ROOT/RSpan.hxx> 0016 0017 #include <vector> 0018 #include <functional> 0019 0020 namespace ROOT::Minuit2 { 0021 0022 /// Adapter class to wrap user-provided functions into the FCNBase interface. 0023 /// 0024 /// This class allows users to supply their function (and optionally its gradient, 0025 /// diagonal second derivatives, and Hessian) in the form of `std::function` objects. 0026 /// It adapts these functions so that they can be used transparently with the MINUIT 0027 /// minimizers via the `FCNBase` interface. 0028 /// 0029 /// Typical usage: 0030 /// - Pass the function to minimize to the constructor. 0031 /// - Optionally set the gradient, G2 (second derivative diagonal), or Hessian 0032 /// functions using the provided setter methods. 0033 /// - MINUIT will then query these functions if available, or fall back to numerical 0034 /// approximations if they are not provided. 0035 /// 0036 /// \ingroup Minuit 0037 class FCNAdapter : public FCNBase { 0038 0039 public: 0040 /// Construct an adapter around a user-provided function. 0041 /// 0042 /// @param f Function to minimize. It must take a pointer to the parameter array 0043 /// (`double const*`) and return the function value. 0044 /// @param up Error definition parameter (defaults to 1.0). 0045 FCNAdapter(std::function<double(double const *)> f, double up = 1.) : fUp(up), fFunc(std::move(f)) {} 0046 0047 /// Indicate whether an analytic gradient has been provided. 0048 /// 0049 /// @return `true` if a gradient function was set, otherwise `false`. 0050 bool HasGradient() const override { return bool(fGradFunc); } 0051 0052 /// Indicate whether analytic second derivatives (diagonal of the Hessian) are available. 0053 /// 0054 /// @return `true` if a G2 function or a Hessian function has been set, otherwise `false`. 0055 bool HasG2() const override { return bool(fG2Func); } 0056 0057 /// Indicate whether an analytic Hessian has been provided. 0058 /// 0059 /// @return `true` if a Hessian function was set, otherwise `false`. 0060 bool HasHessian() const override { return bool(fHessianFunc); } 0061 0062 /// Evaluate the function at the given parameter vector. 0063 /// 0064 /// @param v Parameter vector. 0065 /// @return Function value at the specified parameters. 0066 double operator()(std::vector<double> const &v) const override { return fFunc(v.data()); } 0067 0068 /// Return the error definition parameter (`up`). 0069 /// 0070 /// @return Current error definition value. 0071 double Up() const override { return fUp; } 0072 0073 /// Evaluate the gradient of the function at the given parameter vector. 0074 /// 0075 /// @param v Parameter vector. 0076 /// @return Gradient vector (∂f/∂xᵢ) at the specified parameters. 0077 std::vector<double> Gradient(std::vector<double> const &v) const override 0078 { 0079 std::vector<double> output(v.size()); 0080 fGradFunc(v.data(), output.data()); 0081 return output; 0082 } 0083 0084 /// Return the diagonal elements of the Hessian (second derivatives). 0085 /// 0086 /// If a G2 function is set, it is used directly. If only a Hessian function 0087 /// is available, the diagonal is extracted from the full Hessian. 0088 /// 0089 /// @param x Parameter vector. 0090 /// @return Vector of second derivatives (one per parameter). 0091 std::vector<double> G2(std::vector<double> const &x) const override 0092 { 0093 std::vector<double> output; 0094 if (fG2Func) 0095 return fG2Func(x); 0096 if (fHessianFunc) { 0097 std::size_t n = x.size(); 0098 output.resize(n); 0099 if (fHessian.empty()) 0100 fHessian.resize(n * n); 0101 fHessianFunc(x, fHessian.data()); 0102 if (!fHessian.empty()) { 0103 // Extract diagonal elements of Hessian 0104 for (unsigned int i = 0; i < n; i++) 0105 output[i] = fHessian[i * n + i]; 0106 } 0107 } 0108 return output; 0109 } 0110 0111 /// Return the full Hessian matrix. 0112 /// 0113 /// If a Hessian function is available, it is used to fill the matrix. 0114 /// If the Hessian function fails, it is cleared and not used again. 0115 /// 0116 /// @param x Parameter vector. 0117 /// @return Flattened Hessian matrix in row-major order. 0118 std::vector<double> Hessian(std::vector<double> const &x) const override 0119 { 0120 std::vector<double> output; 0121 if (fHessianFunc) { 0122 std::size_t n = x.size(); 0123 output.resize(n * n); 0124 bool ret = fHessianFunc(x, output.data()); 0125 if (!ret) { 0126 output.clear(); 0127 fHessianFunc = nullptr; 0128 } 0129 } 0130 0131 return output; 0132 } 0133 0134 /// Set the analytic gradient function. 0135 /// 0136 /// @param f Gradient function of type `void(double const*, double*)`. 0137 /// The first argument is the parameter array, the second is 0138 /// the output array for the gradient values. 0139 void SetGradientFunction(std::function<void(double const *, double *)> f) { fGradFunc = std::move(f); } 0140 0141 /// Set the function providing diagonal second derivatives (G2). 0142 /// 0143 /// @param f Function taking a parameter vector and returning the 0144 /// diagonal of the Hessian matrix as a vector. 0145 void SetG2Function(std::function<std::vector<double>(std::vector<double> const &)> f) { fG2Func = std::move(f); } 0146 0147 /// Set the function providing the full Hessian matrix. 0148 /// 0149 /// @param f Function of type `bool(std::vector<double> const&, double*)`. 0150 /// The first argument is the parameter vector, the second is 0151 /// the output buffer (flattened matrix). The return value 0152 /// should be `true` on success, `false` on failure. 0153 void SetHessianFunction(std::function<bool(std::vector<double> const &, double *)> f) 0154 { 0155 fHessianFunc = std::move(f); 0156 } 0157 0158 /// Update the error definition parameter. 0159 /// 0160 /// @param up New error definition value. 0161 void SetErrorDef(double up) override { fUp = up; } 0162 0163 private: 0164 using Function = std::function<double(double const *)>; 0165 using GradFunction = std::function<void(double const *, double *)>; 0166 using G2Function = std::function<std::vector<double>(std::vector<double> const &)>; 0167 using HessianFunction = std::function<bool(std::vector<double> const &, double *)>; 0168 0169 double fUp = 1.; ///< Error definition parameter. 0170 mutable std::vector<double> fHessian; ///< Storage for intermediate Hessian values. 0171 0172 Function fFunc; ///< Wrapped function to minimize. 0173 GradFunction fGradFunc; ///< Optional gradient function. 0174 G2Function fG2Func; ///< Optional diagonal second-derivative function. 0175 mutable HessianFunction fHessianFunc; ///< Optional Hessian function. 0176 }; 0177 0178 } // namespace ROOT::Minuit2 0179 0180 #endif // ROOT_Minuit2_FCNAdapter
| [ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
|
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
|