Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-27 07:23:54

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 #pragma once
0009 
0010 // Project include(s).
0011 #include "detray/algebra/common/vector.hpp"
0012 #include "detray/algebra/concepts.hpp"
0013 #include "detray/algebra/type_traits.hpp"
0014 
0015 // System include(s).
0016 #include <array>
0017 #include <cassert>
0018 
0019 namespace detray::algebra::storage {
0020 
0021 /// Generic matrix type that can take vectors as columns
0022 template <template <typename, std::size_t> class array_t,
0023           concepts::scalar scalar_t, std::size_t ROW, std::size_t COL>
0024 struct DETRAY_ALIGN(alignof(algebra::storage::vector<ROW, scalar_t, array_t>))
0025     matrix {
0026   // The matrix consists of column vectors
0027   using vector_type = algebra::storage::vector<ROW, scalar_t, array_t>;
0028   // Value type: Can be simd types
0029   using scalar_type = scalar_t;
0030 
0031   /// Default constructor
0032   constexpr matrix() = default;
0033 
0034   /// Construct from given column vectors @param v
0035   template <concepts::vector... vector_t>
0036   DETRAY_HOST_DEVICE
0037     requires(sizeof...(vector_t) == COL)
0038   explicit matrix(vector_t &&...v) : m_storage{std::forward<vector_t>(v)...} {}
0039 
0040   /// Subscript operator
0041   /// @{
0042   DETRAY_HOST_DEVICE
0043   constexpr const vector_type &operator[](const std::size_t i) const {
0044     assert(i < COL);
0045     return m_storage[i];
0046   }
0047   DETRAY_HOST_DEVICE
0048   constexpr vector_type &operator[](const std::size_t i) {
0049     assert(i < COL);
0050     return m_storage[i];
0051   }
0052   /// @}
0053 
0054   /// @returns the number of rows
0055   DETRAY_HOST_DEVICE
0056   static consteval std::size_t rows() { return ROW; }
0057 
0058   /// @returns the number of rows of the underlying vector storage
0059   /// @note can be different from the matrix rows due to padding
0060   DETRAY_HOST_DEVICE
0061   static consteval std::size_t storage_rows() {
0062     return vector_type::simd_size();
0063   }
0064 
0065   /// @returns the number of columns
0066   DETRAY_HOST_DEVICE
0067   static consteval std::size_t columns() { return COL; }
0068 
0069  private:
0070   /// Equality operator between two matrices
0071   template <std::size_t R, std::size_t C, typename S,
0072             template <typename, std::size_t> class A>
0073   DETRAY_HOST_DEVICE friend constexpr bool operator==(
0074       const matrix<A, S, R, C> &lhs, const matrix<A, S, R, C> &rhs);
0075 
0076   /// Sets the trailing uninitialized values to zero.
0077   /// @{
0078   // AoS
0079   template <std::size_t... I>
0080   DETRAY_HOST_DEVICE
0081     requires(!std::is_scalar_v<scalar_t>)
0082   constexpr bool equal(const matrix &rhs, std::index_sequence<I...>) const {
0083     return (... && (m_storage[I] == rhs[I]));
0084   }
0085 
0086   // SoA
0087   template <std::size_t... I>
0088   DETRAY_HOST
0089     requires(std::is_scalar_v<scalar_t>)
0090   constexpr bool equal(const matrix &rhs, std::index_sequence<I...>) const {
0091     return (... && ((m_storage[I].get() == rhs[I].get()).isFull()));
0092   }
0093   /// @}
0094 
0095   /// Arithmetic operators
0096   /// @{
0097   template <std::size_t R, std::size_t C, typename S,
0098             template <typename, std::size_t> class A>
0099   DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator+(
0100       const matrix<A, S, R, C> &lhs, const matrix<A, S, R, C> &rhs) noexcept;
0101 
0102   template <std::size_t R, std::size_t C, typename S,
0103             template <typename, std::size_t> class A>
0104   DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator-(
0105       const matrix<A, S, R, C> &lhs, const matrix<A, S, R, C> &rhs) noexcept;
0106 
0107   template <std::size_t R, std::size_t C, typename S1, typename S2,
0108             template <typename, std::size_t> class A>
0109   DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0110       const S2 a, const matrix<A, S1, R, C> &rhs) noexcept;
0111 
0112   template <std::size_t R, std::size_t C, concepts::scalar S1,
0113             concepts::scalar S2, template <typename, std::size_t> class A>
0114   DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0115       const matrix<A, S1, R, C> &lhs, const S2 a) noexcept;
0116 
0117   /// Matrix-vector multiplication
0118   template <std::size_t R, std::size_t C, typename S,
0119             template <typename, std::size_t> class A>
0120   DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0121       const matrix<A, S, R, C> &lhs, const vector<C, S, A> &v) noexcept;
0122 
0123   /// Matrix-matrix multiplication
0124   template <std::size_t LR, std::size_t C, std::size_t RC, typename S,
0125             template <typename, std::size_t> class A>
0126   DETRAY_HOST_DEVICE friend constexpr decltype(auto) operator*(
0127       const matrix<A, S, LR, C> &lhs, const matrix<A, S, C, RC> &rhs) noexcept;
0128   /// @}
0129 
0130   /// Matrix storage
0131   std::array<vector_type, COL> m_storage;
0132 
0133 };  // struct matrix
0134 
0135 /// Get a zero-initialized matrix
0136 template <concepts::matrix matrix_t, std::size_t COLS = matrix_t::columns()>
0137 DETRAY_HOST_DEVICE constexpr matrix_t zero() noexcept {
0138   matrix_t m;
0139 
0140   DETRAY_UNROLL_N(COLS)
0141   for (std::size_t j = 0u; j < COLS; ++j) {
0142     // Fill zero initialized vector
0143     m[j] = typename matrix_t::vector_type{};
0144   }
0145 
0146   return m;
0147 }
0148 
0149 /// Set a matrix to zero
0150 template <concepts::matrix matrix_t>
0151 DETRAY_HOST_DEVICE constexpr void set_zero(matrix_t &m) noexcept {
0152   m = zero<matrix_t>();
0153 }
0154 
0155 /// Build an identity matrix
0156 template <concepts::matrix matrix_t,
0157           std::size_t R = detray::traits::max_rank<matrix_t>>
0158 DETRAY_HOST_DEVICE constexpr matrix_t identity() noexcept {
0159   // Zero initialized
0160   matrix_t m{zero<matrix_t>()};
0161 
0162   DETRAY_UNROLL_N(R)
0163   for (std::size_t i = 0u; i < R; ++i) {
0164     m[i][i] = typename matrix_t::scalar_type(1);
0165   }
0166 
0167   return m;
0168 }
0169 
0170 /// Set a matrix to zero
0171 template <concepts::matrix matrix_t>
0172 DETRAY_HOST_DEVICE constexpr void set_identity(matrix_t &m) noexcept {
0173   m = identity<matrix_t>();
0174 }
0175 
0176 /// Transpose the matrix @param m
0177 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0178           template <typename, std::size_t> class array_t, std::size_t... I>
0179 DETRAY_HOST_DEVICE constexpr auto transpose(
0180     const matrix<array_t, scalar_t, ROW, COL> &m,
0181     std::index_sequence<I...>) noexcept {
0182   using matrix_T_t = matrix<array_t, scalar_t, COL, ROW>;
0183   using column_t = typename matrix_T_t::vector_type;
0184 
0185   matrix_T_t res_m;
0186 
0187   DETRAY_UNROLL_N(ROW)
0188   for (std::size_t j = 0u; j < ROW; ++j) {
0189     res_m[j] = column_t{m[I][j]...};
0190   }
0191 
0192   return res_m;
0193 }
0194 
0195 /// Build an identity matrix
0196 template <concepts::matrix matrix_t>
0197 DETRAY_HOST_DEVICE constexpr auto transpose(const matrix_t &m) noexcept {
0198   return transpose(m, std::make_index_sequence<matrix_t::columns()>());
0199 }
0200 
0201 /// Equality operator between two matrices
0202 template <std::size_t ROW, std::size_t COL, typename scalar_t,
0203           template <typename, std::size_t> class array_t>
0204 DETRAY_HOST_DEVICE constexpr bool operator==(
0205     const matrix<array_t, scalar_t, ROW, COL> &lhs,
0206     const matrix<array_t, scalar_t, ROW, COL> &rhs) {
0207   return lhs.equal(rhs, std::make_index_sequence<COL>());
0208 }
0209 
0210 /// Scalar multiplication
0211 template <concepts::matrix matrix_t, concepts::scalar scalar_t,
0212           std::size_t... J>
0213 DETRAY_HOST_DEVICE constexpr matrix_t matrix_scalar_mul(
0214     scalar_t a, const matrix_t &rhs, std::index_sequence<J...>) noexcept {
0215   using mat_scalar_t = detray::traits::scalar_t<matrix_t>;
0216 
0217   return matrix_t{(static_cast<mat_scalar_t>(a) * rhs[J])...};
0218 }
0219 
0220 /// Matrix addition
0221 template <concepts::matrix matrix_t, std::size_t... J>
0222 DETRAY_HOST_DEVICE constexpr matrix_t matrix_add(
0223     const matrix_t &lhs, const matrix_t &rhs,
0224     std::index_sequence<J...>) noexcept {
0225   return matrix_t{(lhs[J] + rhs[J])...};
0226 }
0227 
0228 template <concepts::matrix matrix_t, std::size_t... J>
0229 DETRAY_HOST_DEVICE constexpr decltype(auto) matrix_sub(
0230     const matrix_t &lhs, const matrix_t &rhs,
0231     std::index_sequence<J...>) noexcept {
0232   return matrix_t{(lhs[J] - rhs[J])...};
0233 }
0234 
0235 /// Arithmetic operators
0236 /// @{
0237 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0238           template <typename, std::size_t> class array_t>
0239 DETRAY_HOST_DEVICE constexpr decltype(auto) operator+(
0240     const matrix<array_t, scalar_t, ROW, COL> &lhs,
0241     const matrix<array_t, scalar_t, ROW, COL> &rhs) noexcept {
0242   using matrix_t = matrix<array_t, scalar_t, ROW, COL>;
0243 
0244   return matrix_add(lhs, rhs, std::make_index_sequence<matrix_t::columns()>());
0245 }
0246 
0247 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0248           template <typename, std::size_t> class array_t>
0249 DETRAY_HOST_DEVICE constexpr decltype(auto) operator-(
0250     const matrix<array_t, scalar_t, ROW, COL> &lhs,
0251     const matrix<array_t, scalar_t, ROW, COL> &rhs) noexcept {
0252   using matrix_t = matrix<array_t, scalar_t, ROW, COL>;
0253 
0254   return matrix_sub(lhs, rhs, std::make_index_sequence<matrix_t::columns()>());
0255 }
0256 
0257 template <std::size_t R, std::size_t C, typename S1, typename S2,
0258           template <typename, std::size_t> class A>
0259 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0260     const S2 a, const matrix<A, S1, R, C> &rhs) noexcept {
0261   using matrix_t = matrix<A, S2, R, C>;
0262 
0263   return matrix_scalar_mul(static_cast<S1>(a), rhs,
0264                            std::make_index_sequence<matrix_t::columns()>());
0265 }
0266 
0267 template <std::size_t R, std::size_t C, concepts::scalar S1,
0268           concepts::scalar S2, template <typename, std::size_t> class A>
0269 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0270     const matrix<A, S1, R, C> &lhs, const S2 a) noexcept {
0271   return static_cast<S1>(a) * lhs;
0272 }
0273 
0274 /// Matrix-vector multiplication
0275 template <std::size_t ROW, std::size_t COL, concepts::scalar scalar_t,
0276           template <typename, std::size_t> class array_t>
0277 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0278     const matrix<array_t, scalar_t, ROW, COL> &lhs,
0279     const vector<COL, scalar_t, array_t> &v) noexcept {
0280   // Init vector
0281   vector<ROW, scalar_t, array_t> res_v{v[0] * lhs[0]};
0282 
0283   // Add the rest per column
0284   DETRAY_UNROLL_N(COL)
0285   for (std::size_t j = 1u; j < COL; ++j) {
0286     // fma
0287     res_v = res_v + v[j] * lhs[j];
0288   }
0289 
0290   return res_v;
0291 }
0292 
0293 /// Matrix-matrix multiplication
0294 template <std::size_t LROW, std::size_t COL, std::size_t RCOL,
0295           concepts::scalar scalar_t,
0296           template <typename, std::size_t> class array_t>
0297 DETRAY_HOST_DEVICE constexpr decltype(auto) operator*(
0298     const matrix<array_t, scalar_t, LROW, COL> &lhs,
0299     const matrix<array_t, scalar_t, COL, RCOL> &rhs) noexcept {
0300   matrix<array_t, scalar_t, LROW, RCOL> res_m;
0301 
0302   DETRAY_UNROLL_N(RCOL)
0303   for (std::size_t j = 0u; j < RCOL; ++j) {
0304     // Init column j
0305     res_m[j] = rhs[j][0] * lhs[0];
0306 
0307     // Add the rest per column
0308     DETRAY_UNROLL_N(COL)
0309     for (std::size_t i = 1u; i < COL; ++i) {
0310       // fma
0311       res_m[j] = res_m[j] + rhs[j][i] * lhs[i];
0312     }
0313   }
0314 
0315   return res_m;
0316 }
0317 /// @}
0318 
0319 }  // namespace detray::algebra::storage