Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-13 08:53:32

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file celeritas/grid/SplineCalculator.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include <cmath>
0010 
0011 #include "corecel/grid/Interpolator.hh"
0012 #include "corecel/grid/UniformGrid.hh"
0013 #include "corecel/grid/UniformGridData.hh"
0014 #include "corecel/math/Quantity.hh"
0015 #include "celeritas/Quantities.hh"
0016 
0017 namespace celeritas
0018 {
0019 //---------------------------------------------------------------------------//
0020 /*!
0021  * Find and interpolate cross sections on a uniform log grid with an input
0022  * spline-order.
0023  *
0024  * \todo Currently this is hard-coded to use "cross section grid data"
0025  * which have energy coordinates uniform in log space. This should
0026  * be expanded to handle multiple parameterizations of the energy grid (e.g.,
0027  * arbitrary spacing needed for the Livermore sampling) and of the value
0028  * interpolation (e.g. log interpolation). It might also make sense to get rid
0029  * of the "prime energy" and just use log-log interpolation instead, or do a
0030  * piecewise change in the interpolation instead of storing the cross section
0031  * scaled by the energy.
0032  *
0033  * \code
0034     SplineCalculator calc_xs(xs_grid, xs_params.reals);
0035     real_type xs = calc_xs(particle);
0036    \endcode
0037  */
0038 class SplineCalculator
0039 {
0040   public:
0041     //!@{
0042     //! \name Type aliases
0043     using Energy = units::MevEnergy;
0044     using Values
0045         = Collection<real_type, Ownership::const_reference, MemSpace::native>;
0046     //!@}
0047 
0048   public:
0049     // Construct from state-independent data
0050     inline CELER_FUNCTION
0051     SplineCalculator(UniformGridRecord const& grid, Values const& reals);
0052 
0053     // Find and interpolate from the energy
0054     inline CELER_FUNCTION real_type operator()(Energy energy) const;
0055 
0056     // Get the value at the given index
0057     inline CELER_FUNCTION real_type operator[](size_type index) const;
0058 
0059     // Get the minimum energy
0060     CELER_FUNCTION Energy energy_min() const
0061     {
0062         return Energy(std::exp(loge_grid_.front()));
0063     }
0064 
0065     // Get the maximum energy
0066     CELER_FUNCTION Energy energy_max() const
0067     {
0068         return Energy(std::exp(loge_grid_.back()));
0069     }
0070 
0071   private:
0072     UniformGridRecord const& data_;
0073     Values const& reals_;
0074     UniformGrid loge_grid_;
0075 
0076     CELER_FORCEINLINE_FUNCTION real_type interpolate(real_type energy,
0077                                                      size_type low_idx,
0078                                                      size_type high_idx) const;
0079 };
0080 
0081 //---------------------------------------------------------------------------//
0082 // INLINE DEFINITIONS
0083 //---------------------------------------------------------------------------//
0084 /*!
0085  * Construct from cross section data.
0086  */
0087 CELER_FUNCTION
0088 SplineCalculator::SplineCalculator(UniformGridRecord const& grid,
0089                                    Values const& reals)
0090     : data_(grid), reals_(reals), loge_grid_(data_.grid)
0091 {
0092     CELER_EXPECT(data_);
0093 }
0094 
0095 //---------------------------------------------------------------------------//
0096 /*!
0097  * Calculate the cross section.
0098  *
0099  * If needed, we can add a "log(energy/MeV)" accessor if we constantly reuse
0100  * that value and don't want to repeat the `std::log` operation.
0101  */
0102 CELER_FUNCTION real_type SplineCalculator::operator()(Energy energy) const
0103 {
0104     real_type const loge = std::log(energy.value());
0105 
0106     // Snap out-of-bounds values to closest grid points
0107     if (loge <= loge_grid_.front())
0108     {
0109         return (*this)[0];
0110     }
0111     if (loge >= loge_grid_.back())
0112     {
0113         return (*this)[loge_grid_.size() - 1];
0114     }
0115 
0116     // Locate the energy bin
0117     size_type lower_idx = loge_grid_.find(loge);
0118     CELER_ASSERT(lower_idx + 1 < loge_grid_.size());
0119 
0120     // Number of grid indices away from the specified energy that need to be
0121     // checked in both directions
0122     size_type order_steps = data_.spline_order / 2 + 1;
0123 
0124     // True bounding indices of the grid that will be checked.
0125     // If the interpolation requests out-of-bounds indices, clip the
0126     // extents. This will reduce the order of the interpolation
0127     // TODO: instead of clipping the bounds, alter both the low and high
0128     // index to keep the range just shifted down
0129 
0130     size_type true_low_idx;
0131     if (lower_idx >= order_steps - 1)
0132     {
0133         true_low_idx = lower_idx - order_steps + 1;
0134     }
0135     else
0136     {
0137         true_low_idx = 0;
0138     }
0139     size_type true_high_idx
0140         = min(lower_idx + order_steps + 1, loge_grid_.size());
0141 
0142     if (data_.spline_order % 2 == 0)
0143     {
0144         // If the requested interpolation order is even, a direction must be
0145         // selected to interpolate to
0146         real_type low_dist = std::fabs(loge - loge_grid_[lower_idx]);
0147         real_type high_dist = std::fabs(loge_grid_[lower_idx + 1] - loge);
0148 
0149         if (true_high_idx - true_low_idx > data_.spline_order + 1)
0150         {
0151             // If we already clipped based on the bounding indices, don't clip
0152             // again
0153             if (low_dist > high_dist)
0154             {
0155                 true_low_idx += 1;
0156             }
0157             else
0158             {
0159                 true_high_idx -= 1;
0160             }
0161         }
0162     }
0163     return this->interpolate(energy.value(), true_low_idx, true_high_idx);
0164 }
0165 
0166 //---------------------------------------------------------------------------//
0167 /*!
0168  * Get the tabulated value at the given index.
0169  */
0170 CELER_FUNCTION real_type SplineCalculator::operator[](size_type index) const
0171 {
0172     CELER_EXPECT(index < data_.value.size());
0173     return reals_[data_.value[index]];
0174 }
0175 
0176 //---------------------------------------------------------------------------//
0177 /*!
0178  * Interpolate the value using spline.
0179  */
0180 CELER_FUNCTION real_type SplineCalculator::interpolate(real_type energy,
0181                                                        size_type low_idx,
0182                                                        size_type high_idx) const
0183 {
0184     CELER_EXPECT(high_idx <= loge_grid_.size());
0185     real_type result = 0;
0186 
0187     // Outer loop over indices for contributing to the result
0188     for (size_type outer_idx = low_idx; outer_idx < high_idx; ++outer_idx)
0189     {
0190         real_type outer_e = std::exp(loge_grid_[outer_idx]);
0191         real_type num = 1;
0192         real_type denom = 1;
0193 
0194         // Inner loop over indices for determining the weight
0195         for (size_type inner_idx = low_idx; inner_idx < high_idx; ++inner_idx)
0196         {
0197             // Don't contribute for inner and outer index the same
0198             if (inner_idx != outer_idx)
0199             {
0200                 real_type inner_e = std::exp(loge_grid_[inner_idx]);
0201                 num *= (energy - inner_e);
0202                 denom *= (outer_e - inner_e);
0203             }
0204         }
0205         result += (num / denom) * (*this)[outer_idx];
0206     }
0207     CELER_ENSURE(result >= 0);
0208     return result;
0209 }
0210 
0211 //---------------------------------------------------------------------------//
0212 }  // namespace celeritas