Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:23:54

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 #pragma once
0009 
0010 // Project include(s).
0011 #include "detray/algebra/common/matrix.hpp"
0012 #include "detray/algebra/common/vector.hpp"
0013 #include "detray/algebra/concepts.hpp"
0014 
0015 // System include(s).
0016 #include <cassert>
0017 
0018 namespace detray::algebra::storage {
0019 
0020 /// Functor used to access elements of a matrix
0021 struct element_getter {
0022   /// Get const access to a matrix element
0023   template <template <typename, std::size_t> class array_t,
0024             concepts::scalar scalar_t, std::size_t ROW, std::size_t COL>
0025   DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0026       const matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row,
0027       std::size_t col) const {
0028     // Make sure that the indices are valid.
0029     assert(row < ROW);
0030     assert(col < COL);
0031 
0032     // Return the selected element.
0033     return m[col][row];
0034   }
0035 
0036   /// Get non-const access to a matrix element
0037   template <template <typename, std::size_t> class array_t,
0038             concepts::scalar scalar_t, std::size_t ROW, std::size_t COL>
0039   DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0040       matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row,
0041       std::size_t col) const {
0042     assert(row < ROW);
0043     assert(col < COL);
0044 
0045     return m[col][row];
0046   }
0047 
0048   /// Get const access to a matrix element
0049   template <template <typename, std::size_t> class array_t,
0050             concepts::scalar scalar_t, std::size_t ROW>
0051   DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0052       const matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) const {
0053     assert(row < ROW);
0054     return m[0][row];
0055   }
0056 
0057   /// Get non-const access to a matrix element
0058   template <template <typename, std::size_t> class array_t,
0059             concepts::scalar scalar_t, std::size_t ROW>
0060   DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0061       matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) const {
0062     assert(row < ROW);
0063     return m[0][row];
0064   }
0065 
0066   /// Get const access to a vector element
0067   template <template <typename, std::size_t> class array_t,
0068             concepts::scalar scalar_t, std::size_t N>
0069   DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0070       const vector<N, scalar_t, array_t> &v, std::size_t row) const {
0071     assert(row < N);
0072     return v[row];
0073   }
0074 
0075   /// Get non-const access to a vector element
0076   template <template <typename, std::size_t> class array_t,
0077             concepts::scalar scalar_t, std::size_t N>
0078   DETRAY_HOST_DEVICE constexpr decltype(auto) operator()(
0079       vector<N, scalar_t, array_t> &v, std::size_t row) const {
0080     assert(row < N);
0081     return v[row];
0082   }
0083 
0084 };  // struct element_getter
0085 
0086 /// Function extracting an element from a matrix (const)
0087 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0088           template <typename, std::size_t> class array_t>
0089 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0090     const matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row,
0091     std::size_t col) {
0092   return element_getter{}(m, row, col);
0093 }
0094 
0095 /// Function extracting an element from a matrix (non-const)
0096 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0097           template <typename, std::size_t> class array_t>
0098 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0099     matrix<array_t, scalar_t, ROW, COL> &m, std::size_t row, std::size_t col) {
0100   return element_getter{}(m, row, col);
0101 }
0102 
0103 /// Function extracting an element from a 1D matrix (const)
0104 template <std::size_t ROW, concepts::scalar scalar_t,
0105           template <typename, std::size_t> class array_t>
0106 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0107     const matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) {
0108   return element_getter{}(m, row);
0109 }
0110 
0111 /// Function extracting an element from a 1D matrix (non-const)
0112 template <std::size_t ROW, concepts::scalar scalar_t,
0113           template <typename, std::size_t> class array_t>
0114 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0115     matrix<array_t, scalar_t, ROW, 1> &m, std::size_t row) {
0116   return element_getter{}(m, row);
0117 }
0118 
0119 /// Function extracting an element from a vector (const)
0120 template <std::size_t N, concepts::scalar scalar_t,
0121           template <typename, std::size_t> class array_t>
0122 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0123     const vector<N, scalar_t, array_t> &v, std::size_t row) {
0124   return element_getter{}(v, row);
0125 }
0126 
0127 /// Function extracting an element from a vector (non-const)
0128 template <std::size_t N, concepts::scalar scalar_t,
0129           template <typename, std::size_t> class array_t>
0130 DETRAY_HOST_DEVICE constexpr decltype(auto) element(
0131     vector<N, scalar_t, array_t> &v, std::size_t row) {
0132   return element_getter{}(v, row);
0133 }
0134 
0135 template <std::size_t I, std::size_t J, concepts::matrix M>
0136 DETRAY_HOST_DEVICE decltype(auto) element(M &matrix) {
0137   if constexpr (concepts::has_compile_time_2d_access<M>) {
0138     return matrix.template element<I, J>();
0139   } else {
0140     using index_t = detray::traits::index_t<std::decay_t<M>>;
0141     return element(matrix, static_cast<index_t>(I), static_cast<index_t>(J));
0142   }
0143 }
0144 
0145 template <std::size_t I, concepts::vector V>
0146 DETRAY_HOST_DEVICE decltype(auto) element(V &vector) {
0147   if constexpr (concepts::has_compile_time_1d_access<V>) {
0148     return vector.template element<I>();
0149   } else {
0150     using index_t = detray::traits::index_t<std::decay_t<V>>;
0151     return element(vector, static_cast<index_t>(I));
0152   }
0153 }
0154 
0155 /// Functor used to access a submatrix of a matrix
0156 struct block_getter {
0157   /// Get a block of a const matrix
0158   template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0159             std::size_t mCOL, concepts::scalar scalar_t,
0160             template <typename, std::size_t> class array_t>
0161   DETRAY_HOST_DEVICE constexpr auto operator()(
0162       const matrix<array_t, scalar_t, mROW, mCOL> &m, const std::size_t row,
0163       const std::size_t col) const noexcept {
0164     static_assert(ROWS <= mROW);
0165     static_assert(COLS <= mCOL);
0166     assert(row + ROWS <= mROW);
0167     assert(col + COLS <= mCOL);
0168 
0169     using input_matrix_t = matrix<array_t, scalar_t, mROW, mCOL>;
0170     using matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0171 
0172     matrix_t res_m;
0173 
0174     // Don't access single elements in underlying vectors unless necessary
0175     if constexpr (matrix_t::storage_rows() == input_matrix_t::storage_rows()) {
0176       if (row == 0u) {
0177         for (std::size_t j = col; j < col + COLS; ++j) {
0178           res_m[j - col] = m[j];
0179         }
0180 
0181         return res_m;
0182       }
0183     }
0184 
0185     for (std::size_t j = col; j < col + COLS; ++j) {
0186       for (std::size_t i = row; i < row + ROWS; ++i) {
0187         res_m[j - col][i - row] = m[j][i];
0188       }
0189     }
0190 
0191     return res_m;
0192   }
0193 
0194   /// Get a vector of a const matrix
0195   template <std::size_t SIZE, std::size_t ROWS, std::size_t COLS,
0196             concepts::scalar scalar_t,
0197             template <typename, std::size_t> class array_t>
0198   DETRAY_HOST_DEVICE constexpr auto vector(
0199       const matrix<array_t, scalar_t, ROWS, COLS> &m, const std::size_t row,
0200       const std::size_t col) const noexcept {
0201     static_assert(SIZE <= ROWS);
0202     static_assert(SIZE <= COLS);
0203     assert(row + SIZE <= ROWS);
0204     assert(col <= COLS);
0205 
0206     using input_matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0207     using vector_t = algebra::storage::vector<SIZE, scalar_t, array_t>;
0208 
0209     vector_t res_v{};
0210 
0211     // Don't access single elements in underlying vectors unless necessary
0212     if constexpr (SIZE == input_matrix_t::storage_rows()) {
0213       if (row == 0u) {
0214         return m[col];
0215       }
0216     }
0217     for (std::size_t i = row; i < row + SIZE; ++i) {
0218       res_v[i - row] = m[col][i];
0219     }
0220 
0221     return res_v;
0222   }
0223 
0224   /// Get a block of a const matrix
0225   template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0226             std::size_t mCOL, concepts::scalar scalar_t,
0227             template <typename, std::size_t> class array_t>
0228   DETRAY_HOST_DEVICE constexpr void set(
0229       matrix<array_t, scalar_t, mROW, mCOL> &m,
0230       const matrix<array_t, scalar_t, ROWS, COLS> &b, const std::size_t row,
0231       const std::size_t col) const noexcept {
0232     static_assert(ROWS <= mROW);
0233     static_assert(COLS <= mCOL);
0234     assert(row + ROWS <= mROW);
0235     assert(col + COLS <= mCOL);
0236 
0237     using input_matrix_t = matrix<array_t, scalar_t, mROW, mCOL>;
0238     using matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0239 
0240     // Don't access single elements in underlying vectors unless necessary
0241     if constexpr (ROWS == mROW &&
0242                   matrix_t::storage_rows() == input_matrix_t::storage_rows()) {
0243       if (row == 0u) {
0244         for (std::size_t j = col; j < col + COLS; ++j) {
0245           m[j] = b[j - col];
0246         }
0247         return;
0248       }
0249     }
0250     for (std::size_t j = col; j < col + COLS; ++j) {
0251       for (std::size_t i = row; i < row + ROWS; ++i) {
0252         m[j][i] = b[j - col][i - row];
0253       }
0254     }
0255   }
0256 
0257   /// Operator setting a block with a vector
0258   template <std::size_t ROWS, std::size_t COLS, std::size_t N,
0259             concepts::scalar scalar_t,
0260             template <typename, std::size_t> class array_t>
0261   DETRAY_HOST_DEVICE constexpr void set(
0262       matrix<array_t, scalar_t, ROWS, COLS> &m,
0263       const detray::algebra::storage::vector<N, scalar_t, array_t> &b,
0264       const std::size_t row, const std::size_t col) const noexcept {
0265     using matrix_t = matrix<array_t, scalar_t, ROWS, COLS>;
0266     using vector_t = detray::algebra::storage::vector<N, scalar_t, array_t>;
0267 
0268     static_assert(N <= ROWS);
0269     assert(row + N <= matrix_t::rows());
0270     assert(row < matrix_t::rows());
0271     assert(col < matrix_t::columns());
0272 
0273     if constexpr (ROWS == N &&
0274                   matrix_t::storage_rows() == vector_t::simd_size()) {
0275       if (row == 0u) {
0276         m[col] = b;
0277         return;
0278       }
0279     }
0280     for (std::size_t i = row; i < N + row; ++i) {
0281       m[col][i] = b[i - row];
0282     }
0283   }
0284 
0285 };  // struct block_getter
0286 
0287 /// Get a block of a const matrix
0288 template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0289           std::size_t mCOL, concepts::scalar scalar_t,
0290           template <typename, std::size_t> class array_t>
0291 DETRAY_HOST_DEVICE constexpr auto block(
0292     const matrix<array_t, scalar_t, mROW, mCOL> &m, const std::size_t row,
0293     const std::size_t col) noexcept {
0294   return block_getter{}.template operator()<ROWS, COLS>(m, row, col);
0295 }
0296 
0297 /// Set a block of a const matrix from a vector
0298 template <std::size_t ROWS, std::size_t COLS, std::size_t N,
0299           concepts::scalar scalar_t,
0300           template <typename, std::size_t> class array_t>
0301 DETRAY_HOST_DEVICE constexpr void set_block(
0302     matrix<array_t, scalar_t, ROWS, COLS> &m,
0303     const vector<N, scalar_t, array_t> &b, const std::size_t row,
0304     const std::size_t col) noexcept {
0305   block_getter{}.set(m, b, row, col);
0306 }
0307 
0308 /// Set a block of a const matrix from another matrix
0309 template <std::size_t ROWS, std::size_t COLS, std::size_t mROW,
0310           std::size_t mCOL, concepts::scalar scalar_t,
0311           template <typename, std::size_t> class array_t>
0312 DETRAY_HOST_DEVICE constexpr void set_block(
0313     matrix<array_t, scalar_t, mROW, mCOL> &m,
0314     const matrix<array_t, scalar_t, ROWS, COLS> &b, const std::size_t row,
0315     const std::size_t col) noexcept {
0316   block_getter{}.set(m, b, row, col);
0317 }
0318 
0319 }  // namespace detray::algebra::storage