![]() |
|
|||
File indexing completed on 2025-09-14 08:50:58
0001 //------------------------------- -*- C++ -*- -------------------------------// 0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details 0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 0004 //---------------------------------------------------------------------------// 0005 //! \file corecel/random/distribution/Selector.hh 0006 //---------------------------------------------------------------------------// 0007 #pragma once 0008 0009 #include <type_traits> 0010 0011 #include "corecel/cont/Range.hh" 0012 #include "corecel/math/Algorithms.hh" 0013 #include "corecel/math/SoftEqual.hh" 0014 0015 #include "GenerateCanonical.hh" 0016 0017 namespace celeritas 0018 { 0019 //---------------------------------------------------------------------------// 0020 /*! 0021 * On-the-fly selection of a weighted discrete distribution. 0022 * 0023 * This algorithm encapsulates the loop for sampling from distributions by 0024 * integer index or by OpaqueId. Edge cases are thoroughly tested (it will 0025 * never iterate off the end, even for incorrect values of the "total" 0026 * probability/xs), and it uses one fewer register than the typical 0027 * accumulation algorithm. When building with debug checking, the constructor 0028 * asserts that the provided "total" value is consistent. 0029 * 0030 * The given function *must* return a consistent value for the same given 0031 * argument. 0032 * 0033 * \code 0034 auto select_el = make_selector( 0035 [](ElementId i) { return xs[i.get()]; }, 0036 ElementId{num_elements()}, 0037 tot_xs); 0038 ElementId el = select_el(rng); 0039 \endcode 0040 * or 0041 * \code 0042 auto select_val = make_selector([](size_type i) { return pdf[i]; }, 0043 pdf.size()); 0044 size_type idx = select_val(rng); 0045 \endcode 0046 */ 0047 template<class F, class T> 0048 class Selector 0049 { 0050 public: 0051 //!@{ 0052 //! \name Type aliases 0053 using value_type = T; 0054 using real_type = typename std::invoke_result<F, value_type>::type; 0055 //!@} 0056 0057 public: 0058 // Construct with function, size, and accumulated value 0059 inline CELER_FUNCTION Selector(F&& eval, value_type size, real_type total); 0060 0061 // Sample from the distribution 0062 template<class Engine> 0063 inline CELER_FUNCTION T operator()(Engine& rng) const; 0064 0065 private: 0066 using IterT = RangeIter<T>; 0067 0068 F eval_; 0069 IterT last_; 0070 real_type total_; 0071 0072 // Total value, for debug checking 0073 inline CELER_FUNCTION real_type debug_accumulated_total() const; 0074 }; 0075 0076 //---------------------------------------------------------------------------// 0077 /*! 0078 * Create a selector object from a function and total accumulated value. 0079 */ 0080 template<class F, class T> 0081 CELER_FUNCTION Selector<F, T> 0082 make_selector(F&& func, T size, decltype(func(size)) total = 1) 0083 { 0084 return {celeritas::forward<F>(func), size, total}; 0085 } 0086 0087 //---------------------------------------------------------------------------// 0088 // INLINE DEFINITIONS 0089 //---------------------------------------------------------------------------// 0090 /*! 0091 * Construct with function, size, and accumulated value. 0092 */ 0093 template<class F, class T> 0094 CELER_FUNCTION 0095 Selector<F, T>::Selector(F&& eval, value_type size, real_type total) 0096 : eval_{celeritas::forward<F>(eval)}, last_{size}, total_{total} 0097 { 0098 CELER_EXPECT(last_ != IterT{}); 0099 CELER_EXPECT(total_ > 0); 0100 CELER_EXPECT( 0101 celeritas::soft_equal(this->debug_accumulated_total(), total_)); 0102 0103 // Don't accumulate the last value except to assert that the 'total' 0104 // isn't out-of-bounds 0105 --last_; 0106 } 0107 0108 //---------------------------------------------------------------------------// 0109 /*! 0110 * Sample from the distribution. 0111 */ 0112 template<class F, class T> 0113 template<class Engine> 0114 CELER_FUNCTION T Selector<F, T>::operator()(Engine& rng) const 0115 { 0116 real_type accum = -total_ * generate_canonical(rng); 0117 for (IterT iter{}; iter != last_; ++iter) 0118 { 0119 accum += eval_(*iter); 0120 if (accum > 0) 0121 return *iter; 0122 } 0123 0124 return *last_; 0125 } 0126 0127 //---------------------------------------------------------------------------// 0128 /*! 0129 * Accumulate total value for debug checking. 0130 * 0131 * This should *only* be used in the constructor before last_ is decremented. 0132 */ 0133 template<class F, class T> 0134 CELER_FUNCTION auto Selector<F, T>::debug_accumulated_total() const -> real_type 0135 { 0136 real_type accum = 0; 0137 for (IterT iter{}; iter != last_; ++iter) 0138 { 0139 accum += eval_(*iter); 0140 } 0141 return accum; 0142 } 0143 0144 //---------------------------------------------------------------------------// 0145 } // namespace celeritas
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
![]() ![]() |