Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:07:51

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file orange/MatrixUtils.hh
0006 // TODO: split into BLAS and host-only utils
0007 //---------------------------------------------------------------------------//
0008 #pragma once
0009 
0010 #include <cmath>
0011 
0012 #include "corecel/Macros.hh"
0013 #include "corecel/cont/Array.hh"
0014 #include "corecel/math/Algorithms.hh"
0015 #include "corecel/math/Turn.hh"
0016 #include "geocel/Types.hh"
0017 
0018 namespace celeritas
0019 {
0020 //! Policy tags for matrix operations
0021 namespace matrix
0022 {
0023 //---------------------------------------------------------------------------//
0024 struct TransposePolicy
0025 {
0026 };
0027 //! Indicate that the input matrix is transposed
0028 inline constexpr TransposePolicy transpose{};
0029 }  // namespace matrix
0030 
0031 //---------------------------------------------------------------------------//
0032 // Apply a matrix to an array
0033 template<class T, size_type N>
0034 inline CELER_FUNCTION Array<T, N> gemv(T alpha,
0035                                        SquareMatrix<T, N> const& a,
0036                                        Array<T, N> const& x,
0037                                        T beta,
0038                                        Array<T, N> const& y);
0039 
0040 //---------------------------------------------------------------------------//
0041 // Apply the transpose of a matrix to an array
0042 template<class T, size_type N>
0043 inline CELER_FUNCTION Array<T, N> gemv(matrix::TransposePolicy,
0044                                        T alpha,
0045                                        SquareMatrix<T, N> const& a,
0046                                        Array<T, N> const& x,
0047                                        T beta,
0048                                        Array<T, N> const& y);
0049 
0050 //---------------------------------------------------------------------------//
0051 //!@{
0052 //! Apply a matrix or its transpose to an array, without scaling or addition
0053 template<class T, size_type N>
0054 inline CELER_FUNCTION Array<T, N>
0055 gemv(SquareMatrix<T, N> const& a, Array<T, N> const& x)
0056 {
0057     return gemv(T{1}, a, x, T{0}, x);
0058 }
0059 
0060 template<class T, size_type N>
0061 inline CELER_FUNCTION Array<T, N>
0062 gemv(matrix::TransposePolicy, SquareMatrix<T, N> const& a, Array<T, N> const& x)
0063 {
0064     return gemv(matrix::transpose, T{1}, a, x, T{0}, x);
0065 }
0066 //!@}
0067 //---------------------------------------------------------------------------//
0068 // Host-only declarations
0069 // (double and float (and some int) for N=3 are instantiated in MatrixUtils.cc)
0070 //---------------------------------------------------------------------------//
0071 
0072 // Calculate the determinant of a matrix
0073 template<class T>
0074 T determinant(SquareMatrix<T, 3> const& mat);
0075 
0076 // Calculate the trace of a matrix
0077 template<class T>
0078 T trace(SquareMatrix<T, 3> const& mat);
0079 
0080 // Perform a matrix-matrix multiply
0081 template<class T, size_type N>
0082 SquareMatrix<T, N>
0083 gemm(SquareMatrix<T, N> const& a, SquareMatrix<T, N> const& b);
0084 
0085 // Perform a matrix-matrix multiply with A transposed
0086 template<class T, size_type N>
0087 SquareMatrix<T, N> gemm(matrix::TransposePolicy,
0088                         SquareMatrix<T, N> const& a,
0089                         SquareMatrix<T, N> const& b);
0090 
0091 // Normalize and orthogonalize a small, dense matrix
0092 template<class T, size_type N>
0093 void orthonormalize(SquareMatrix<T, N>* mat);
0094 
0095 // Create a C-ordered rotation matrix about an arbitrary axis
0096 SquareMatrixReal3 make_rotation(Real3 const& ax, Turn rev);
0097 
0098 // Create a C-ordered rotation matrix about a cartesian axis
0099 SquareMatrixReal3 make_rotation(Axis ax, Turn rev);
0100 
0101 // Apply a rotation to an existing C-ordered rotation matrix
0102 SquareMatrixReal3 make_rotation(Axis ax, Turn rev, SquareMatrixReal3 const&);
0103 
0104 // Construct a transposed matrix
0105 SquareMatrixReal3 make_transpose(SquareMatrixReal3 const&);
0106 
0107 //---------------------------------------------------------------------------//
0108 // INLINE DEFINITIONS
0109 //---------------------------------------------------------------------------//
0110 /*!
0111  * Naive generalized matrix-vector multiply.
0112  *
0113  * \f[
0114  * z \gets \alpha A x + \beta y
0115  * \f]
0116  *
0117  * This should be equivalent to BLAS' GEMV without transposition. All
0118  * matrix orderings are C-style: mat[i][j] is for row i, column j .
0119  *
0120  * \warning This implementation is limited and slow.
0121  */
0122 template<class T, size_type N>
0123 CELER_FUNCTION Array<T, N> gemv(T alpha,
0124                                 SquareMatrix<T, N> const& a,
0125                                 Array<T, N> const& x,
0126                                 T beta,
0127                                 Array<T, N> const& y)
0128 {
0129     Array<T, N> result;
0130     for (size_type i = 0; i != N; ++i)
0131     {
0132         result[i] = beta * y[i];
0133         for (size_type j = 0; j != N; ++j)
0134         {
0135             result[i] = fma(alpha, a[i][j] * x[j], result[i]);
0136         }
0137     }
0138     return result;
0139 }
0140 
0141 //---------------------------------------------------------------------------//
0142 /*!
0143  * Naive transposed generalized matrix-vector multiply.
0144  *
0145  * \f[
0146  * z \gets \alpha A^T x + \beta y
0147  * \f]
0148  *
0149  * This should be equivalent to BLAS' GEMV with the 't' option. All
0150  * matrix orderings are C-style: mat[i][j] is for row i, column j .
0151  *
0152  * \warning This implementation is limited and slow.
0153  */
0154 template<class T, size_type N>
0155 CELER_FUNCTION Array<T, N> gemv(matrix::TransposePolicy,
0156                                 T alpha,
0157                                 SquareMatrix<T, N> const& a,
0158                                 Array<T, N> const& x,
0159                                 T beta,
0160                                 Array<T, N> const& y)
0161 {
0162     Array<T, N> result;
0163     for (size_type i = 0; i != N; ++i)
0164     {
0165         result[i] = beta * y[i];
0166     }
0167     for (size_type j = 0; j != N; ++j)
0168     {
0169         for (size_type i = 0; i != N; ++i)
0170         {
0171             result[i] = fma(alpha, a[j][i] * x[j], result[i]);
0172         }
0173     }
0174     return result;
0175 }
0176 
0177 //---------------------------------------------------------------------------//
0178 }  // namespace celeritas