Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:29:45

0001 // @(#)root/roostats:$Id$
0002 // Authors: Kevin Belasco        17/06/2009
0003 // Authors: Kyle Cranmer         17/06/2009
0004 /*************************************************************************
0005  * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers.               *
0006  * All rights reserved.                                                  *
0007  *                                                                       *
0008  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0009  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0010  *************************************************************************/
0011 
0012 #ifndef RooStats_MCMCInterval
0013 #define RooStats_MCMCInterval
0014 
0015 #include "Rtypes.h"
0016 
0017 #include "RooStats/ConfInterval.h"
0018 #include "RooArgSet.h"
0019 #include "RooArgList.h"
0020 #include "RooMsgService.h"
0021 #include "RooStats/MarkovChain.h"
0022 
0023 #include <vector>
0024 
0025 class RooNDKeysPdf;
0026 class RooProduct;
0027 
0028 
0029 namespace RooStats {
0030 
0031    class Heaviside;
0032 
0033    class MCMCInterval : public ConfInterval {
0034 
0035 
0036    public:
0037 
0038       /// default constructor
0039       explicit MCMCInterval(const char *name = nullptr);
0040 
0041       /// constructor from parameter of interest and Markov chain object
0042       MCMCInterval(const char* name, const RooArgSet& parameters,
0043                    MarkovChain& chain);
0044 
0045       enum {DEFAULT_NUM_BINS = 50};
0046       enum IntervalType {kShortest, kTailFraction};
0047 
0048       ~MCMCInterval() override;
0049 
0050       /// determine whether this point is in the confidence interval
0051       bool IsInInterval(const RooArgSet& point) const override;
0052 
0053       /// set the desired confidence level (see GetActualConfidenceLevel())
0054       /// Note: calling this function triggers the algorithm that determines
0055       /// the interval, so call this after initializing all other aspects
0056       /// of this IntervalCalculator
0057       /// Also, calling this function again with a different confidence level
0058       /// re-triggers the calculation of the interval
0059       void SetConfidenceLevel(double cl) override;
0060 
0061       /// get the desired confidence level (see GetActualConfidenceLevel())
0062       double ConfidenceLevel() const override {return fConfidenceLevel;}
0063 
0064       /// return a set containing the parameters of this interval
0065       /// the caller owns the returned RooArgSet*
0066       RooArgSet* GetParameters() const override;
0067 
0068       /// get the cutoff bin height for being considered in the
0069       /// confidence interval
0070       virtual double GetHistCutoff();
0071 
0072       /// get the cutoff RooNDKeysPdf value for being considered in the
0073       /// confidence interval
0074       virtual double GetKeysPdfCutoff();
0075       ///virtual double GetKeysPdfCutoff() { return fKeysCutoff; }
0076 
0077       /// get the actual value of the confidence level for this interval.
0078       virtual double GetActualConfidenceLevel();
0079 
0080       /// whether the specified confidence level is a floor for the actual
0081       /// confidence level (strict), or a ceiling (not strict)
0082       virtual void SetHistStrict(bool isHistStrict)
0083       { fIsHistStrict = isHistStrict; }
0084 
0085       /// check if parameters are correct. (dummy implementation to start)
0086       bool CheckParameters(const RooArgSet& point) const override;
0087 
0088       /// Set the parameters of interest for this interval
0089       /// and change other internal data members accordingly
0090       virtual void SetParameters(const RooArgSet& parameters);
0091 
0092       /// Set the MarkovChain that this interval is based on.
0093       /// \note The MCMCInterval object takes ownership of the passed MarkovChain.
0094       virtual void SetChain(MarkovChain& chain) { fChain.reset(&chain); }
0095 
0096       /// Set which parameters go on which axis.  The first list element
0097       /// goes on the x axis, second (if it exists) on y, third (if it
0098       /// exists) on z, etc
0099       virtual void SetAxes(RooArgList& axes);
0100 
0101       /// return a list of RooRealVars representing the axes
0102       /// you own the returned RooArgList
0103       virtual RooArgList* GetAxes()
0104       {
0105          RooArgList* axes = new RooArgList();
0106          for (Int_t i = 0; i < fDimension; i++)
0107             axes->addClone(*fAxes[i]);
0108          return axes;
0109       }
0110 
0111       /// get the lowest value of param that is within the confidence interval
0112       virtual double LowerLimit(RooRealVar& param);
0113 
0114       /// determine lower limit of the lower confidence interval
0115       virtual double LowerLimitTailFraction(RooRealVar& param);
0116 
0117       /// get the lower limit of param in the shortest confidence interval
0118       /// Note that this works better for some distributions (ones with exactly
0119       /// one maximum) than others, and sometimes has little value.
0120       virtual double LowerLimitShortest(RooRealVar& param);
0121 
0122       /// determine lower limit in the shortest interval by using keys pdf
0123       virtual double LowerLimitByKeys(RooRealVar& param);
0124 
0125       /// determine lower limit using histogram
0126       virtual double LowerLimitByHist(RooRealVar& param);
0127 
0128       /// determine lower limit using histogram
0129       virtual double LowerLimitBySparseHist(RooRealVar& param);
0130 
0131       /// determine lower limit using histogram
0132       virtual double LowerLimitByDataHist(RooRealVar& param);
0133 
0134       /// get the highest value of param that is within the confidence interval
0135       virtual double UpperLimit(RooRealVar& param);
0136 
0137       /// determine upper limit of the lower confidence interval
0138       virtual double UpperLimitTailFraction(RooRealVar& param);
0139 
0140       /// get the upper limit of param in the confidence interval
0141       /// Note that this works better for some distributions (ones with exactly
0142       /// one maximum) than others, and sometimes has little value.
0143       virtual double UpperLimitShortest(RooRealVar& param);
0144 
0145       /// determine upper limit in the shortest interval by using keys pdf
0146       virtual double UpperLimitByKeys(RooRealVar& param);
0147 
0148       /// determine upper limit using histogram
0149       virtual double UpperLimitByHist(RooRealVar& param);
0150 
0151       /// determine upper limit using histogram
0152       virtual double UpperLimitBySparseHist(RooRealVar& param);
0153 
0154       /// determine upper limit using histogram
0155       virtual double UpperLimitByDataHist(RooRealVar& param);
0156 
0157       /// Determine the approximate maximum value of the Keys PDF
0158       double GetKeysMax();
0159 
0160       /// set the number of steps in the chain to discard as burn-in,
0161       /// starting from the first
0162       virtual void SetNumBurnInSteps(Int_t numBurnInSteps)
0163       { fNumBurnInSteps = numBurnInSteps; }
0164 
0165       /// set whether to use kernel estimation to determine the interval
0166       virtual void SetUseKeys(bool useKeys) { fUseKeys = useKeys; }
0167 
0168       /// set whether to use a sparse histogram.  you MUST also call
0169       /// SetUseKeys(false) to use a histogram.
0170       virtual void SetUseSparseHist(bool useSparseHist)
0171       { fUseSparseHist = useSparseHist; }
0172 
0173       /// get whether we used kernel estimation to determine the interval
0174       virtual bool GetUseKeys() { return fUseKeys; }
0175 
0176       /// get the number of steps in the chain to discard as burn-in,
0177 
0178       /// get the number of steps in the chain to discard as burn-in,
0179       /// starting from the first
0180       virtual Int_t GetNumBurnInSteps() { return fNumBurnInSteps; }
0181 
0182       /// set the number of bins to use (same for all axes, for now)
0183       ///virtual void SetNumBins(Int_t numBins);
0184 
0185       /// Get a clone of the histogram of the posterior
0186       virtual TH1* GetPosteriorHist();
0187 
0188       /// Get a clone of the keys pdf of the posterior
0189       virtual RooNDKeysPdf* GetPosteriorKeysPdf();
0190 
0191       /// Get a clone of the (keyspdf * heaviside) product of the posterior
0192       virtual RooProduct* GetPosteriorKeysProduct();
0193 
0194       /// Get the number of parameters of interest in this interval
0195       virtual Int_t GetDimension() const { return fDimension; }
0196 
0197       /// Get the markov chain on which this interval is based
0198       /// You do not own the returned MarkovChain*
0199       virtual const MarkovChain* GetChain() { return fChain.get(); }
0200 
0201       /// Get a clone of the markov chain on which this interval is based
0202       /// as a RooDataSet.  You own the returned RooDataSet*
0203       virtual RooFit::OwningPtr<RooDataSet> GetChainAsDataSet(RooArgSet* whichVars = nullptr)
0204       { return fChain->GetAsDataSet(whichVars); }
0205 
0206       /// Get the markov chain on which this interval is based
0207       /// as a RooDataSet.  You do not own the returned RooDataSet*
0208       virtual const RooDataSet* GetChainAsConstDataSet()
0209       { return fChain->GetAsConstDataSet(); }
0210 
0211       /// Get a clone of the markov chain on which this interval is based
0212       /// as a RooDataHist.  You own the returned RooDataHist*
0213       virtual RooFit::OwningPtr<RooDataHist> GetChainAsDataHist(RooArgSet* whichVars = nullptr)
0214       { return fChain->GetAsDataHist(whichVars); }
0215 
0216       /// Get a clone of the markov chain on which this interval is based
0217       /// as a THnSparse.  You own the returned THnSparse*
0218       virtual THnSparse* GetChainAsSparseHist(RooArgSet* whichVars = nullptr)
0219       { return fChain->GetAsSparseHist(whichVars); }
0220 
0221       /// Get a clone of the NLL variable from the markov chain
0222       virtual RooRealVar* GetNLLVar() const
0223       { return fChain->GetNLLVar(); }
0224 
0225       /// Get a clone of the weight variable from the markov chain
0226       virtual RooRealVar* GetWeightVar() const
0227       { return fChain->GetWeightVar(); }
0228 
0229       /// set the acceptable level or error for Keys interval determination
0230       virtual void SetEpsilon(double epsilon)
0231       {
0232          if (epsilon < 0) {
0233             coutE(InputArguments) << "MCMCInterval::SetEpsilon will not allow "
0234                                   << "negative epsilon value" << std::endl;
0235          } else {
0236             fEpsilon = epsilon;
0237          }
0238       }
0239 
0240       /// Set the type of interval to find.  This will only have an effect for
0241       /// 1-D intervals.  If is more than 1 parameter of interest, then a
0242       /// "shortest" interval will always be used, since it generalizes directly
0243       /// to N dimensions
0244       virtual void SetIntervalType(enum IntervalType intervalType)
0245       { fIntervalType = intervalType; }
0246       virtual void SetShortestInterval() { SetIntervalType(kShortest); }
0247 
0248       /// Return the type of this interval
0249       virtual enum IntervalType GetIntervalType() { return fIntervalType; }
0250 
0251       /// set the left-side tail fraction for a tail-fraction interval
0252       virtual void SetLeftSideTailFraction(double a) {
0253          fIntervalType = kTailFraction;
0254          fLeftSideTF = a;
0255       }
0256 
0257       /// kbelasco: The inner-workings of the class really should not be exposed
0258       /// like this in a comment, but it seems to be the only way to give
0259       /// the user any control over this process, if they desire it
0260       ///
0261       /// Set the fraction delta such that
0262       /// topCutoff (a) is considered == bottomCutoff (b) iff
0263       /// (std::abs(a - b) < std::abs(fDelta * (a + b)/2))
0264       /// when determining the confidence interval by Keys
0265       virtual void SetDelta(double delta)
0266       {
0267          if (delta < 0.) {
0268             coutE(InputArguments) << "MCMCInterval::SetDelta will not allow "
0269                                   << "negative delta value" << std::endl;
0270          } else {
0271             fDelta = delta;
0272          }
0273       }
0274 
0275    private:
0276       inline bool AcceptableConfLevel(double confLevel);
0277       inline bool WithinDeltaFraction(double a, double b);
0278 
0279       constexpr static const double DEFAULT_EPSILON = 0.01;
0280       constexpr static const double DEFAULT_DELTA   = 10e-6;
0281 
0282    protected:
0283       RooArgSet fParameters;         ///< parameters of interest for this interval
0284       std::unique_ptr<MarkovChain> fChain; ///< the markov chain
0285       double fConfidenceLevel = 0.0; ///< Requested confidence level (eg. 0.95 for 95% CL)
0286 
0287       std::unique_ptr<RooDataHist> fDataHist; ///< the binned Markov Chain data
0288       std::unique_ptr<THnSparse> fSparseHist; ///< the binned Markov Chain data
0289       double fHistConfLevel = 0.0;      ///< the actual conf level determined by hist
0290       double fHistCutoff = -1;          ///< cutoff bin size to be in interval
0291 
0292       std::unique_ptr<RooNDKeysPdf> fKeysPdf;     ///< the kernel estimation pdf
0293       std::unique_ptr<RooProduct> fProduct;       ///< the (keysPdf * heaviside) product
0294       std::unique_ptr<Heaviside> fHeaviside;      ///< the Heaviside function
0295       std::unique_ptr<RooDataHist> fKeysDataHist; ///< data hist representing product
0296       std::unique_ptr<RooRealVar> fCutoffVar;     ///< cutoff variable to use for integrating keys pdf
0297       double fKeysConfLevel = 0.0;          ///< the actual conf level determined by keys
0298       double fKeysCutoff = -1;              ///< cutoff keys pdf value to be in interval
0299       double fFull = 0.0;                   ///< Value of intergral of fProduct
0300 
0301       double fLeftSideTF = -1;    ///< left side tail-fraction for interval
0302       double fTFConfLevel = 0.0;  ///< the actual conf level of tail-fraction interval
0303       std::vector<Int_t> fVector; ///< vector containing the Markov chain data
0304       double fVecWeight = 0;      ///< sum of weights of all entries in fVector
0305       double fTFLower;            ///< lower limit of the tail-fraction interval
0306       double fTFUpper;            ///< upper limit of the tail-fraction interval
0307 
0308       std::unique_ptr<TH1> fHist; ///< the binned Markov Chain data
0309 
0310       bool fUseKeys = false;        ///< whether to use kernel estimation
0311       bool fUseSparseHist = false;  ///< whether to use sparse hist (vs. RooDataHist)
0312       bool fIsHistStrict = true;    ///< whether the specified confidence level is a
0313                                     ///< floor for the actual confidence level (strict),
0314                                     ///< or a ceiling (not strict) for determination by
0315                                     ///< histogram
0316       Int_t fDimension = 1;         ///< number of variables
0317       Int_t fNumBurnInSteps = 0;    ///< number of steps to discard as burn in, starting
0318                                     ///< from the first
0319       std::vector<RooRealVar*> fAxes; ///< array of pointers to RooRealVars representing
0320                                     ///< the axes of the histogram
0321                                     ///< fAxes[0] represents x-axis, [1] y, [2] z, etc
0322 
0323       double fEpsilon = DEFAULT_EPSILON; ///< acceptable error for Keys interval determination
0324 
0325       double fDelta = DEFAULT_DELTA; ///< topCutoff (a) considered == bottomCutoff (b) iff
0326                                      ///< (std::abs(a - b) < std::abs(fDelta * (a + b)/2));
0327                                      ///< Theoretically, the Abs is not needed here, but
0328                                      ///< floating-point arithmetic does not always work
0329                                      ///< perfectly, and the Abs doesn't hurt
0330       enum IntervalType fIntervalType = kShortest;
0331 
0332       // functions
0333       virtual void DetermineInterval();
0334       virtual void DetermineShortestInterval();
0335       virtual void DetermineTailFractionInterval();
0336       virtual void DetermineByHist();
0337       virtual void DetermineBySparseHist();
0338       virtual void DetermineByDataHist();
0339       virtual void DetermineByKeys();
0340       virtual void CreateHist();
0341       virtual void CreateSparseHist();
0342       virtual void CreateDataHist();
0343       virtual void CreateKeysPdf();
0344       virtual void CreateKeysDataHist();
0345       virtual void CreateVector(RooRealVar* param);
0346       inline virtual double CalcConfLevel(double cutoff, double full);
0347 
0348       ClassDefOverride(MCMCInterval,2)  // Concrete implementation of a ConfInterval based on MCMC calculation
0349 
0350    };
0351 }
0352 
0353 #endif