File indexing completed on 2026-05-27 07:24:23
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #include "detray/test/utils/landau_distribution.hpp"
0011
0012
0013 #include <gtest/gtest.h>
0014
0015
0016 #include <algorithm>
0017 #include <iterator>
0018
0019 using namespace detray;
0020
0021
0022
0023
0024
0025 template <typename T>
0026 class detray_simulation_LandauSamplingValidation : public ::testing::Test {
0027 public:
0028 using scalar_type = T;
0029
0030
0031 std::size_t get_index(const scalar_type value) const {
0032 return static_cast<std::size_t>((value - min) / bin_size);
0033 }
0034
0035
0036
0037 constexpr static const scalar_type mu = 0.f;
0038 constexpr static const scalar_type sigma = 1.f;
0039 constexpr static const scalar_type mpv = -0.22278f;
0040
0041
0042 constexpr static const double bin_size = 0.05;
0043 constexpr static const double min = -2.;
0044 constexpr static const double max = 2.;
0045 constexpr static const std::size_t n_bins =
0046 static_cast<std::size_t>((max - min) / bin_size);
0047 };
0048
0049
0050 using TestTypes = ::testing::Types<float, double>;
0051 TYPED_TEST_SUITE(detray_simulation_LandauSamplingValidation, TestTypes, );
0052
0053 TYPED_TEST(detray_simulation_LandauSamplingValidation, landau_sampling) {
0054
0055 std::random_device rd{};
0056 std::mt19937_64 generator{rd()};
0057 generator.seed(0u);
0058
0059
0060 landau_distribution<typename TestFixture::scalar_type> ld;
0061
0062
0063 EXPECT_EQ(this->n_bins, 80u);
0064
0065
0066 std::vector<int> counter(this->n_bins, 0);
0067
0068
0069 std::size_t n_samples = 10000000u;
0070 const auto minf = static_cast<typename TestFixture::scalar_type>(this->min);
0071 const auto maxf = static_cast<typename TestFixture::scalar_type>(this->max);
0072 for (std::size_t i = 0u; i < n_samples; i++) {
0073 const auto sa = ld(generator, this->mu, this->sigma);
0074
0075 if (sa > minf && sa < maxf) {
0076 const std::size_t index = this->get_index(sa);
0077 counter[index]++;
0078 }
0079 }
0080
0081 const std::size_t mpv_index = this->get_index(this->mpv);
0082
0083 const auto max_index = static_cast<std::size_t>(
0084 std::distance(counter.begin(), std::ranges::max_element(counter)));
0085
0086
0087
0088 EXPECT_EQ(mpv_index, 35u);
0089
0090 EXPECT_TRUE(max_index == mpv_index || max_index == mpv_index - 1u);
0091 }