Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-19 09:06:43

0001 // This file is part of Eigen, a lightweight C++ template library
0002 // for linear algebra.
0003 //
0004 // Copyright (C) 2012 Gael Guennebaud <gael.guennebaud@inria.fr>
0005 //
0006 // This Source Code Form is subject to the terms of the Mozilla
0007 // Public License v. 2.0. If a copy of the MPL was not distributed
0008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
0009 
0010 #ifndef EIGEN_SPARSELU_GEMM_KERNEL_H
0011 #define EIGEN_SPARSELU_GEMM_KERNEL_H
0012 
0013 namespace RivetEigen {
0014 
0015 namespace internal {
0016 
0017 
0018 /** \internal
0019   * A general matrix-matrix product kernel optimized for the SparseLU factorization.
0020   *  - A, B, and C must be column major
0021   *  - lda and ldc must be multiples of the respective packet size
0022   *  - C must have the same alignment as A
0023   */
0024 template<typename Scalar>
0025 EIGEN_DONT_INLINE
0026 void sparselu_gemm(Index m, Index n, Index d, const Scalar* A, Index lda, const Scalar* B, Index ldb, Scalar* C, Index ldc)
0027 {
0028   using namespace RivetRivetEigen::internal;
0029   
0030   typedef typename packet_traits<Scalar>::type Packet;
0031   enum {
0032     NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
0033     PacketSize = packet_traits<Scalar>::size,
0034     PM = 8,                             // peeling in M
0035     RN = 2,                             // register blocking
0036     RK = NumberOfRegisters>=16 ? 4 : 2, // register blocking
0037     BM = 4096/sizeof(Scalar),           // number of rows of A-C per chunk
0038     SM = PM*PacketSize                  // step along M
0039   };
0040   Index d_end = (d/RK)*RK;    // number of columns of A (rows of B) suitable for full register blocking
0041   Index n_end = (n/RN)*RN;    // number of columns of B-C suitable for processing RN columns at once
0042   Index i0 = internal::first_default_aligned(A,m);
0043   
0044   eigen_internal_assert(((lda%PacketSize)==0) && ((ldc%PacketSize)==0) && (i0==internal::first_default_aligned(C,m)));
0045   
0046   // handle the non aligned rows of A and C without any optimization:
0047   for(Index i=0; i<i0; ++i)
0048   {
0049     for(Index j=0; j<n; ++j)
0050     {
0051       Scalar c = C[i+j*ldc];
0052       for(Index k=0; k<d; ++k)
0053         c += B[k+j*ldb] * A[i+k*lda];
0054       C[i+j*ldc] = c;
0055     }
0056   }
0057   // process the remaining rows per chunk of BM rows
0058   for(Index ib=i0; ib<m; ib+=BM)
0059   {
0060     Index actual_b = std::min<Index>(BM, m-ib);                 // actual number of rows
0061     Index actual_b_end1 = (actual_b/SM)*SM;                   // actual number of rows suitable for peeling
0062     Index actual_b_end2 = (actual_b/PacketSize)*PacketSize;   // actual number of rows suitable for vectorization
0063     
0064     // Let's process two columns of B-C at once
0065     for(Index j=0; j<n_end; j+=RN)
0066     {
0067       const Scalar* Bc0 = B+(j+0)*ldb;
0068       const Scalar* Bc1 = B+(j+1)*ldb;
0069       
0070       for(Index k=0; k<d_end; k+=RK)
0071       {
0072         
0073         // load and expand a RN x RK block of B
0074         Packet b00, b10, b20, b30, b01, b11, b21, b31;
0075                   { b00 = pset1<Packet>(Bc0[0]); }
0076                   { b10 = pset1<Packet>(Bc0[1]); }
0077         if(RK==4) { b20 = pset1<Packet>(Bc0[2]); }
0078         if(RK==4) { b30 = pset1<Packet>(Bc0[3]); }
0079                   { b01 = pset1<Packet>(Bc1[0]); }
0080                   { b11 = pset1<Packet>(Bc1[1]); }
0081         if(RK==4) { b21 = pset1<Packet>(Bc1[2]); }
0082         if(RK==4) { b31 = pset1<Packet>(Bc1[3]); }
0083         
0084         Packet a0, a1, a2, a3, c0, c1, t0, t1;
0085         
0086         const Scalar* A0 = A+ib+(k+0)*lda;
0087         const Scalar* A1 = A+ib+(k+1)*lda;
0088         const Scalar* A2 = A+ib+(k+2)*lda;
0089         const Scalar* A3 = A+ib+(k+3)*lda;
0090         
0091         Scalar* C0 = C+ib+(j+0)*ldc;
0092         Scalar* C1 = C+ib+(j+1)*ldc;
0093         
0094                   a0 = pload<Packet>(A0);
0095                   a1 = pload<Packet>(A1);
0096         if(RK==4)
0097         {
0098           a2 = pload<Packet>(A2);
0099           a3 = pload<Packet>(A3);
0100         }
0101         else
0102         {
0103           // workaround "may be used uninitialized in this function" warning
0104           a2 = a3 = a0;
0105         }
0106         
0107 #define KMADD(c, a, b, tmp) {tmp = b; tmp = pmul(a,tmp); c = padd(c,tmp);}
0108 #define WORK(I)  \
0109                      c0 = pload<Packet>(C0+i+(I)*PacketSize);    \
0110                      c1 = pload<Packet>(C1+i+(I)*PacketSize);    \
0111                      KMADD(c0, a0, b00, t0)                      \
0112                      KMADD(c1, a0, b01, t1)                      \
0113                      a0 = pload<Packet>(A0+i+(I+1)*PacketSize);  \
0114                      KMADD(c0, a1, b10, t0)                      \
0115                      KMADD(c1, a1, b11, t1)                      \
0116                      a1 = pload<Packet>(A1+i+(I+1)*PacketSize);  \
0117           if(RK==4){ KMADD(c0, a2, b20, t0)                     }\
0118           if(RK==4){ KMADD(c1, a2, b21, t1)                     }\
0119           if(RK==4){ a2 = pload<Packet>(A2+i+(I+1)*PacketSize); }\
0120           if(RK==4){ KMADD(c0, a3, b30, t0)                     }\
0121           if(RK==4){ KMADD(c1, a3, b31, t1)                     }\
0122           if(RK==4){ a3 = pload<Packet>(A3+i+(I+1)*PacketSize); }\
0123                      pstore(C0+i+(I)*PacketSize, c0);            \
0124                      pstore(C1+i+(I)*PacketSize, c1)
0125         
0126         // process rows of A' - C' with aggressive vectorization and peeling 
0127         for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
0128         {
0129           EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL1");
0130                     prefetch((A0+i+(5)*PacketSize));
0131                     prefetch((A1+i+(5)*PacketSize));
0132           if(RK==4) prefetch((A2+i+(5)*PacketSize));
0133           if(RK==4) prefetch((A3+i+(5)*PacketSize));
0134 
0135           WORK(0);
0136           WORK(1);
0137           WORK(2);
0138           WORK(3);
0139           WORK(4);
0140           WORK(5);
0141           WORK(6);
0142           WORK(7);
0143         }
0144         // process the remaining rows with vectorization only
0145         for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
0146         {
0147           WORK(0);
0148         }
0149 #undef WORK
0150         // process the remaining rows without vectorization
0151         for(Index i=actual_b_end2; i<actual_b; ++i)
0152         {
0153           if(RK==4)
0154           {
0155             C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
0156             C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1]+A2[i]*Bc1[2]+A3[i]*Bc1[3];
0157           }
0158           else
0159           {
0160             C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
0161             C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1];
0162           }
0163         }
0164         
0165         Bc0 += RK;
0166         Bc1 += RK;
0167       } // peeled loop on k
0168     } // peeled loop on the columns j
0169     // process the last column (we now perform a matrix-vector product)
0170     if((n-n_end)>0)
0171     {
0172       const Scalar* Bc0 = B+(n-1)*ldb;
0173       
0174       for(Index k=0; k<d_end; k+=RK)
0175       {
0176         
0177         // load and expand a 1 x RK block of B
0178         Packet b00, b10, b20, b30;
0179                   b00 = pset1<Packet>(Bc0[0]);
0180                   b10 = pset1<Packet>(Bc0[1]);
0181         if(RK==4) b20 = pset1<Packet>(Bc0[2]);
0182         if(RK==4) b30 = pset1<Packet>(Bc0[3]);
0183         
0184         Packet a0, a1, a2, a3, c0, t0/*, t1*/;
0185         
0186         const Scalar* A0 = A+ib+(k+0)*lda;
0187         const Scalar* A1 = A+ib+(k+1)*lda;
0188         const Scalar* A2 = A+ib+(k+2)*lda;
0189         const Scalar* A3 = A+ib+(k+3)*lda;
0190         
0191         Scalar* C0 = C+ib+(n_end)*ldc;
0192         
0193                   a0 = pload<Packet>(A0);
0194                   a1 = pload<Packet>(A1);
0195         if(RK==4)
0196         {
0197           a2 = pload<Packet>(A2);
0198           a3 = pload<Packet>(A3);
0199         }
0200         else
0201         {
0202           // workaround "may be used uninitialized in this function" warning
0203           a2 = a3 = a0;
0204         }
0205         
0206 #define WORK(I) \
0207                    c0 = pload<Packet>(C0+i+(I)*PacketSize);     \
0208                    KMADD(c0, a0, b00, t0)                       \
0209                    a0 = pload<Packet>(A0+i+(I+1)*PacketSize);   \
0210                    KMADD(c0, a1, b10, t0)                       \
0211                    a1 = pload<Packet>(A1+i+(I+1)*PacketSize);   \
0212         if(RK==4){ KMADD(c0, a2, b20, t0)                      }\
0213         if(RK==4){ a2 = pload<Packet>(A2+i+(I+1)*PacketSize);  }\
0214         if(RK==4){ KMADD(c0, a3, b30, t0)                      }\
0215         if(RK==4){ a3 = pload<Packet>(A3+i+(I+1)*PacketSize);  }\
0216                    pstore(C0+i+(I)*PacketSize, c0);
0217         
0218         // aggressive vectorization and peeling
0219         for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
0220         {
0221           EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL2");
0222           WORK(0);
0223           WORK(1);
0224           WORK(2);
0225           WORK(3);
0226           WORK(4);
0227           WORK(5);
0228           WORK(6);
0229           WORK(7);
0230         }
0231         // vectorization only
0232         for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
0233         {
0234           WORK(0);
0235         }
0236         // remaining scalars
0237         for(Index i=actual_b_end2; i<actual_b; ++i)
0238         {
0239           if(RK==4) 
0240             C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
0241           else
0242             C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
0243         }
0244         
0245         Bc0 += RK;
0246 #undef WORK
0247       }
0248     }
0249     
0250     // process the last columns of A, corresponding to the last rows of B
0251     Index rd = d-d_end;
0252     if(rd>0)
0253     {
0254       for(Index j=0; j<n; ++j)
0255       {
0256         enum {
0257           Alignment = PacketSize>1 ? Aligned : 0
0258         };
0259         typedef Map<Matrix<Scalar,Dynamic,1>, Alignment > MapVector;
0260         typedef Map<const Matrix<Scalar,Dynamic,1>, Alignment > ConstMapVector;
0261         if(rd==1)       MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b);
0262         
0263         else if(rd==2)  MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
0264                                                         + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b);
0265         
0266         else            MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
0267                                                         + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b)
0268                                                         + B[2+d_end+j*ldb] * ConstMapVector(A+(d_end+2)*lda+ib, actual_b);
0269       }
0270     }
0271   
0272   } // blocking on the rows of A and C
0273 }
0274 #undef KMADD
0275 
0276 } // namespace internal
0277 
0278 } // namespace RivetEigen
0279 
0280 #endif // EIGEN_SPARSELU_GEMM_KERNEL_H