![]() |
|
|||
File indexing completed on 2025-02-22 10:31:30
0001 //----------------------------------*-C++-*----------------------------------// 0002 // Copyright 2021-2024 UT-Battelle, LLC, and other Celeritas developers. 0003 // See the top-level COPYRIGHT file for details. 0004 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 0005 //---------------------------------------------------------------------------// 0006 //! \file celeritas/random/Selector.hh 0007 //---------------------------------------------------------------------------// 0008 #pragma once 0009 0010 #include <type_traits> 0011 0012 #include "corecel/cont/Range.hh" 0013 #include "corecel/math/Algorithms.hh" 0014 #include "corecel/math/SoftEqual.hh" 0015 0016 #include "distribution/GenerateCanonical.hh" 0017 0018 namespace celeritas 0019 { 0020 //---------------------------------------------------------------------------// 0021 /*! 0022 * On-the-fly selection of a weighted discrete distribution. 0023 * 0024 * This algorithm encapsulates the loop for sampling from distributions by 0025 * integer index or by OpaqueId. Edge cases are thoroughly tested (it will 0026 * never iterate off the end, even for incorrect values of the "total" 0027 * probability/xs), and it uses one fewer register than the typical 0028 * accumulation algorithm. When building with debug checking, the constructor 0029 * asserts that the provided "total" value is consistent. 0030 * 0031 * The given function *must* return a consistent value for the same given 0032 * argument. 0033 * 0034 * \code 0035 auto select_el = make_selector( 0036 [](ElementId i) { return xs[i.get()]; }, 0037 ElementId{num_elements()}, 0038 tot_xs); 0039 ElementId el = select_el(rng); 0040 \endcode 0041 * or 0042 * \code 0043 auto select_val = make_selector([](size_type i) { return pdf[i]; }, 0044 pdf.size()); 0045 size_type idx = select_val(rng); 0046 \endcode 0047 */ 0048 template<class F, class T> 0049 class Selector 0050 { 0051 public: 0052 //!@{ 0053 //! \name Type aliases 0054 using value_type = T; 0055 using real_type = typename std::invoke_result<F, value_type>::type; 0056 //!@} 0057 0058 public: 0059 // Construct with function, size, and accumulated value 0060 inline CELER_FUNCTION Selector(F&& eval, value_type size, real_type total); 0061 0062 // Sample from the distribution 0063 template<class Engine> 0064 inline CELER_FUNCTION T operator()(Engine& rng) const; 0065 0066 private: 0067 using IterT = RangeIter<T>; 0068 0069 F eval_; 0070 IterT last_; 0071 real_type total_; 0072 0073 // Total value, for debug checking 0074 inline CELER_FUNCTION real_type debug_accumulated_total() const; 0075 }; 0076 0077 //---------------------------------------------------------------------------// 0078 /*! 0079 * Create a selector object from a function and total accumulated value. 0080 */ 0081 template<class F, class T> 0082 CELER_FUNCTION Selector<F, T> 0083 make_selector(F&& func, T size, decltype(func(size)) total = 1) 0084 { 0085 return {celeritas::forward<F>(func), size, total}; 0086 } 0087 0088 //---------------------------------------------------------------------------// 0089 // INLINE DEFINITIONS 0090 //---------------------------------------------------------------------------// 0091 /*! 0092 * Construct with function, size, and accumulated value. 0093 */ 0094 template<class F, class T> 0095 CELER_FUNCTION 0096 Selector<F, T>::Selector(F&& eval, value_type size, real_type total) 0097 : eval_{celeritas::forward<F>(eval)}, last_{size}, total_{total} 0098 { 0099 CELER_EXPECT(last_ != IterT{}); 0100 CELER_EXPECT(total_ > 0); 0101 CELER_EXPECT( 0102 celeritas::soft_equal(this->debug_accumulated_total(), total_)); 0103 0104 // Don't accumulate the last value except to assert that the 'total' 0105 // isn't out-of-bounds 0106 --last_; 0107 } 0108 0109 //---------------------------------------------------------------------------// 0110 /*! 0111 * Sample from the distribution. 0112 */ 0113 template<class F, class T> 0114 template<class Engine> 0115 CELER_FUNCTION T Selector<F, T>::operator()(Engine& rng) const 0116 { 0117 real_type accum = -total_ * generate_canonical(rng); 0118 for (IterT iter{}; iter != last_; ++iter) 0119 { 0120 accum += eval_(*iter); 0121 if (accum > 0) 0122 return *iter; 0123 } 0124 0125 return *last_; 0126 } 0127 0128 //---------------------------------------------------------------------------// 0129 /*! 0130 * Accumulate total value for debug checking. 0131 * 0132 * This should *only* be used in the constructor before last_ is decremented. 0133 */ 0134 template<class F, class T> 0135 CELER_FUNCTION auto Selector<F, T>::debug_accumulated_total() const -> real_type 0136 { 0137 real_type accum = 0; 0138 for (IterT iter{}; iter != last_; ++iter) 0139 { 0140 accum += eval_(*iter); 0141 } 0142 return accum; 0143 } 0144 0145 //---------------------------------------------------------------------------// 0146 } // namespace celeritas
[ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
![]() ![]() |