File indexing completed on 2026-05-27 07:23:54
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010
0011 #include "detray/algebra/common/vector.hpp"
0012 #include "detray/algebra/concepts.hpp"
0013 #include "detray/algebra/type_traits.hpp"
0014
0015
0016 #include <array>
0017 #include <cassert>
0018
0019 namespace detray::algebra::storage {
0020
0021
0022 template <template <typename, std::size_t> class array_t,
0023 concepts::scalar scalar_t, std::size_t ROW, std::size_t COL>
0024 struct DETRAY_ALIGN(alignof(algebra::storage::vector<ROW, scalar_t, array_t>))
0025 matrix {
0026
0027 using vector_type = algebra::storage::vector<ROW, scalar_t, array_t>;
0028
0029 using scalar_type = scalar_t;
0030
0031
0032 constexpr matrix() = default;
0033
0034
0035 template <concepts::vector... vector_t>
0036 DETRAY_HOST_DEVICE
0037 requires(sizeof...(vector_t) == COL)
0038 explicit matrix(vector_t &&...v) : m_storage{std::forward<vector_t>(v)...} {}
0039
0040
0041
0042 DETRAY_HOST_DEVICE
0043 constexpr const vector_type &operator[](const std::size_t i) const {
0044 assert(i < COL);
0045 return m_storage[i];
0046 }
0047 DETRAY_HOST_DEVICE
0048 constexpr vector_type &operator[](const std::size_t i) {
0049 assert(i < COL);
0050 return m_storage[i];
0051 }
0052
0053
0054
0055 DETRAY_HOST_DEVICE
0056 static consteval std::size_t rows() { return ROW; }
0057
0058
0059
0060 DETRAY_HOST_DEVICE
0061 static consteval std::size_t storage_rows() {
0062 return vector_type::simd_size();
0063 }
0064
0065
0066 DETRAY_HOST_DEVICE
0067 static consteval std::size_t columns() { return COL; }
0068
0069 private:
0070
0071 template <std::size_t R, std::size_t C, typename S,
0072 template <typename, std::size_t> class A>
0073 DETRAY_HOST_DEVICE friend constexpr bool operator==(
0074 const matrix<A, S, R, C> &lhs, const matrix<A, S, R, C> &rhs);
0075
0076
0077
0078
0079 template <std::size_t... I>
0080 DETRAY_HOST_DEVICE
0081 requires(!std::is_scalar_v<scalar_t>)
0082 constexpr bool equal(const matrix &rhs, std::index_sequence<I...>) const {
0083 return (... && (m_storage[I] == rhs[I]));
0084 }
0085
0086
0087 template <std::size_t... I>
0088 DETRAY_HOST
0089 requires(std::is_scalar_v<scalar_t>)
0090 constexpr bool equal(const matrix &rhs, std::index_sequence<I...>) const {
0091 return (... && ((m_storage[I].get() == rhs[I].get()).isFull()));
0092 }
0093
0094
0095
0096
0097 template <std::size_t R, std::size_t C, typename S,
0098 template <typename, std::size_t> class A>
0099 DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator+(
0100 const matrix<A, S, R, C> &lhs, const matrix<A, S, R, C> &rhs) noexcept;
0101
0102 template <std::size_t R, std::size_t C, typename S,
0103 template <typename, std::size_t> class A>
0104 DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator-(
0105 const matrix<A, S, R, C> &lhs, const matrix<A, S, R, C> &rhs) noexcept;
0106
0107 template <std::size_t R, std::size_t C, typename S1, typename S2,
0108 template <typename, std::size_t> class A>
0109 DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0110 const S2 a, const matrix<A, S1, R, C> &rhs) noexcept;
0111
0112 template <std::size_t R, std::size_t C, concepts::scalar S1,
0113 concepts::scalar S2, template <typename, std::size_t> class A>
0114 DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0115 const matrix<A, S1, R, C> &lhs, const S2 a) noexcept;
0116
0117
0118 template <std::size_t R, std::size_t C, typename S,
0119 template <typename, std::size_t> class A>
0120 DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0121 const matrix<A, S, R, C> &lhs, const vector<C, S, A> &v) noexcept;
0122
0123
0124 template <std::size_t LR, std::size_t C, std::size_t RC, typename S,
0125 template <typename, std::size_t> class A>
0126 DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0127 const matrix<A, S, LR, C> &lhs, const matrix<A, S, C, RC> &rhs) noexcept;
0128
0129
0130
0131 std::array<vector_type, COL> m_storage;
0132
0133 };
0134
0135
0136 template <concepts::matrix matrix_t, std::size_t COLS = matrix_t::columns()>
0137 DETRAY_HOST_DEVICE constexpr matrix_t zero() noexcept {
0138 matrix_t m;
0139
0140 DETRAY_UNROLL_N(COLS)
0141 for (std::size_t j = 0u; j < COLS; ++j) {
0142
0143 m[j] = typename matrix_t::vector_type{};
0144 }
0145
0146 return m;
0147 }
0148
0149
0150 template <concepts::matrix matrix_t>
0151 DETRAY_HOST_DEVICE constexpr void set_zero(matrix_t &m) noexcept {
0152 m = zero<matrix_t>();
0153 }
0154
0155
0156 template <concepts::matrix matrix_t,
0157 std::size_t R = detray::traits::max_rank<matrix_t>>
0158 DETRAY_HOST_DEVICE constexpr matrix_t identity() noexcept {
0159
0160 matrix_t m{zero<matrix_t>()};
0161
0162 DETRAY_UNROLL_N(R)
0163 for (std::size_t i = 0u; i < R; ++i) {
0164 m[i][i] = typename matrix_t::scalar_type(1);
0165 }
0166
0167 return m;
0168 }
0169
0170
0171 template <concepts::matrix matrix_t>
0172 DETRAY_HOST_DEVICE constexpr void set_identity(matrix_t &m) noexcept {
0173 m = identity<matrix_t>();
0174 }
0175
0176
0177 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0178 template <typename, std::size_t> class array_t, std::size_t... I>
0179 DETRAY_HOST_DEVICE constexpr auto transpose(
0180 const matrix<array_t, scalar_t, ROW, COL> &m,
0181 std::index_sequence<I...>) noexcept {
0182 using matrix_T_t = matrix<array_t, scalar_t, COL, ROW>;
0183 using column_t = typename matrix_T_t::vector_type;
0184
0185 matrix_T_t res_m;
0186
0187 DETRAY_UNROLL_N(ROW)
0188 for (std::size_t j = 0u; j < ROW; ++j) {
0189 res_m[j] = column_t{m[I][j]...};
0190 }
0191
0192 return res_m;
0193 }
0194
0195
0196 template <concepts::matrix matrix_t>
0197 DETRAY_HOST_DEVICE constexpr auto transpose(const matrix_t &m) noexcept {
0198 return transpose(m, std::make_index_sequence<matrix_t::columns()>());
0199 }
0200
0201
0202 template <std::size_t ROW, std::size_t COL, typename scalar_t,
0203 template <typename, std::size_t> class array_t>
0204 DETRAY_HOST_DEVICE constexpr bool operator==(
0205 const matrix<array_t, scalar_t, ROW, COL> &lhs,
0206 const matrix<array_t, scalar_t, ROW, COL> &rhs) {
0207 return lhs.equal(rhs, std::make_index_sequence<COL>());
0208 }
0209
0210
0211 template <concepts::matrix matrix_t, concepts::scalar scalar_t,
0212 std::size_t... J>
0213 DETRAY_HOST_DEVICE constexpr matrix_t matrix_scalar_mul(
0214 scalar_t a, const matrix_t &rhs, std::index_sequence<J...>) noexcept {
0215 using mat_scalar_t = detray::traits::scalar_t<matrix_t>;
0216
0217 return matrix_t{(static_cast<mat_scalar_t>(a) * rhs[J])...};
0218 }
0219
0220
0221 template <concepts::matrix matrix_t, std::size_t... J>
0222 DETRAY_HOST_DEVICE constexpr matrix_t matrix_add(
0223 const matrix_t &lhs, const matrix_t &rhs,
0224 std::index_sequence<J...>) noexcept {
0225 return matrix_t{(lhs[J] + rhs[J])...};
0226 }
0227
0228 template <concepts::matrix matrix_t, std::size_t... J>
0229 DETRAY_HOST_DEVICE constexpr decltype(auto) matrix_sub(
0230 const matrix_t &lhs, const matrix_t &rhs,
0231 std::index_sequence<J...>) noexcept {
0232 return matrix_t{(lhs[J] - rhs[J])...};
0233 }
0234
0235
0236
0237 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0238 template <typename, std::size_t> class array_t>
0239 DETRAY_HOST_DEVICE constexpr decltype(auto) operator+(
0240 const matrix<array_t, scalar_t, ROW, COL> &lhs,
0241 const matrix<array_t, scalar_t, ROW, COL> &rhs) noexcept {
0242 using matrix_t = matrix<array_t, scalar_t, ROW, COL>;
0243
0244 return matrix_add(lhs, rhs, std::make_index_sequence<matrix_t::columns()>());
0245 }
0246
0247 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0248 template <typename, std::size_t> class array_t>
0249 DETRAY_HOST_DEVICE constexpr decltype(auto) operator-(
0250 const matrix<array_t, scalar_t, ROW, COL> &lhs,
0251 const matrix<array_t, scalar_t, ROW, COL> &rhs) noexcept {
0252 using matrix_t = matrix<array_t, scalar_t, ROW, COL>;
0253
0254 return matrix_sub(lhs, rhs, std::make_index_sequence<matrix_t::columns()>());
0255 }
0256
0257 template <std::size_t R, std::size_t C, typename S1, typename S2,
0258 template <typename, std::size_t> class A>
0259 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0260 const S2 a, const matrix<A, S1, R, C> &rhs) noexcept {
0261 using matrix_t = matrix<A, S2, R, C>;
0262
0263 return matrix_scalar_mul(static_cast<S1>(a), rhs,
0264 std::make_index_sequence<matrix_t::columns()>());
0265 }
0266
0267 template <std::size_t R, std::size_t C, concepts::scalar S1,
0268 concepts::scalar S2, template <typename, std::size_t> class A>
0269 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0270 const matrix<A, S1, R, C> &lhs, const S2 a) noexcept {
0271 return static_cast<S1>(a) * lhs;
0272 }
0273
0274
0275 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0276 template <typename, std::size_t> class array_t>
0277 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0278 const matrix<array_t, scalar_t, ROW, COL> &lhs,
0279 const vector<COL, scalar_t, array_t> &v) noexcept {
0280
0281 vector<ROW, scalar_t, array_t> res_v{v[0] * lhs[0]};
0282
0283
0284 DETRAY_UNROLL_N(COL)
0285 for (std::size_t j = 1u; j < COL; ++j) {
0286
0287 res_v = res_v + v[j] * lhs[j];
0288 }
0289
0290 return res_v;
0291 }
0292
0293
0294 template <std::size_t LROW, std::size_t COL, std::size_t RCOL,
0295 concepts::scalar scalar_t,
0296 template <typename, std::size_t> class array_t>
0297 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0298 const matrix<array_t, scalar_t, LROW, COL> &lhs,
0299 const matrix<array_t, scalar_t, COL, RCOL> &rhs) noexcept {
0300 matrix<array_t, scalar_t, LROW, RCOL> res_m;
0301
0302 DETRAY_UNROLL_N(RCOL)
0303 for (std::size_t j = 0u; j < RCOL; ++j) {
0304
0305 res_m[j] = rhs[j][0] * lhs[0];
0306
0307
0308 DETRAY_UNROLL_N(COL)
0309 for (std::size_t i = 1u; i < COL; ++i) {
0310
0311 res_m[j] = res_m[j] + rhs[j][i] * lhs[i];
0312 }
0313 }
0314
0315 return res_m;
0316 }
0317
0318
0319 }