Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:31:30

0001 //----------------------------------*-C++-*----------------------------------//
0002 // Copyright 2021-2024 UT-Battelle, LLC, and other Celeritas developers.
0003 // See the top-level COPYRIGHT file for details.
0004 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0005 //---------------------------------------------------------------------------//
0006 //! \file celeritas/random/Selector.hh
0007 //---------------------------------------------------------------------------//
0008 #pragma once
0009 
0010 #include <type_traits>
0011 
0012 #include "corecel/cont/Range.hh"
0013 #include "corecel/math/Algorithms.hh"
0014 #include "corecel/math/SoftEqual.hh"
0015 
0016 #include "distribution/GenerateCanonical.hh"
0017 
0018 namespace celeritas
0019 {
0020 //---------------------------------------------------------------------------//
0021 /*!
0022  * On-the-fly selection of a weighted discrete distribution.
0023  *
0024  * This algorithm encapsulates the loop for sampling from distributions by
0025  * integer index or by OpaqueId. Edge cases are thoroughly tested (it will
0026  * never iterate off the end, even for incorrect values of the "total"
0027  * probability/xs), and it uses one fewer register than the typical
0028  * accumulation algorithm. When building with debug checking, the constructor
0029  * asserts that the provided "total" value is consistent.
0030  *
0031  * The given function *must* return a consistent value for the same given
0032  * argument.
0033  *
0034  * \code
0035     auto select_el = make_selector(
0036         [](ElementId i) { return xs[i.get()]; },
0037         ElementId{num_elements()},
0038         tot_xs);
0039     ElementId el = select_el(rng);
0040    \endcode
0041  * or
0042  * \code
0043     auto select_val = make_selector([](size_type i) { return pdf[i]; },
0044                                     pdf.size());
0045     size_type idx = select_val(rng);
0046    \endcode
0047  */
0048 template<class F, class T>
0049 class Selector
0050 {
0051   public:
0052     //!@{
0053     //! \name Type aliases
0054     using value_type = T;
0055     using real_type = typename std::invoke_result<F, value_type>::type;
0056     //!@}
0057 
0058   public:
0059     // Construct with function, size, and accumulated value
0060     inline CELER_FUNCTION Selector(F&& eval, value_type size, real_type total);
0061 
0062     // Sample from the distribution
0063     template<class Engine>
0064     inline CELER_FUNCTION T operator()(Engine& rng) const;
0065 
0066   private:
0067     using IterT = RangeIter<T>;
0068 
0069     F eval_;
0070     IterT last_;
0071     real_type total_;
0072 
0073     // Total value, for debug checking
0074     inline CELER_FUNCTION real_type debug_accumulated_total() const;
0075 };
0076 
0077 //---------------------------------------------------------------------------//
0078 /*!
0079  * Create a selector object from a function and total accumulated value.
0080  */
0081 template<class F, class T>
0082 CELER_FUNCTION Selector<F, T>
0083 make_selector(F&& func, T size, decltype(func(size)) total = 1)
0084 {
0085     return {celeritas::forward<F>(func), size, total};
0086 }
0087 
0088 //---------------------------------------------------------------------------//
0089 // INLINE DEFINITIONS
0090 //---------------------------------------------------------------------------//
0091 /*!
0092  * Construct with function, size, and accumulated value.
0093  */
0094 template<class F, class T>
0095 CELER_FUNCTION
0096 Selector<F, T>::Selector(F&& eval, value_type size, real_type total)
0097     : eval_{celeritas::forward<F>(eval)}, last_{size}, total_{total}
0098 {
0099     CELER_EXPECT(last_ != IterT{});
0100     CELER_EXPECT(total_ > 0);
0101     CELER_EXPECT(
0102         celeritas::soft_equal(this->debug_accumulated_total(), total_));
0103 
0104     // Don't accumulate the last value except to assert that the 'total'
0105     // isn't out-of-bounds
0106     --last_;
0107 }
0108 
0109 //---------------------------------------------------------------------------//
0110 /*!
0111  * Sample from the distribution.
0112  */
0113 template<class F, class T>
0114 template<class Engine>
0115 CELER_FUNCTION T Selector<F, T>::operator()(Engine& rng) const
0116 {
0117     real_type accum = -total_ * generate_canonical(rng);
0118     for (IterT iter{}; iter != last_; ++iter)
0119     {
0120         accum += eval_(*iter);
0121         if (accum > 0)
0122             return *iter;
0123     }
0124 
0125     return *last_;
0126 }
0127 
0128 //---------------------------------------------------------------------------//
0129 /*!
0130  * Accumulate total value for debug checking.
0131  *
0132  * This should *only* be used in the constructor before last_ is decremented.
0133  */
0134 template<class F, class T>
0135 CELER_FUNCTION auto Selector<F, T>::debug_accumulated_total() const -> real_type
0136 {
0137     real_type accum = 0;
0138     for (IterT iter{}; iter != last_; ++iter)
0139     {
0140         accum += eval_(*iter);
0141     }
0142     return accum;
0143 }
0144 
0145 //---------------------------------------------------------------------------//
0146 }  // namespace celeritas