Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:05:58

0001 //----------------------------------*-C++-*----------------------------------//
0002 // Copyright 2023-2024 UT-Battelle, LLC, and other Celeritas developers.
0003 // See the top-level COPYRIGHT file for details.
0004 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0005 //---------------------------------------------------------------------------//
0006 //! \file orange/MatrixUtils.hh
0007 // TODO: split into BLAS and host-only utils
0008 //---------------------------------------------------------------------------//
0009 #pragma once
0010 
0011 #include <cmath>
0012 
0013 #include "corecel/Macros.hh"
0014 #include "corecel/cont/Array.hh"
0015 #include "corecel/math/Algorithms.hh"
0016 #include "corecel/math/Turn.hh"
0017 #include "geocel/Types.hh"
0018 
0019 namespace celeritas
0020 {
0021 //! Policy tags for matrix operations
0022 namespace matrix
0023 {
0024 //---------------------------------------------------------------------------//
0025 struct TransposePolicy
0026 {
0027 };
0028 //! Indicate that the input matrix is transposed
0029 inline constexpr TransposePolicy transpose{};
0030 }  // namespace matrix
0031 
0032 //---------------------------------------------------------------------------//
0033 // Apply a matrix to an array
0034 template<class T, size_type N>
0035 inline CELER_FUNCTION Array<T, N> gemv(T alpha,
0036                                        SquareMatrix<T, N> const& a,
0037                                        Array<T, N> const& x,
0038                                        T beta,
0039                                        Array<T, N> const& y);
0040 
0041 //---------------------------------------------------------------------------//
0042 // Apply the transpose of a matrix to an array
0043 template<class T, size_type N>
0044 inline CELER_FUNCTION Array<T, N> gemv(matrix::TransposePolicy,
0045                                        T alpha,
0046                                        SquareMatrix<T, N> const& a,
0047                                        Array<T, N> const& x,
0048                                        T beta,
0049                                        Array<T, N> const& y);
0050 
0051 //---------------------------------------------------------------------------//
0052 //!@{
0053 //! Apply a matrix or its transpose to an array, without scaling or addition
0054 template<class T, size_type N>
0055 inline CELER_FUNCTION Array<T, N>
0056 gemv(SquareMatrix<T, N> const& a, Array<T, N> const& x)
0057 {
0058     return gemv(T{1}, a, x, T{0}, x);
0059 }
0060 
0061 template<class T, size_type N>
0062 inline CELER_FUNCTION Array<T, N>
0063 gemv(matrix::TransposePolicy, SquareMatrix<T, N> const& a, Array<T, N> const& x)
0064 {
0065     return gemv(matrix::transpose, T{1}, a, x, T{0}, x);
0066 }
0067 //!@}
0068 //---------------------------------------------------------------------------//
0069 // Host-only declarations
0070 // (double and float (and some int) for N=3 are instantiated in MatrixUtils.cc)
0071 //---------------------------------------------------------------------------//
0072 
0073 // Calculate the determinant of a matrix
0074 template<class T>
0075 T determinant(SquareMatrix<T, 3> const& mat);
0076 
0077 // Calculate the trace of a matrix
0078 template<class T>
0079 T trace(SquareMatrix<T, 3> const& mat);
0080 
0081 // Perform a matrix-matrix multiply
0082 template<class T, size_type N>
0083 SquareMatrix<T, N>
0084 gemm(SquareMatrix<T, N> const& a, SquareMatrix<T, N> const& b);
0085 
0086 // Perform a matrix-matrix multiply with A transposed
0087 template<class T, size_type N>
0088 SquareMatrix<T, N> gemm(matrix::TransposePolicy,
0089                         SquareMatrix<T, N> const& a,
0090                         SquareMatrix<T, N> const& b);
0091 
0092 // Normalize and orthogonalize a small, dense matrix
0093 template<class T, size_type N>
0094 void orthonormalize(SquareMatrix<T, N>* mat);
0095 
0096 // Create a C-ordered rotation matrix about an arbitrary axis
0097 SquareMatrixReal3 make_rotation(Real3 const& ax, Turn rev);
0098 
0099 // Create a C-ordered rotation matrix about a cartesian axis
0100 SquareMatrixReal3 make_rotation(Axis ax, Turn rev);
0101 
0102 // Apply a rotation to an existing C-ordered rotation matrix
0103 SquareMatrixReal3 make_rotation(Axis ax, Turn rev, SquareMatrixReal3 const&);
0104 
0105 // Construct a transposed matrix
0106 SquareMatrixReal3 make_transpose(SquareMatrixReal3 const&);
0107 
0108 //---------------------------------------------------------------------------//
0109 // INLINE DEFINITIONS
0110 //---------------------------------------------------------------------------//
0111 /*!
0112  * Naive generalized matrix-vector multiply.
0113  *
0114  * \f[
0115  * z \gets \alpha A x + \beta y
0116  * \f]
0117  *
0118  * This should be equivalent to BLAS' GEMV without transposition. All
0119  * matrix orderings are C-style: mat[i][j] is for row i, column j .
0120  *
0121  * \warning This implementation is limited and slow.
0122  */
0123 template<class T, size_type N>
0124 CELER_FUNCTION Array<T, N> gemv(T alpha,
0125                                 SquareMatrix<T, N> const& a,
0126                                 Array<T, N> const& x,
0127                                 T beta,
0128                                 Array<T, N> const& y)
0129 {
0130     Array<T, N> result;
0131     for (size_type i = 0; i != N; ++i)
0132     {
0133         result[i] = beta * y[i];
0134         for (size_type j = 0; j != N; ++j)
0135         {
0136             result[i] = fma(alpha, a[i][j] * x[j], result[i]);
0137         }
0138     }
0139     return result;
0140 }
0141 
0142 //---------------------------------------------------------------------------//
0143 /*!
0144  * Naive transposed generalized matrix-vector multiply.
0145  *
0146  * \f[
0147  * z \gets \alpha A^T x + \beta y
0148  * \f]
0149  *
0150  * This should be equivalent to BLAS' GEMV with the 't' option. All
0151  * matrix orderings are C-style: mat[i][j] is for row i, column j .
0152  *
0153  * \warning This implementation is limited and slow.
0154  */
0155 template<class T, size_type N>
0156 CELER_FUNCTION Array<T, N> gemv(matrix::TransposePolicy,
0157                                 T alpha,
0158                                 SquareMatrix<T, N> const& a,
0159                                 Array<T, N> const& x,
0160                                 T beta,
0161                                 Array<T, N> const& y)
0162 {
0163     Array<T, N> result;
0164     for (size_type i = 0; i != N; ++i)
0165     {
0166         result[i] = beta * y[i];
0167     }
0168     for (size_type j = 0; j != N; ++j)
0169     {
0170         for (size_type i = 0; i != N; ++i)
0171         {
0172             result[i] = fma(alpha, a[j][i] * x[j], result[i]);
0173         }
0174     }
0175     return result;
0176 }
0177 
0178 //---------------------------------------------------------------------------//
0179 }  // namespace celeritas