File indexing completed on 2026-05-27 07:23:54
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010
0011 #include "detray/algebra/common/matrix.hpp"
0012 #include "detray/algebra/common/vector.hpp"
0013 #include "detray/algebra/concepts.hpp"
0014
0015
0016 #include <cassert>
0017
0018 namespace detray::algebra::storage {
0019
0020
0021 struct element_getter {
0022
0023 template <template <typename, std::size_t> class array_t,
0024 concepts::scalar scalar_t, std::size_t ROW, std::size_t COL>
0025 DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0026 const matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row,
0027 std::size_t col) const {
0028
0029 assert(row < ROW);
0030 assert(col < COL);
0031
0032
0033 return m[col][row];
0034 }
0035
0036
0037 template <template <typename, std::size_t> class array_t,
0038 concepts::scalar scalar_t, std::size_t ROW, std::size_t COL>
0039 DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0040 matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row,
0041 std::size_t col) const {
0042 assert(row < ROW);
0043 assert(col < COL);
0044
0045 return m[col][row];
0046 }
0047
0048
0049 template <template <typename, std::size_t> class array_t,
0050 concepts::scalar scalar_t, std::size_t ROW>
0051 DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0052 const matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) const {
0053 assert(row < ROW);
0054 return m[0][row];
0055 }
0056
0057
0058 template <template <typename, std::size_t> class array_t,
0059 concepts::scalar scalar_t, std::size_t ROW>
0060 DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0061 matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) const {
0062 assert(row < ROW);
0063 return m[0][row];
0064 }
0065
0066
0067 template <template <typename, std::size_t> class array_t,
0068 concepts::scalar scalar_t, std::size_t N>
0069 DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0070 const vector<N, scalar_t, array_t> &v, std::size_t row) const {
0071 assert(row < N);
0072 return v[row];
0073 }
0074
0075
0076 template <template <typename, std::size_t> class array_t,
0077 concepts::scalar scalar_t, std::size_t N>
0078 DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0079 vector<N, scalar_t, array_t> &v, std::size_t row) const {
0080 assert(row < N);
0081 return v[row];
0082 }
0083
0084 };
0085
0086
0087 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0088 template <typename, std::size_t> class array_t>
0089 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0090 const matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row,
0091 std::size_t col) {
0092 return element_getter{}(m, row, col);
0093 }
0094
0095
0096 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0097 template <typename, std::size_t> class array_t>
0098 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0099 matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row, std::size_t col) {
0100 return element_getter{}(m, row, col);
0101 }
0102
0103
0104 template <std::size_t ROW, concepts::scalar scalar_t,
0105 template <typename, std::size_t> class array_t>
0106 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0107 const matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) {
0108 return element_getter{}(m, row);
0109 }
0110
0111
0112 template <std::size_t ROW, concepts::scalar scalar_t,
0113 template <typename, std::size_t> class array_t>
0114 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0115 matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) {
0116 return element_getter{}(m, row);
0117 }
0118
0119
0120 template <std::size_t N, concepts::scalar scalar_t,
0121 template <typename, std::size_t> class array_t>
0122 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0123 const vector<N, scalar_t, array_t> &v, std::size_t row) {
0124 return element_getter{}(v, row);
0125 }
0126
0127
0128 template <std::size_t N, concepts::scalar scalar_t,
0129 template <typename, std::size_t> class array_t>
0130 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0131 vector<N, scalar_t, array_t> &v, std::size_t row) {
0132 return element_getter{}(v, row);
0133 }
0134
0135 template <std::size_t I, std::size_t J, concepts::matrix M>
0136 DETRAY_HOST_DEVICE decltype(auto) element(M &matrix) {
0137 if constexpr (concepts::has_compile_time_2d_access<M>) {
0138 return matrix.template element<I, J>();
0139 } else {
0140 using index_t = detray::traits::index_t<std::decay_t<M>>;
0141 return element(matrix, static_cast<index_t>(I), static_cast<index_t>(J));
0142 }
0143 }
0144
0145 template <std::size_t I, concepts::vector V>
0146 DETRAY_HOST_DEVICE decltype(auto) element(V &vector) {
0147 if constexpr (concepts::has_compile_time_1d_access<V>) {
0148 return vector.template element<I>();
0149 } else {
0150 using index_t = detray::traits::index_t<std::decay_t<V>>;
0151 return element(vector, static_cast<index_t>(I));
0152 }
0153 }
0154
0155
0156 struct block_getter {
0157
0158 template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0159 std::size_t mCOL, concepts::scalar scalar_t,
0160 template <typename, std::size_t> class array_t>
0161 DETRAY_HOST_DEVICE constexpr auto operator()(
0162 const matrix<array_t, scalar_t, mROW, mCOL> &m, const std::size_t row,
0163 const std::size_t col) const noexcept {
0164 static_assert(ROWS <= mROW);
0165 static_assert(COLS <= mCOL);
0166 assert(row + ROWS <= mROW);
0167 assert(col + COLS <= mCOL);
0168
0169 using input_matrix_t = matrix<array_t, scalar_t, mROW, mCOL>;
0170 using matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0171
0172 matrix_t res_m;
0173
0174
0175 if constexpr (matrix_t::storage_rows() == input_matrix_t::storage_rows()) {
0176 if (row == 0u) {
0177 for (std::size_t j = col; j < col + COLS; ++j) {
0178 res_m[j - col] = m[j];
0179 }
0180
0181 return res_m;
0182 }
0183 }
0184
0185 for (std::size_t j = col; j < col + COLS; ++j) {
0186 for (std::size_t i = row; i < row + ROWS; ++i) {
0187 res_m[j - col][i - row] = m[j][i];
0188 }
0189 }
0190
0191 return res_m;
0192 }
0193
0194
0195 template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
0196 concepts::scalar scalar_t,
0197 template <typename, std::size_t> class array_t>
0198 DETRAY_HOST_DEVICE constexpr auto vector(
0199 const matrix<array_t, scalar_t, ROWS, COLS> &m, const std::size_t row,
0200 const std::size_t col) const noexcept {
0201 static_assert(SIZE <= ROWS);
0202 static_assert(SIZE <= COLS);
0203 assert(row + SIZE <= ROWS);
0204 assert(col <= COLS);
0205
0206 using input_matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0207 using vector_t = algebra::storage::vector<SIZE, scalar_t, array_t>;
0208
0209 vector_t res_v{};
0210
0211
0212 if constexpr (SIZE == input_matrix_t::storage_rows()) {
0213 if (row == 0u) {
0214 return m[col];
0215 }
0216 }
0217 for (std::size_t i = row; i < row + SIZE; ++i) {
0218 res_v[i - row] = m[col][i];
0219 }
0220
0221 return res_v;
0222 }
0223
0224
0225 template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0226 std::size_t mCOL, concepts::scalar scalar_t,
0227 template <typename, std::size_t> class array_t>
0228 DETRAY_HOST_DEVICE constexpr void set(
0229 matrix<array_t, scalar_t, mROW, mCOL> &m,
0230 const matrix<array_t, scalar_t, ROWS, COLS> &b, const std::size_t row,
0231 const std::size_t col) const noexcept {
0232 static_assert(ROWS <= mROW);
0233 static_assert(COLS <= mCOL);
0234 assert(row + ROWS <= mROW);
0235 assert(col + COLS <= mCOL);
0236
0237 using input_matrix_t = matrix<array_t, scalar_t, mROW, mCOL>;
0238 using matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0239
0240
0241 if constexpr (ROWS == mROW &&
0242 matrix_t::storage_rows() == input_matrix_t::storage_rows()) {
0243 if (row == 0u) {
0244 for (std::size_t j = col; j < col + COLS; ++j) {
0245 m[j] = b[j - col];
0246 }
0247 return;
0248 }
0249 }
0250 for (std::size_t j = col; j < col + COLS; ++j) {
0251 for (std::size_t i = row; i < row + ROWS; ++i) {
0252 m[j][i] = b[j - col][i - row];
0253 }
0254 }
0255 }
0256
0257
0258 template <std::size_t ROWS, std::size_t COLS, std::size_t N,
0259 concepts::scalar scalar_t,
0260 template <typename, std::size_t> class array_t>
0261 DETRAY_HOST_DEVICE constexpr void set(
0262 matrix<array_t, scalar_t, ROWS, COLS> &m,
0263 const detray::algebra::storage::vector<N, scalar_t, array_t> &b,
0264 const std::size_t row, const std::size_t col) const noexcept {
0265 using matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0266 using vector_t = detray::algebra::storage::vector<N, scalar_t, array_t>;
0267
0268 static_assert(N <= ROWS);
0269 assert(row + N <= matrix_t::rows());
0270 assert(row < matrix_t::rows());
0271 assert(col < matrix_t::columns());
0272
0273 if constexpr (ROWS == N &&
0274 matrix_t::storage_rows() == vector_t::simd_size()) {
0275 if (row == 0u) {
0276 m[col] = b;
0277 return;
0278 }
0279 }
0280 for (std::size_t i = row; i < N + row; ++i) {
0281 m[col][i] = b[i - row];
0282 }
0283 }
0284
0285 };
0286
0287
0288 template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0289 std::size_t mCOL, concepts::scalar scalar_t,
0290 template <typename, std::size_t> class array_t>
0291 DETRAY_HOST_DEVICE constexpr auto block(
0292 const matrix<array_t, scalar_t, mROW, mCOL> &m, const std::size_t row,
0293 const std::size_t col) noexcept {
0294 return block_getter{}.template operator()<ROWS, COLS>(m, row, col);
0295 }
0296
0297
0298 template <std::size_t ROWS, std::size_t COLS, std::size_t N,
0299 concepts::scalar scalar_t,
0300 template <typename, std::size_t> class array_t>
0301 DETRAY_HOST_DEVICE constexpr void set_block(
0302 matrix<array_t, scalar_t, ROWS, COLS> &m,
0303 const vector<N, scalar_t, array_t> &b, const std::size_t row,
0304 const std::size_t col) noexcept {
0305 block_getter{}.set(m, b, row, col);
0306 }
0307
0308
0309 template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0310 std::size_t mCOL, concepts::scalar scalar_t,
0311 template <typename, std::size_t> class array_t>
0312 DETRAY_HOST_DEVICE constexpr void set_block(
0313 matrix<array_t, scalar_t, mROW, mCOL> &m,
0314 const matrix<array_t, scalar_t, ROWS, COLS> &b, const std::size_t row,
0315 const std::size_t col) noexcept {
0316 block_getter{}.set(m, b, row, col);
0317 }
0318
0319 }