File indexing completed on 2026-05-27 07:24:14
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "detray/definitions/detail/qualifiers.hpp"
0013
0014
0015 #include "detray/test/device/device_fixture.hpp"
0016 #include "detray/test/framework/types.hpp"
0017
0018
0019 #include <vecmem/containers/data/vector_view.hpp>
0020 #include <vecmem/containers/device_vector.hpp>
0021 #include <vecmem/containers/vector.hpp>
0022 #include <vecmem/memory/memory_resource.hpp>
0023
0024
0025 #include <gtest/gtest.h>
0026
0027
0028 #include <cmath>
0029 #include <memory>
0030
0031 namespace detray::test {
0032
0033
0034 template <detray::concepts::algebra A>
0035 class transform_fixture : public device_fixture<dscalar<A>> {
0036 using scalar_t = dscalar<A>;
0037 using vector3_t = dvector3D<A>;
0038
0039 using result_t = scalar_t;
0040 using base_fixture = device_fixture<result_t>;
0041
0042 public:
0043
0044 transform_fixture(vecmem::memory_resource& mr) : base_fixture(mr) {}
0045
0046 protected:
0047
0048 virtual void SetUp() override {
0049
0050 base_fixture::SetUp();
0051
0052
0053 m_t1 = std::make_unique<vecmem::vector<vector3_t>>(this->size(),
0054 &this->resource());
0055 m_t2 = std::make_unique<vecmem::vector<vector3_t>>(this->size(),
0056 &this->resource());
0057 m_t3 = std::make_unique<vecmem::vector<vector3_t>>(this->size(),
0058 &this->resource());
0059
0060 m_v1 = std::make_unique<vecmem::vector<vector3_t>>(this->size(),
0061 &this->resource());
0062 m_v2 = std::make_unique<vecmem::vector<vector3_t>>(this->size(),
0063 &this->resource());
0064
0065
0066 for (std::size_t i = 0; i < this->size(); ++i) {
0067 m_t1->at(i) = {static_cast<scalar_t>(1.1), static_cast<scalar_t>(2.2),
0068 static_cast<scalar_t>(3.3)};
0069 m_t2->at(i) = {static_cast<scalar_t>(4.4), static_cast<scalar_t>(5.5),
0070 static_cast<scalar_t>(6.6)};
0071 m_t3->at(i) = {static_cast<scalar_t>(7.7), static_cast<scalar_t>(8.8),
0072 static_cast<scalar_t>(9.9)};
0073
0074 m_v1->at(i) = {static_cast<scalar_t>(i * 0.6),
0075 static_cast<scalar_t>((i + 1) * 1.2),
0076 static_cast<scalar_t>((i + 2) * 1.3)};
0077 m_v2->at(i) = {static_cast<scalar_t>((i + 1) * 1.8),
0078 static_cast<scalar_t>(i * 2.3),
0079 static_cast<scalar_t>((i + 2) * 3.4)};
0080 }
0081 }
0082
0083
0084 virtual void TearDown() override {
0085
0086 m_t1.reset();
0087 m_t2.reset();
0088 m_t3.reset();
0089 m_v1.reset();
0090 m_v2.reset();
0091
0092
0093 base_fixture::TearDown();
0094 }
0095
0096
0097
0098
0099 std::unique_ptr<vecmem::vector<vector3_t>> m_t1, m_t2, m_t3;
0100 std::unique_ptr<vecmem::vector<vector3_t>> m_v1, m_v2;
0101
0102
0103
0104 };
0105
0106
0107 template <detray::concepts::algebra A>
0108 class transform3_ops_functor {
0109 using scalar_t = dscalar<A>;
0110 using point3_t = dpoint3D<A>;
0111 using vector3_t = dvector3D<A>;
0112 using transform3_t = dtransform3D<A>;
0113
0114 public:
0115 DETRAY_HOST_DEVICE void operator()(
0116 std::size_t i, vecmem::data::vector_view<const vector3_t> t1,
0117 vecmem::data::vector_view<const vector3_t> t2,
0118 vecmem::data::vector_view<const vector3_t> t3,
0119 vecmem::data::vector_view<const vector3_t> a,
0120 vecmem::data::vector_view<const vector3_t> b,
0121 vecmem::data::vector_view<scalar_t> output) const {
0122
0123 vecmem::device_vector<const vector3_t> vec_t1(t1), vec_t2(t2), vec_t3(t3),
0124 vec_a(a), vec_b(b);
0125 vecmem::device_vector<scalar_t> vec_output(output);
0126
0127
0128 auto ii = static_cast<typename decltype(vec_output)::size_type>(i);
0129 vec_output[ii] = transform3_ops(vec_t1[ii], vec_t2[ii], vec_t3[ii],
0130 vec_a[ii], vec_b[ii]);
0131 }
0132
0133 private:
0134
0135 DETRAY_HOST_DEVICE
0136 scalar_t transform3_ops(vector3_t t1, vector3_t t2, vector3_t t3, vector3_t a,
0137 vector3_t b) const {
0138 using namespace algebra;
0139
0140 transform3_t tr1(t1, t2, t3);
0141 transform3_t tr2;
0142 tr2 = tr1;
0143
0144 point3_t translation = tr2.translation();
0145
0146 point3_t gpoint = tr2.point_to_global(a);
0147 point3_t lpoint = tr2.point_to_local(b);
0148
0149 vector3_t gvec = tr2.vector_to_global(a);
0150 vector3_t lvec = tr2.vector_to_local(b);
0151
0152 return {detray::vector::norm(translation) + detray::vector::perp(gpoint) +
0153 detray::vector::phi(lpoint) + detray::vector::dot(gvec, lvec)};
0154 }
0155 };
0156
0157 }