Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-14 08:50:58

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file corecel/random/distribution/Selector.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include <type_traits>
0010 
0011 #include "corecel/cont/Range.hh"
0012 #include "corecel/math/Algorithms.hh"
0013 #include "corecel/math/SoftEqual.hh"
0014 
0015 #include "GenerateCanonical.hh"
0016 
0017 namespace celeritas
0018 {
0019 //---------------------------------------------------------------------------//
0020 /*!
0021  * On-the-fly selection of a weighted discrete distribution.
0022  *
0023  * This algorithm encapsulates the loop for sampling from distributions by
0024  * integer index or by OpaqueId. Edge cases are thoroughly tested (it will
0025  * never iterate off the end, even for incorrect values of the "total"
0026  * probability/xs), and it uses one fewer register than the typical
0027  * accumulation algorithm. When building with debug checking, the constructor
0028  * asserts that the provided "total" value is consistent.
0029  *
0030  * The given function *must* return a consistent value for the same given
0031  * argument.
0032  *
0033  * \code
0034     auto select_el = make_selector(
0035         [](ElementId i) { return xs[i.get()]; },
0036         ElementId{num_elements()},
0037         tot_xs);
0038     ElementId el = select_el(rng);
0039    \endcode
0040  * or
0041  * \code
0042     auto select_val = make_selector([](size_type i) { return pdf[i]; },
0043                                     pdf.size());
0044     size_type idx = select_val(rng);
0045    \endcode
0046  */
0047 template<class F, class T>
0048 class Selector
0049 {
0050   public:
0051     //!@{
0052     //! \name Type aliases
0053     using value_type = T;
0054     using real_type = typename std::invoke_result<F, value_type>::type;
0055     //!@}
0056 
0057   public:
0058     // Construct with function, size, and accumulated value
0059     inline CELER_FUNCTION Selector(F&& eval, value_type size, real_type total);
0060 
0061     // Sample from the distribution
0062     template<class Engine>
0063     inline CELER_FUNCTION T operator()(Engine& rng) const;
0064 
0065   private:
0066     using IterT = RangeIter<T>;
0067 
0068     F eval_;
0069     IterT last_;
0070     real_type total_;
0071 
0072     // Total value, for debug checking
0073     inline CELER_FUNCTION real_type debug_accumulated_total() const;
0074 };
0075 
0076 //---------------------------------------------------------------------------//
0077 /*!
0078  * Create a selector object from a function and total accumulated value.
0079  */
0080 template<class F, class T>
0081 CELER_FUNCTION Selector<F, T>
0082 make_selector(F&& func, T size, decltype(func(size)) total = 1)
0083 {
0084     return {celeritas::forward<F>(func), size, total};
0085 }
0086 
0087 //---------------------------------------------------------------------------//
0088 // INLINE DEFINITIONS
0089 //---------------------------------------------------------------------------//
0090 /*!
0091  * Construct with function, size, and accumulated value.
0092  */
0093 template<class F, class T>
0094 CELER_FUNCTION
0095 Selector<F, T>::Selector(F&& eval, value_type size, real_type total)
0096     : eval_{celeritas::forward<F>(eval)}, last_{size}, total_{total}
0097 {
0098     CELER_EXPECT(last_ != IterT{});
0099     CELER_EXPECT(total_ > 0);
0100     CELER_EXPECT(
0101         celeritas::soft_equal(this->debug_accumulated_total(), total_));
0102 
0103     // Don't accumulate the last value except to assert that the 'total'
0104     // isn't out-of-bounds
0105     --last_;
0106 }
0107 
0108 //---------------------------------------------------------------------------//
0109 /*!
0110  * Sample from the distribution.
0111  */
0112 template<class F, class T>
0113 template<class Engine>
0114 CELER_FUNCTION T Selector<F, T>::operator()(Engine& rng) const
0115 {
0116     real_type accum = -total_ * generate_canonical(rng);
0117     for (IterT iter{}; iter != last_; ++iter)
0118     {
0119         accum += eval_(*iter);
0120         if (accum > 0)
0121             return *iter;
0122     }
0123 
0124     return *last_;
0125 }
0126 
0127 //---------------------------------------------------------------------------//
0128 /*!
0129  * Accumulate total value for debug checking.
0130  *
0131  * This should *only* be used in the constructor before last_ is decremented.
0132  */
0133 template<class F, class T>
0134 CELER_FUNCTION auto Selector<F, T>::debug_accumulated_total() const -> real_type
0135 {
0136     real_type accum = 0;
0137     for (IterT iter{}; iter != last_; ++iter)
0138     {
0139         accum += eval_(*iter);
0140     }
0141     return accum;
0142 }
0143 
0144 //---------------------------------------------------------------------------//
0145 }  // namespace celeritas