Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-19 09:51:36

0001 // This file is part of Eigen, a lightweight C++ template library
0002 // for linear algebra.
0003 //
0004 // Copyright (C) 2014 Benoit Steiner (benoit.steiner.goog@gmail.com)
0005 //
0006 // This Source Code Form is subject to the terms of the Mozilla
0007 // Public License v. 2.0. If a copy of the MPL was not distributed
0008 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
0009 
0010 #ifndef EIGEN_COMPLEX_AVX_H
0011 #define EIGEN_COMPLEX_AVX_H
0012 
0013 namespace Eigen {
0014 
0015 namespace internal {
0016 
0017 //---------- float ----------
0018 struct Packet4cf
0019 {
0020   EIGEN_STRONG_INLINE Packet4cf() {}
0021   EIGEN_STRONG_INLINE explicit Packet4cf(const __m256& a) : v(a) {}
0022   __m256  v;
0023 };
0024 
0025 #ifndef EIGEN_VECTORIZE_AVX512
0026 template<> struct packet_traits<std::complex<float> >  : default_packet_traits
0027 {
0028   typedef Packet4cf type;
0029   typedef Packet2cf half;
0030   enum {
0031     Vectorizable = 1,
0032     AlignedOnScalar = 1,
0033     size = 4,
0034     HasHalfPacket = 1,
0035 
0036     HasAdd    = 1,
0037     HasSub    = 1,
0038     HasMul    = 1,
0039     HasDiv    = 1,
0040     HasNegate = 1,
0041     HasSqrt   = 1,
0042     HasAbs    = 0,
0043     HasAbs2   = 0,
0044     HasMin    = 0,
0045     HasMax    = 0,
0046     HasSetLinear = 0
0047   };
0048 };
0049 #endif
0050 
0051 template<> struct unpacket_traits<Packet4cf> {
0052   typedef std::complex<float> type;
0053   typedef Packet2cf half;
0054   typedef Packet8f as_real;
0055   enum {
0056     size=4,
0057     alignment=Aligned32,
0058     vectorizable=true,
0059     masked_load_available=false,
0060     masked_store_available=false
0061   };
0062 };
0063 
0064 template<> EIGEN_STRONG_INLINE Packet4cf padd<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); }
0065 template<> EIGEN_STRONG_INLINE Packet4cf psub<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); }
0066 template<> EIGEN_STRONG_INLINE Packet4cf pnegate(const Packet4cf& a)
0067 {
0068   return Packet4cf(pnegate(a.v));
0069 }
0070 template<> EIGEN_STRONG_INLINE Packet4cf pconj(const Packet4cf& a)
0071 {
0072   const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000));
0073   return Packet4cf(_mm256_xor_ps(a.v,mask));
0074 }
0075 
0076 template<> EIGEN_STRONG_INLINE Packet4cf pmul<Packet4cf>(const Packet4cf& a, const Packet4cf& b)
0077 {
0078   __m256 tmp1 = _mm256_mul_ps(_mm256_moveldup_ps(a.v), b.v);
0079   __m256 tmp2 = _mm256_mul_ps(_mm256_movehdup_ps(a.v), _mm256_permute_ps(b.v, _MM_SHUFFLE(2,3,0,1)));
0080   __m256 result = _mm256_addsub_ps(tmp1, tmp2);
0081   return Packet4cf(result);
0082 }
0083 
0084 template <>
0085 EIGEN_STRONG_INLINE Packet4cf pcmp_eq(const Packet4cf& a, const Packet4cf& b) {
0086   __m256 eq = _mm256_cmp_ps(a.v, b.v, _CMP_EQ_OQ);
0087   return Packet4cf(_mm256_and_ps(eq, _mm256_permute_ps(eq, 0xb1)));
0088 }
0089 
0090 template<> EIGEN_STRONG_INLINE Packet4cf ptrue<Packet4cf>(const Packet4cf& a) { return Packet4cf(ptrue(Packet8f(a.v))); }
0091 template<> EIGEN_STRONG_INLINE Packet4cf pand   <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_and_ps(a.v,b.v)); }
0092 template<> EIGEN_STRONG_INLINE Packet4cf por    <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_or_ps(a.v,b.v)); }
0093 template<> EIGEN_STRONG_INLINE Packet4cf pxor   <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_xor_ps(a.v,b.v)); }
0094 template<> EIGEN_STRONG_INLINE Packet4cf pandnot<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_andnot_ps(b.v,a.v)); }
0095 
0096 template<> EIGEN_STRONG_INLINE Packet4cf pload <Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet4cf(pload<Packet8f>(&numext::real_ref(*from))); }
0097 template<> EIGEN_STRONG_INLINE Packet4cf ploadu<Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cf(ploadu<Packet8f>(&numext::real_ref(*from))); }
0098 
0099 
0100 template<> EIGEN_STRONG_INLINE Packet4cf pset1<Packet4cf>(const std::complex<float>& from)
0101 {
0102   return Packet4cf(_mm256_castpd_ps(_mm256_broadcast_sd((const double*)(const void*)&from)));
0103 }
0104 
0105 template<> EIGEN_STRONG_INLINE Packet4cf ploaddup<Packet4cf>(const std::complex<float>* from)
0106 {
0107   // FIXME The following might be optimized using _mm256_movedup_pd
0108   Packet2cf a = ploaddup<Packet2cf>(from);
0109   Packet2cf b = ploaddup<Packet2cf>(from+1);
0110   return  Packet4cf(_mm256_insertf128_ps(_mm256_castps128_ps256(a.v), b.v, 1));
0111 }
0112 
0113 template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float>* to, const Packet4cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); }
0114 template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float>* to, const Packet4cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); }
0115 
0116 template<> EIGEN_DEVICE_FUNC inline Packet4cf pgather<std::complex<float>, Packet4cf>(const std::complex<float>* from, Index stride)
0117 {
0118   return Packet4cf(_mm256_set_ps(std::imag(from[3*stride]), std::real(from[3*stride]),
0119                                  std::imag(from[2*stride]), std::real(from[2*stride]),
0120                                  std::imag(from[1*stride]), std::real(from[1*stride]),
0121                                  std::imag(from[0*stride]), std::real(from[0*stride])));
0122 }
0123 
0124 template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet4cf>(std::complex<float>* to, const Packet4cf& from, Index stride)
0125 {
0126   __m128 low = _mm256_extractf128_ps(from.v, 0);
0127   to[stride*0] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 0)),
0128                                      _mm_cvtss_f32(_mm_shuffle_ps(low, low, 1)));
0129   to[stride*1] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 2)),
0130                                      _mm_cvtss_f32(_mm_shuffle_ps(low, low, 3)));
0131 
0132   __m128 high = _mm256_extractf128_ps(from.v, 1);
0133   to[stride*2] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 0)),
0134                                      _mm_cvtss_f32(_mm_shuffle_ps(high, high, 1)));
0135   to[stride*3] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 2)),
0136                                      _mm_cvtss_f32(_mm_shuffle_ps(high, high, 3)));
0137 
0138 }
0139 
0140 template<> EIGEN_STRONG_INLINE std::complex<float>  pfirst<Packet4cf>(const Packet4cf& a)
0141 {
0142   return pfirst(Packet2cf(_mm256_castps256_ps128(a.v)));
0143 }
0144 
0145 template<> EIGEN_STRONG_INLINE Packet4cf preverse(const Packet4cf& a) {
0146   __m128 low  = _mm256_extractf128_ps(a.v, 0);
0147   __m128 high = _mm256_extractf128_ps(a.v, 1);
0148   __m128d lowd  = _mm_castps_pd(low);
0149   __m128d highd = _mm_castps_pd(high);
0150   low  = _mm_castpd_ps(_mm_shuffle_pd(lowd,lowd,0x1));
0151   high = _mm_castpd_ps(_mm_shuffle_pd(highd,highd,0x1));
0152   __m256 result = _mm256_setzero_ps();
0153   result = _mm256_insertf128_ps(result, low, 1);
0154   result = _mm256_insertf128_ps(result, high, 0);
0155   return Packet4cf(result);
0156 }
0157 
0158 template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet4cf>(const Packet4cf& a)
0159 {
0160   return predux(padd(Packet2cf(_mm256_extractf128_ps(a.v,0)),
0161                      Packet2cf(_mm256_extractf128_ps(a.v,1))));
0162 }
0163 
0164 template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet4cf>(const Packet4cf& a)
0165 {
0166   return predux_mul(pmul(Packet2cf(_mm256_extractf128_ps(a.v, 0)),
0167                          Packet2cf(_mm256_extractf128_ps(a.v, 1))));
0168 }
0169 
0170 EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cf,Packet8f)
0171 
0172 template<> EIGEN_STRONG_INLINE Packet4cf pdiv<Packet4cf>(const Packet4cf& a, const Packet4cf& b)
0173 {
0174   Packet4cf num = pmul(a, pconj(b));
0175   __m256 tmp = _mm256_mul_ps(b.v, b.v);
0176   __m256 tmp2    = _mm256_shuffle_ps(tmp,tmp,0xB1);
0177   __m256 denom = _mm256_add_ps(tmp, tmp2);
0178   return Packet4cf(_mm256_div_ps(num.v, denom));
0179 }
0180 
0181 template<> EIGEN_STRONG_INLINE Packet4cf pcplxflip<Packet4cf>(const Packet4cf& x)
0182 {
0183   return Packet4cf(_mm256_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0 ,1)));
0184 }
0185 
0186 //---------- double ----------
0187 struct Packet2cd
0188 {
0189   EIGEN_STRONG_INLINE Packet2cd() {}
0190   EIGEN_STRONG_INLINE explicit Packet2cd(const __m256d& a) : v(a) {}
0191   __m256d  v;
0192 };
0193 
0194 #ifndef EIGEN_VECTORIZE_AVX512
0195 template<> struct packet_traits<std::complex<double> >  : default_packet_traits
0196 {
0197   typedef Packet2cd type;
0198   typedef Packet1cd half;
0199   enum {
0200     Vectorizable = 1,
0201     AlignedOnScalar = 0,
0202     size = 2,
0203     HasHalfPacket = 1,
0204 
0205     HasAdd    = 1,
0206     HasSub    = 1,
0207     HasMul    = 1,
0208     HasDiv    = 1,
0209     HasNegate = 1,
0210     HasSqrt   = 1,
0211     HasAbs    = 0,
0212     HasAbs2   = 0,
0213     HasMin    = 0,
0214     HasMax    = 0,
0215     HasSetLinear = 0
0216   };
0217 };
0218 #endif
0219 
0220 template<> struct unpacket_traits<Packet2cd> {
0221   typedef std::complex<double> type;
0222   typedef Packet1cd half;
0223   typedef Packet4d as_real;
0224   enum {
0225     size=2,
0226     alignment=Aligned32,
0227     vectorizable=true,
0228     masked_load_available=false,
0229     masked_store_available=false
0230   };
0231 };
0232 
0233 template<> EIGEN_STRONG_INLINE Packet2cd padd<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); }
0234 template<> EIGEN_STRONG_INLINE Packet2cd psub<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); }
0235 template<> EIGEN_STRONG_INLINE Packet2cd pnegate(const Packet2cd& a) { return Packet2cd(pnegate(a.v)); }
0236 template<> EIGEN_STRONG_INLINE Packet2cd pconj(const Packet2cd& a)
0237 {
0238   const __m256d mask = _mm256_castsi256_pd(_mm256_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0));
0239   return Packet2cd(_mm256_xor_pd(a.v,mask));
0240 }
0241 
0242 template<> EIGEN_STRONG_INLINE Packet2cd pmul<Packet2cd>(const Packet2cd& a, const Packet2cd& b)
0243 {
0244   __m256d tmp1 = _mm256_shuffle_pd(a.v,a.v,0x0);
0245   __m256d even = _mm256_mul_pd(tmp1, b.v);
0246   __m256d tmp2 = _mm256_shuffle_pd(a.v,a.v,0xF);
0247   __m256d tmp3 = _mm256_shuffle_pd(b.v,b.v,0x5);
0248   __m256d odd  = _mm256_mul_pd(tmp2, tmp3);
0249   return Packet2cd(_mm256_addsub_pd(even, odd));
0250 }
0251 
0252 template <>
0253 EIGEN_STRONG_INLINE Packet2cd pcmp_eq(const Packet2cd& a, const Packet2cd& b) {
0254   __m256d eq = _mm256_cmp_pd(a.v, b.v, _CMP_EQ_OQ);
0255   return Packet2cd(pand(eq, _mm256_permute_pd(eq, 0x5)));
0256 }
0257 
0258 template<> EIGEN_STRONG_INLINE Packet2cd ptrue<Packet2cd>(const Packet2cd& a) { return Packet2cd(ptrue(Packet4d(a.v))); }
0259 template<> EIGEN_STRONG_INLINE Packet2cd pand   <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_and_pd(a.v,b.v)); }
0260 template<> EIGEN_STRONG_INLINE Packet2cd por    <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_or_pd(a.v,b.v)); }
0261 template<> EIGEN_STRONG_INLINE Packet2cd pxor   <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_xor_pd(a.v,b.v)); }
0262 template<> EIGEN_STRONG_INLINE Packet2cd pandnot<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_andnot_pd(b.v,a.v)); }
0263 
0264 template<> EIGEN_STRONG_INLINE Packet2cd pload <Packet2cd>(const std::complex<double>* from)
0265 { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cd(pload<Packet4d>((const double*)from)); }
0266 template<> EIGEN_STRONG_INLINE Packet2cd ploadu<Packet2cd>(const std::complex<double>* from)
0267 { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cd(ploadu<Packet4d>((const double*)from)); }
0268 
0269 template<> EIGEN_STRONG_INLINE Packet2cd pset1<Packet2cd>(const std::complex<double>& from)
0270 {
0271   // in case casting to a __m128d* is really not safe, then we can still fallback to this version: (much slower though)
0272 //   return Packet2cd(_mm256_loadu2_m128d((const double*)&from,(const double*)&from));
0273     return Packet2cd(_mm256_broadcast_pd((const __m128d*)(const void*)&from));
0274 }
0275 
0276 template<> EIGEN_STRONG_INLINE Packet2cd ploaddup<Packet2cd>(const std::complex<double>* from) { return pset1<Packet2cd>(*from); }
0277 
0278 template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> *   to, const Packet2cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); }
0279 template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> *   to, const Packet2cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); }
0280 
0281 template<> EIGEN_DEVICE_FUNC inline Packet2cd pgather<std::complex<double>, Packet2cd>(const std::complex<double>* from, Index stride)
0282 {
0283   return Packet2cd(_mm256_set_pd(std::imag(from[1*stride]), std::real(from[1*stride]),
0284                  std::imag(from[0*stride]), std::real(from[0*stride])));
0285 }
0286 
0287 template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet2cd>(std::complex<double>* to, const Packet2cd& from, Index stride)
0288 {
0289   __m128d low = _mm256_extractf128_pd(from.v, 0);
0290   to[stride*0] = std::complex<double>(_mm_cvtsd_f64(low), _mm_cvtsd_f64(_mm_shuffle_pd(low, low, 1)));
0291   __m128d high = _mm256_extractf128_pd(from.v, 1);
0292   to[stride*1] = std::complex<double>(_mm_cvtsd_f64(high), _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1)));
0293 }
0294 
0295 template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet2cd>(const Packet2cd& a)
0296 {
0297   __m128d low = _mm256_extractf128_pd(a.v, 0);
0298   EIGEN_ALIGN16 double res[2];
0299   _mm_store_pd(res, low);
0300   return std::complex<double>(res[0],res[1]);
0301 }
0302 
0303 template<> EIGEN_STRONG_INLINE Packet2cd preverse(const Packet2cd& a) {
0304   __m256d result = _mm256_permute2f128_pd(a.v, a.v, 1);
0305   return Packet2cd(result);
0306 }
0307 
0308 template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet2cd>(const Packet2cd& a)
0309 {
0310   return predux(padd(Packet1cd(_mm256_extractf128_pd(a.v,0)),
0311                      Packet1cd(_mm256_extractf128_pd(a.v,1))));
0312 }
0313 
0314 template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet2cd>(const Packet2cd& a)
0315 {
0316   return predux(pmul(Packet1cd(_mm256_extractf128_pd(a.v,0)),
0317                      Packet1cd(_mm256_extractf128_pd(a.v,1))));
0318 }
0319 
0320 EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cd,Packet4d)
0321 
0322 template<> EIGEN_STRONG_INLINE Packet2cd pdiv<Packet2cd>(const Packet2cd& a, const Packet2cd& b)
0323 {
0324   Packet2cd num = pmul(a, pconj(b));
0325   __m256d tmp = _mm256_mul_pd(b.v, b.v);
0326   __m256d denom = _mm256_hadd_pd(tmp, tmp);
0327   return Packet2cd(_mm256_div_pd(num.v, denom));
0328 }
0329 
0330 template<> EIGEN_STRONG_INLINE Packet2cd pcplxflip<Packet2cd>(const Packet2cd& x)
0331 {
0332   return Packet2cd(_mm256_shuffle_pd(x.v, x.v, 0x5));
0333 }
0334 
0335 EIGEN_DEVICE_FUNC inline void
0336 ptranspose(PacketBlock<Packet4cf,4>& kernel) {
0337   __m256d P0 = _mm256_castps_pd(kernel.packet[0].v);
0338   __m256d P1 = _mm256_castps_pd(kernel.packet[1].v);
0339   __m256d P2 = _mm256_castps_pd(kernel.packet[2].v);
0340   __m256d P3 = _mm256_castps_pd(kernel.packet[3].v);
0341 
0342   __m256d T0 = _mm256_shuffle_pd(P0, P1, 15);
0343   __m256d T1 = _mm256_shuffle_pd(P0, P1, 0);
0344   __m256d T2 = _mm256_shuffle_pd(P2, P3, 15);
0345   __m256d T3 = _mm256_shuffle_pd(P2, P3, 0);
0346 
0347   kernel.packet[1].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T0, T2, 32));
0348   kernel.packet[3].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T0, T2, 49));
0349   kernel.packet[0].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T1, T3, 32));
0350   kernel.packet[2].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T1, T3, 49));
0351 }
0352 
0353 EIGEN_DEVICE_FUNC inline void
0354 ptranspose(PacketBlock<Packet2cd,2>& kernel) {
0355   __m256d tmp = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 0+(2<<4));
0356   kernel.packet[1].v = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 1+(3<<4));
0357  kernel.packet[0].v = tmp;
0358 }
0359 
0360 template<> EIGEN_STRONG_INLINE Packet2cd psqrt<Packet2cd>(const Packet2cd& a) {
0361   return psqrt_complex<Packet2cd>(a);
0362 }
0363 
0364 template<> EIGEN_STRONG_INLINE Packet4cf psqrt<Packet4cf>(const Packet4cf& a) {
0365   return psqrt_complex<Packet4cf>(a);
0366 }
0367 
0368 } // end namespace internal
0369 
0370 } // end namespace Eigen
0371 
0372 #endif // EIGEN_COMPLEX_AVX_H