File indexing completed on 2025-04-19 09:06:43
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_SPARSELU_GEMM_KERNEL_H
0011 #define EIGEN_SPARSELU_GEMM_KERNEL_H
0012
0013 namespace RivetEigen {
0014
0015 namespace internal {
0016
0017
0018
0019
0020
0021
0022
0023
0024 template<typename Scalar>
0025 EIGEN_DONT_INLINE
0026 void sparselu_gemm(Index m, Index n, Index d, const Scalar* A, Index lda, const Scalar* B, Index ldb, Scalar* C, Index ldc)
0027 {
0028 using namespace RivetRivetEigen::internal;
0029
0030 typedef typename packet_traits<Scalar>::type Packet;
0031 enum {
0032 NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
0033 PacketSize = packet_traits<Scalar>::size,
0034 PM = 8,
0035 RN = 2,
0036 RK = NumberOfRegisters>=16 ? 4 : 2,
0037 BM = 4096/sizeof(Scalar),
0038 SM = PM*PacketSize
0039 };
0040 Index d_end = (d/RK)*RK;
0041 Index n_end = (n/RN)*RN;
0042 Index i0 = internal::first_default_aligned(A,m);
0043
0044 eigen_internal_assert(((lda%PacketSize)==0) && ((ldc%PacketSize)==0) && (i0==internal::first_default_aligned(C,m)));
0045
0046
0047 for(Index i=0; i<i0; ++i)
0048 {
0049 for(Index j=0; j<n; ++j)
0050 {
0051 Scalar c = C[i+j*ldc];
0052 for(Index k=0; k<d; ++k)
0053 c += B[k+j*ldb] * A[i+k*lda];
0054 C[i+j*ldc] = c;
0055 }
0056 }
0057
0058 for(Index ib=i0; ib<m; ib+=BM)
0059 {
0060 Index actual_b = std::min<Index>(BM, m-ib);
0061 Index actual_b_end1 = (actual_b/SM)*SM;
0062 Index actual_b_end2 = (actual_b/PacketSize)*PacketSize;
0063
0064
0065 for(Index j=0; j<n_end; j+=RN)
0066 {
0067 const Scalar* Bc0 = B+(j+0)*ldb;
0068 const Scalar* Bc1 = B+(j+1)*ldb;
0069
0070 for(Index k=0; k<d_end; k+=RK)
0071 {
0072
0073
0074 Packet b00, b10, b20, b30, b01, b11, b21, b31;
0075 { b00 = pset1<Packet>(Bc0[0]); }
0076 { b10 = pset1<Packet>(Bc0[1]); }
0077 if(RK==4) { b20 = pset1<Packet>(Bc0[2]); }
0078 if(RK==4) { b30 = pset1<Packet>(Bc0[3]); }
0079 { b01 = pset1<Packet>(Bc1[0]); }
0080 { b11 = pset1<Packet>(Bc1[1]); }
0081 if(RK==4) { b21 = pset1<Packet>(Bc1[2]); }
0082 if(RK==4) { b31 = pset1<Packet>(Bc1[3]); }
0083
0084 Packet a0, a1, a2, a3, c0, c1, t0, t1;
0085
0086 const Scalar* A0 = A+ib+(k+0)*lda;
0087 const Scalar* A1 = A+ib+(k+1)*lda;
0088 const Scalar* A2 = A+ib+(k+2)*lda;
0089 const Scalar* A3 = A+ib+(k+3)*lda;
0090
0091 Scalar* C0 = C+ib+(j+0)*ldc;
0092 Scalar* C1 = C+ib+(j+1)*ldc;
0093
0094 a0 = pload<Packet>(A0);
0095 a1 = pload<Packet>(A1);
0096 if(RK==4)
0097 {
0098 a2 = pload<Packet>(A2);
0099 a3 = pload<Packet>(A3);
0100 }
0101 else
0102 {
0103
0104 a2 = a3 = a0;
0105 }
0106
0107 #define KMADD(c, a, b, tmp) {tmp = b; tmp = pmul(a,tmp); c = padd(c,tmp);}
0108 #define WORK(I) \
0109 c0 = pload<Packet>(C0+i+(I)*PacketSize); \
0110 c1 = pload<Packet>(C1+i+(I)*PacketSize); \
0111 KMADD(c0, a0, b00, t0) \
0112 KMADD(c1, a0, b01, t1) \
0113 a0 = pload<Packet>(A0+i+(I+1)*PacketSize); \
0114 KMADD(c0, a1, b10, t0) \
0115 KMADD(c1, a1, b11, t1) \
0116 a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
0117 if(RK==4){ KMADD(c0, a2, b20, t0) }\
0118 if(RK==4){ KMADD(c1, a2, b21, t1) }\
0119 if(RK==4){ a2 = pload<Packet>(A2+i+(I+1)*PacketSize); }\
0120 if(RK==4){ KMADD(c0, a3, b30, t0) }\
0121 if(RK==4){ KMADD(c1, a3, b31, t1) }\
0122 if(RK==4){ a3 = pload<Packet>(A3+i+(I+1)*PacketSize); }\
0123 pstore(C0+i+(I)*PacketSize, c0); \
0124 pstore(C1+i+(I)*PacketSize, c1)
0125
0126
0127 for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
0128 {
0129 EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL1");
0130 prefetch((A0+i+(5)*PacketSize));
0131 prefetch((A1+i+(5)*PacketSize));
0132 if(RK==4) prefetch((A2+i+(5)*PacketSize));
0133 if(RK==4) prefetch((A3+i+(5)*PacketSize));
0134
0135 WORK(0);
0136 WORK(1);
0137 WORK(2);
0138 WORK(3);
0139 WORK(4);
0140 WORK(5);
0141 WORK(6);
0142 WORK(7);
0143 }
0144
0145 for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
0146 {
0147 WORK(0);
0148 }
0149 #undef WORK
0150
0151 for(Index i=actual_b_end2; i<actual_b; ++i)
0152 {
0153 if(RK==4)
0154 {
0155 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
0156 C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1]+A2[i]*Bc1[2]+A3[i]*Bc1[3];
0157 }
0158 else
0159 {
0160 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
0161 C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1];
0162 }
0163 }
0164
0165 Bc0 += RK;
0166 Bc1 += RK;
0167 }
0168 }
0169
0170 if((n-n_end)>0)
0171 {
0172 const Scalar* Bc0 = B+(n-1)*ldb;
0173
0174 for(Index k=0; k<d_end; k+=RK)
0175 {
0176
0177
0178 Packet b00, b10, b20, b30;
0179 b00 = pset1<Packet>(Bc0[0]);
0180 b10 = pset1<Packet>(Bc0[1]);
0181 if(RK==4) b20 = pset1<Packet>(Bc0[2]);
0182 if(RK==4) b30 = pset1<Packet>(Bc0[3]);
0183
0184 Packet a0, a1, a2, a3, c0, t0;
0185
0186 const Scalar* A0 = A+ib+(k+0)*lda;
0187 const Scalar* A1 = A+ib+(k+1)*lda;
0188 const Scalar* A2 = A+ib+(k+2)*lda;
0189 const Scalar* A3 = A+ib+(k+3)*lda;
0190
0191 Scalar* C0 = C+ib+(n_end)*ldc;
0192
0193 a0 = pload<Packet>(A0);
0194 a1 = pload<Packet>(A1);
0195 if(RK==4)
0196 {
0197 a2 = pload<Packet>(A2);
0198 a3 = pload<Packet>(A3);
0199 }
0200 else
0201 {
0202
0203 a2 = a3 = a0;
0204 }
0205
0206 #define WORK(I) \
0207 c0 = pload<Packet>(C0+i+(I)*PacketSize); \
0208 KMADD(c0, a0, b00, t0) \
0209 a0 = pload<Packet>(A0+i+(I+1)*PacketSize); \
0210 KMADD(c0, a1, b10, t0) \
0211 a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
0212 if(RK==4){ KMADD(c0, a2, b20, t0) }\
0213 if(RK==4){ a2 = pload<Packet>(A2+i+(I+1)*PacketSize); }\
0214 if(RK==4){ KMADD(c0, a3, b30, t0) }\
0215 if(RK==4){ a3 = pload<Packet>(A3+i+(I+1)*PacketSize); }\
0216 pstore(C0+i+(I)*PacketSize, c0);
0217
0218
0219 for(Index i=0; i<actual_b_end1; i+=PacketSize*8)
0220 {
0221 EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL2");
0222 WORK(0);
0223 WORK(1);
0224 WORK(2);
0225 WORK(3);
0226 WORK(4);
0227 WORK(5);
0228 WORK(6);
0229 WORK(7);
0230 }
0231
0232 for(Index i=actual_b_end1; i<actual_b_end2; i+=PacketSize)
0233 {
0234 WORK(0);
0235 }
0236
0237 for(Index i=actual_b_end2; i<actual_b; ++i)
0238 {
0239 if(RK==4)
0240 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
0241 else
0242 C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
0243 }
0244
0245 Bc0 += RK;
0246 #undef WORK
0247 }
0248 }
0249
0250
0251 Index rd = d-d_end;
0252 if(rd>0)
0253 {
0254 for(Index j=0; j<n; ++j)
0255 {
0256 enum {
0257 Alignment = PacketSize>1 ? Aligned : 0
0258 };
0259 typedef Map<Matrix<Scalar,Dynamic,1>, Alignment > MapVector;
0260 typedef Map<const Matrix<Scalar,Dynamic,1>, Alignment > ConstMapVector;
0261 if(rd==1) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b);
0262
0263 else if(rd==2) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
0264 + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b);
0265
0266 else MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)
0267 + B[1+d_end+j*ldb] * ConstMapVector(A+(d_end+1)*lda+ib, actual_b)
0268 + B[2+d_end+j*ldb] * ConstMapVector(A+(d_end+2)*lda+ib, actual_b);
0269 }
0270 }
0271
0272 }
0273 }
0274 #undef KMADD
0275
0276 }
0277
0278 }
0279
0280 #endif