File indexing completed on 2026-05-27 07:23:54
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010
0011 #if defined(__GNUC__) && !defined(__clang__)
0012 #pragma GCC diagnostic warning "-Wdeprecated-declarations"
0013 #endif
0014
0015
0016 #include "detray/algebra/common/array_operators.hpp"
0017 #include "detray/algebra/concepts.hpp"
0018 #include "detray/definitions/detail/qualifiers.hpp"
0019
0020
0021 #include <array>
0022 #include <cstddef>
0023 #include <initializer_list>
0024 #include <type_traits>
0025 #include <utility>
0026
0027 namespace detray::algebra {
0028
0029 namespace storage {
0030
0031 namespace detail {
0032
0033
0034
0035 DETRAY_HOST_DEVICE
0036 consteval std::size_t nearest_power_of_two(std::size_t min_value,
0037 std::size_t current_value) {
0038
0039
0040 return min_value <= current_value
0041 ? current_value
0042 : nearest_power_of_two(min_value, current_value * 2u);
0043 }
0044
0045 }
0046
0047
0048
0049 template <std::size_t N, concepts::scalar scalar_t,
0050 template <typename, std::size_t> class array_t>
0051 class DETRAY_ALIGN(
0052 alignof(array_t<scalar_t, detail::nearest_power_of_two(N, 2u)>)) vector {
0053 public:
0054
0055 DETRAY_HOST_DEVICE
0056 static consteval std::size_t simd_size() {
0057 return concepts::value<scalar_t> ? detail::nearest_power_of_two(N, 2u) : N;
0058 }
0059
0060
0061 using scalar_type = scalar_t;
0062
0063 using array_type = array_t<scalar_t, simd_size()>;
0064
0065
0066 DETRAY_HOST_DEVICE
0067 constexpr vector() {
0068 if constexpr (!concepts::simd_scalar<scalar_type>) {
0069 zero_fill(std::make_index_sequence<simd_size()>{});
0070 }
0071 }
0072
0073
0074
0075
0076
0077
0078 template <typename... Scalars>
0079 requires(concepts::simd_scalar<scalar_t> && (sizeof...(Scalars) == N) &&
0080 ((concepts::simd_scalar<Scalars> ||
0081 std::convertible_to<Scalars, scalar_t>) &&
0082 ...))
0083 DETRAY_HOST_DEVICE constexpr vector(Scalars &&...scals)
0084 : m_data{std::forward<Scalars>(scals)...} {}
0085
0086
0087
0088
0089 template <typename... Values>
0090 requires(!concepts::simd_scalar<scalar_t> && (sizeof...(Values) > 1) &&
0091 ((concepts::value<Values> ||
0092 std::convertible_to<Values, scalar_t>) &&
0093 ...))
0094 DETRAY_HOST_DEVICE constexpr vector(Values &&...vals) {
0095 static_assert(sizeof...(Values) <= N);
0096
0097
0098 if constexpr ((simd_size() - N) == 1) {
0099 m_data = {std::forward<Values>(vals)..., 0.f};
0100 } else if constexpr ((simd_size() - N) == 2) {
0101 m_data = {std::forward<Values>(vals)..., 0.f, 0.f};
0102 } else if constexpr (sizeof...(Values) < simd_size()) {
0103
0104 zero_fill(std::make_index_sequence<simd_size() - sizeof...(Values)>{});
0105 } else {
0106 m_data = {std::forward<Values>(vals)...};
0107 }
0108 }
0109
0110
0111
0112
0113
0114 DETRAY_HOST_DEVICE
0115 constexpr vector(array_type &&vals) : m_data{std::move(vals)} {}
0116
0117 DETRAY_HOST_DEVICE
0118 constexpr vector(const array_type &vals) : m_data{vals} {}
0119
0120
0121
0122
0123
0124 template <std::size_t M>
0125 requires(vector<N, scalar_t, array_t>::simd_size() ==
0126 vector<M, scalar_t, array_t>::simd_size())
0127 DETRAY_HOST_DEVICE constexpr const vector &operator=(
0128 const vector<M, scalar_t, array_t> &lhs) {
0129 m_data = lhs;
0130 return *this;
0131 }
0132
0133
0134
0135 DETRAY_HOST_DEVICE
0136 constexpr operator array_type &() { return m_data; }
0137 DETRAY_HOST_DEVICE
0138 constexpr operator const array_type &() const { return m_data; }
0139
0140
0141 DETRAY_HOST_DEVICE
0142 constexpr const auto &get() const { return m_data; }
0143
0144
0145
0146 DETRAY_HOST_DEVICE
0147 constexpr decltype(auto) operator[](std::size_t i) { return m_data[i]; }
0148 DETRAY_HOST_DEVICE
0149 constexpr decltype(auto) operator[](std::size_t i) const { return m_data[i]; }
0150
0151
0152
0153
0154
0155 DETRAY_HOST_DEVICE
0156 constexpr decltype(auto) operator*=(scalar_type factor) noexcept {
0157 return m_data *= factor;
0158 }
0159
0160
0161
0162
0163 template <concepts::scalar S = scalar_t>
0164 requires(!concepts::simd_scalar<S>)
0165 DETRAY_HOST_DEVICE constexpr friend bool operator==(
0166 const vector &lhs, const vector &rhs) noexcept {
0167 const auto comp = lhs.compare(rhs);
0168 bool is_full = true;
0169
0170 DETRAY_UNROLL_N(N)
0171 for (unsigned int i{0u}; i < N; ++i) {
0172 is_full = is_full && comp[i];
0173 }
0174
0175 return is_full;
0176 }
0177
0178
0179 template <concepts::scalar S = scalar_t>
0180 requires(concepts::simd_scalar<S>)
0181 DETRAY_HOST_DEVICE constexpr friend bool operator==(
0182 const vector &lhs, const vector &rhs) noexcept {
0183 const auto comp = lhs.compare(rhs);
0184 bool is_full = true;
0185
0186 DETRAY_UNROLL_N(N)
0187 for (unsigned int i{0u}; i < N; ++i) {
0188
0189 is_full = is_full && comp[i].isFull();
0190 }
0191
0192 return is_full;
0193 }
0194
0195
0196
0197 template <typename other_t>
0198 DETRAY_HOST_DEVICE constexpr bool operator!=(
0199 const other_t &rhs) const noexcept {
0200 return ((*this == rhs) == false);
0201 }
0202
0203
0204 template <typename other_t>
0205 DETRAY_HOST_DEVICE constexpr auto compare(const other_t &rhs) const noexcept {
0206 using result_t = decltype(m_data[0] == rhs[0]);
0207
0208 std::array<result_t, N> comp;
0209
0210 DETRAY_UNROLL_N(N)
0211 for (unsigned int i{0u}; i < N; ++i) {
0212 comp[i] = (m_data[i] == rhs[i]);
0213 }
0214
0215 return comp;
0216 }
0217
0218
0219 array_t<scalar_t, simd_size()> m_data;
0220
0221 private:
0222
0223 template <std::size_t... Is>
0224 DETRAY_HOST_DEVICE constexpr void zero_fill(
0225 std::index_sequence<Is...>) noexcept {
0226 ((m_data[simd_size() - sizeof...(Is) + Is] = static_cast<scalar_t>(0)),
0227 ...);
0228 }
0229 };
0230
0231
0232 #define DECLARE_VECTOR_OPERATORS(OP) \
0233 template <std::size_t N, concepts::scalar scalar_t, concepts::value value_t, \
0234 template <typename, std::size_t> class array_t> \
0235 DETRAY_HOST_DEVICE constexpr decltype(auto) operator OP( \
0236 const vector<N, scalar_t, array_t> &lhs, value_t rhs) noexcept { \
0237 return lhs.m_data OP static_cast<scalar_t>(rhs); \
0238 } \
0239 template <std::size_t N, concepts::scalar scalar_t, concepts::value value_t, \
0240 template <typename, std::size_t> class array_t> \
0241 DETRAY_HOST_DEVICE constexpr decltype(auto) operator OP( \
0242 value_t lhs, const vector<N, scalar_t, array_t> &rhs) noexcept { \
0243 return static_cast<scalar_t>(lhs) OP rhs.m_data; \
0244 } \
0245 template <std::size_t N, concepts::scalar scalar_t, \
0246 template <typename, std::size_t> class array_t> \
0247 DETRAY_HOST_DEVICE constexpr decltype(auto) operator OP( \
0248 const vector<N, scalar_t, array_t> &lhs, \
0249 const vector<N, scalar_t, array_t> &rhs) noexcept { \
0250 return lhs.m_data OP rhs.m_data; \
0251 } \
0252 template <std::size_t N, concepts::scalar scalar_t, \
0253 template <typename, std::size_t> class array_t, typename other_t> \
0254 requires(concepts::vector<other_t> || concepts::simd_scalar<other_t>) \
0255 DETRAY_HOST_DEVICE constexpr decltype(auto) operator OP( \
0256 const vector<N, scalar_t, array_t> &lhs, const other_t &rhs) noexcept { \
0257 return lhs.m_data OP rhs; \
0258 } \
0259 template <std::size_t N, concepts::scalar scalar_t, \
0260 template <typename, std::size_t> class array_t, typename other_t> \
0261 requires(concepts::vector<other_t> || concepts::simd_scalar<other_t>) \
0262 DETRAY_HOST_DEVICE constexpr decltype(auto) operator OP( \
0263 const other_t &lhs, const vector<N, scalar_t, array_t> &rhs) noexcept { \
0264 return lhs OP rhs.m_data; \
0265 }
0266
0267
0268
0269 DECLARE_VECTOR_OPERATORS(+)
0270 DECLARE_VECTOR_OPERATORS(-)
0271 DECLARE_VECTOR_OPERATORS(*)
0272 DECLARE_VECTOR_OPERATORS(/)
0273
0274
0275
0276 #undef DECLARE_VECTOR_OPERATORS
0277
0278 }
0279
0280 namespace detail {
0281
0282 template <typename T>
0283 struct is_storage_vector : public std::false_type {};
0284
0285 template <std::size_t N, concepts::scalar scalar_t,
0286 template <typename, std::size_t> class array_t>
0287 struct is_storage_vector<algebra::storage::vector<N, scalar_t, array_t>>
0288 : public std::true_type {};
0289
0290 template <typename T>
0291 inline constexpr bool is_storage_vector_v = is_storage_vector<T>::value;
0292
0293 }
0294
0295 }