File indexing completed on 2025-01-18 09:57:10
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_SPECIAL_FUNCTIONS_H
0011 #define EIGEN_SPECIAL_FUNCTIONS_H
0012
0013 namespace Eigen {
0014 namespace internal {
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044 template <typename Scalar>
0045 struct lgamma_impl {
0046 EIGEN_DEVICE_FUNC
0047 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
0048 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
0049 THIS_TYPE_IS_NOT_SUPPORTED);
0050 return Scalar(0);
0051 }
0052 };
0053
0054 template <typename Scalar>
0055 struct lgamma_retval {
0056 typedef Scalar type;
0057 };
0058
0059 #if EIGEN_HAS_C99_MATH
0060
0061 #if defined(__GLIBC__) && ((__GLIBC__>=2 && __GLIBC_MINOR__ >= 19) || __GLIBC__>2) \
0062 && (defined(_DEFAULT_SOURCE) || defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
0063 #define EIGEN_HAS_LGAMMA_R
0064 #endif
0065
0066
0067 #if defined(__GLIBC__) && ((__GLIBC__==2 && __GLIBC_MINOR__ < 19) || __GLIBC__<2) \
0068 && (defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
0069 #define EIGEN_HAS_LGAMMA_R
0070 #endif
0071
0072 template <>
0073 struct lgamma_impl<float> {
0074 EIGEN_DEVICE_FUNC
0075 static EIGEN_STRONG_INLINE float run(float x) {
0076 #if !defined(EIGEN_GPU_COMPILE_PHASE) && defined (EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
0077 int dummy;
0078 return ::lgammaf_r(x, &dummy);
0079 #elif defined(SYCL_DEVICE_ONLY)
0080 return cl::sycl::lgamma(x);
0081 #else
0082 return ::lgammaf(x);
0083 #endif
0084 }
0085 };
0086
0087 template <>
0088 struct lgamma_impl<double> {
0089 EIGEN_DEVICE_FUNC
0090 static EIGEN_STRONG_INLINE double run(double x) {
0091 #if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
0092 int dummy;
0093 return ::lgamma_r(x, &dummy);
0094 #elif defined(SYCL_DEVICE_ONLY)
0095 return cl::sycl::lgamma(x);
0096 #else
0097 return ::lgamma(x);
0098 #endif
0099 }
0100 };
0101
0102 #undef EIGEN_HAS_LGAMMA_R
0103 #endif
0104
0105
0106
0107
0108
0109 template <typename Scalar>
0110 struct digamma_retval {
0111 typedef Scalar type;
0112 };
0113
0114
0115
0116
0117
0118
0119
0120
0121
0122
0123
0124
0125
0126
0127 template <typename Scalar>
0128 struct digamma_impl_maybe_poly {
0129 EIGEN_DEVICE_FUNC
0130 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
0131 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
0132 THIS_TYPE_IS_NOT_SUPPORTED);
0133 return Scalar(0);
0134 }
0135 };
0136
0137
0138 template <>
0139 struct digamma_impl_maybe_poly<float> {
0140 EIGEN_DEVICE_FUNC
0141 static EIGEN_STRONG_INLINE float run(const float s) {
0142 const float A[] = {
0143 -4.16666666666666666667E-3f,
0144 3.96825396825396825397E-3f,
0145 -8.33333333333333333333E-3f,
0146 8.33333333333333333333E-2f
0147 };
0148
0149 float z;
0150 if (s < 1.0e8f) {
0151 z = 1.0f / (s * s);
0152 return z * internal::ppolevl<float, 3>::run(z, A);
0153 } else return 0.0f;
0154 }
0155 };
0156
0157 template <>
0158 struct digamma_impl_maybe_poly<double> {
0159 EIGEN_DEVICE_FUNC
0160 static EIGEN_STRONG_INLINE double run(const double s) {
0161 const double A[] = {
0162 8.33333333333333333333E-2,
0163 -2.10927960927960927961E-2,
0164 7.57575757575757575758E-3,
0165 -4.16666666666666666667E-3,
0166 3.96825396825396825397E-3,
0167 -8.33333333333333333333E-3,
0168 8.33333333333333333333E-2
0169 };
0170
0171 double z;
0172 if (s < 1.0e17) {
0173 z = 1.0 / (s * s);
0174 return z * internal::ppolevl<double, 6>::run(z, A);
0175 }
0176 else return 0.0;
0177 }
0178 };
0179
0180 template <typename Scalar>
0181 struct digamma_impl {
0182 EIGEN_DEVICE_FUNC
0183 static Scalar run(Scalar x) {
0184
0185
0186
0187
0188
0189
0190
0191
0192
0193
0194
0195
0196
0197
0198
0199
0200
0201
0202
0203
0204
0205
0206
0207
0208
0209
0210
0211
0212
0213
0214
0215
0216
0217
0218
0219
0220
0221
0222
0223
0224
0225
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237
0238
0239
0240
0241 Scalar p, q, nz, s, w, y;
0242 bool negative = false;
0243
0244 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
0245 const Scalar m_pi = Scalar(EIGEN_PI);
0246
0247 const Scalar zero = Scalar(0);
0248 const Scalar one = Scalar(1);
0249 const Scalar half = Scalar(0.5);
0250 nz = zero;
0251
0252 if (x <= zero) {
0253 negative = true;
0254 q = x;
0255 p = numext::floor(q);
0256 if (p == q) {
0257 return nan;
0258 }
0259
0260
0261
0262 nz = q - p;
0263 if (nz != half) {
0264 if (nz > half) {
0265 p += one;
0266 nz = q - p;
0267 }
0268 nz = m_pi / numext::tan(m_pi * nz);
0269 }
0270 else {
0271 nz = zero;
0272 }
0273 x = one - x;
0274 }
0275
0276
0277 s = x;
0278 w = zero;
0279 while (s < Scalar(10)) {
0280 w += one / s;
0281 s += one;
0282 }
0283
0284 y = digamma_impl_maybe_poly<Scalar>::run(s);
0285
0286 y = numext::log(s) - (half / s) - y - w;
0287
0288 return (negative) ? y - nz : y;
0289 }
0290 };
0291
0292
0293
0294
0295
0296
0297
0298
0299
0300
0301
0302
0303 template <typename T>
0304 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
0305
0306
0307 const T plus_4 = pset1<T>(4.f);
0308 const T minus_4 = pset1<T>(-4.f);
0309 const T x = pmax(pmin(a_x, plus_4), minus_4);
0310
0311 const T alpha_1 = pset1<T>(-1.60960333262415e-02f);
0312 const T alpha_3 = pset1<T>(-2.95459980854025e-03f);
0313 const T alpha_5 = pset1<T>(-7.34990630326855e-04f);
0314 const T alpha_7 = pset1<T>(-5.69250639462346e-05f);
0315 const T alpha_9 = pset1<T>(-2.10102402082508e-06f);
0316 const T alpha_11 = pset1<T>(2.77068142495902e-08f);
0317 const T alpha_13 = pset1<T>(-2.72614225801306e-10f);
0318
0319
0320 const T beta_0 = pset1<T>(-1.42647390514189e-02f);
0321 const T beta_2 = pset1<T>(-7.37332916720468e-03f);
0322 const T beta_4 = pset1<T>(-1.68282697438203e-03f);
0323 const T beta_6 = pset1<T>(-2.13374055278905e-04f);
0324 const T beta_8 = pset1<T>(-1.45660718464996e-05f);
0325
0326
0327 const T x2 = pmul(x, x);
0328
0329
0330 T p = pmadd(x2, alpha_13, alpha_11);
0331 p = pmadd(x2, p, alpha_9);
0332 p = pmadd(x2, p, alpha_7);
0333 p = pmadd(x2, p, alpha_5);
0334 p = pmadd(x2, p, alpha_3);
0335 p = pmadd(x2, p, alpha_1);
0336 p = pmul(x, p);
0337
0338
0339 T q = pmadd(x2, beta_8, beta_6);
0340 q = pmadd(x2, q, beta_4);
0341 q = pmadd(x2, q, beta_2);
0342 q = pmadd(x2, q, beta_0);
0343
0344
0345 return pdiv(p, q);
0346 }
0347
0348 template <typename T>
0349 struct erf_impl {
0350 EIGEN_DEVICE_FUNC
0351 static EIGEN_STRONG_INLINE T run(const T& x) {
0352 return generic_fast_erf_float(x);
0353 }
0354 };
0355
0356 template <typename Scalar>
0357 struct erf_retval {
0358 typedef Scalar type;
0359 };
0360
0361 #if EIGEN_HAS_C99_MATH
0362 template <>
0363 struct erf_impl<float> {
0364 EIGEN_DEVICE_FUNC
0365 static EIGEN_STRONG_INLINE float run(float x) {
0366 #if defined(SYCL_DEVICE_ONLY)
0367 return cl::sycl::erf(x);
0368 #else
0369 return generic_fast_erf_float(x);
0370 #endif
0371 }
0372 };
0373
0374 template <>
0375 struct erf_impl<double> {
0376 EIGEN_DEVICE_FUNC
0377 static EIGEN_STRONG_INLINE double run(double x) {
0378 #if defined(SYCL_DEVICE_ONLY)
0379 return cl::sycl::erf(x);
0380 #else
0381 return ::erf(x);
0382 #endif
0383 }
0384 };
0385 #endif
0386
0387
0388
0389
0390
0391 template <typename Scalar>
0392 struct erfc_impl {
0393 EIGEN_DEVICE_FUNC
0394 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
0395 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
0396 THIS_TYPE_IS_NOT_SUPPORTED);
0397 return Scalar(0);
0398 }
0399 };
0400
0401 template <typename Scalar>
0402 struct erfc_retval {
0403 typedef Scalar type;
0404 };
0405
0406 #if EIGEN_HAS_C99_MATH
0407 template <>
0408 struct erfc_impl<float> {
0409 EIGEN_DEVICE_FUNC
0410 static EIGEN_STRONG_INLINE float run(const float x) {
0411 #if defined(SYCL_DEVICE_ONLY)
0412 return cl::sycl::erfc(x);
0413 #else
0414 return ::erfcf(x);
0415 #endif
0416 }
0417 };
0418
0419 template <>
0420 struct erfc_impl<double> {
0421 EIGEN_DEVICE_FUNC
0422 static EIGEN_STRONG_INLINE double run(const double x) {
0423 #if defined(SYCL_DEVICE_ONLY)
0424 return cl::sycl::erfc(x);
0425 #else
0426 return ::erfc(x);
0427 #endif
0428 }
0429 };
0430 #endif
0431
0432
0433
0434
0435
0436
0437
0438
0439
0440
0441
0442
0443
0444
0445
0446
0447
0448
0449
0450
0451
0452
0453
0454
0455
0456
0457
0458
0459
0460
0461
0462
0463
0464
0465
0466
0467
0468
0469
0470
0471
0472
0473
0474
0475
0476
0477
0478
0479
0480
0481
0482
0483
0484
0485
0486
0487
0488
0489
0490 template<typename T>
0491 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(
0492 const T& should_flipsign, const T& x) {
0493 typedef typename unpacket_traits<T>::type Scalar;
0494 const T sign_mask = pset1<T>(Scalar(-0.0));
0495 T sign_bit = pand<T>(should_flipsign, sign_mask);
0496 return pxor<T>(sign_bit, x);
0497 }
0498
0499 template<>
0500 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double flipsign<double>(
0501 const double& should_flipsign, const double& x) {
0502 return should_flipsign == 0 ? x : -x;
0503 }
0504
0505 template<>
0506 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float flipsign<float>(
0507 const float& should_flipsign, const float& x) {
0508 return should_flipsign == 0 ? x : -x;
0509 }
0510
0511
0512
0513
0514
0515 template <typename T, typename ScalarType>
0516 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(const T& b) {
0517 const ScalarType p0[] = {
0518 ScalarType(-5.99633501014107895267e1),
0519 ScalarType(9.80010754185999661536e1),
0520 ScalarType(-5.66762857469070293439e1),
0521 ScalarType(1.39312609387279679503e1),
0522 ScalarType(-1.23916583867381258016e0)
0523 };
0524 const ScalarType q0[] = {
0525 ScalarType(1.0),
0526 ScalarType(1.95448858338141759834e0),
0527 ScalarType(4.67627912898881538453e0),
0528 ScalarType(8.63602421390890590575e1),
0529 ScalarType(-2.25462687854119370527e2),
0530 ScalarType(2.00260212380060660359e2),
0531 ScalarType(-8.20372256168333339912e1),
0532 ScalarType(1.59056225126211695515e1),
0533 ScalarType(-1.18331621121330003142e0)
0534 };
0535 const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
0536 const T half = pset1<T>(ScalarType(0.5));
0537 T c, c2, ndtri_gt_exp_neg_two;
0538
0539 c = psub(b, half);
0540 c2 = pmul(c, c);
0541 ndtri_gt_exp_neg_two = pmadd(c, pmul(
0542 c2, pdiv(
0543 internal::ppolevl<T, 4>::run(c2, p0),
0544 internal::ppolevl<T, 8>::run(c2, q0))), c);
0545 return pmul(ndtri_gt_exp_neg_two, sqrt2pi);
0546 }
0547
0548 template <typename T, typename ScalarType>
0549 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_lt_exp_neg_two(
0550 const T& b, const T& should_flipsign) {
0551
0552
0553
0554 const ScalarType p1[] = {
0555 ScalarType(4.05544892305962419923e0),
0556 ScalarType(3.15251094599893866154e1),
0557 ScalarType(5.71628192246421288162e1),
0558 ScalarType(4.40805073893200834700e1),
0559 ScalarType(1.46849561928858024014e1),
0560 ScalarType(2.18663306850790267539e0),
0561 ScalarType(-1.40256079171354495875e-1),
0562 ScalarType(-3.50424626827848203418e-2),
0563 ScalarType(-8.57456785154685413611e-4)
0564 };
0565 const ScalarType q1[] = {
0566 ScalarType(1.0),
0567 ScalarType(1.57799883256466749731e1),
0568 ScalarType(4.53907635128879210584e1),
0569 ScalarType(4.13172038254672030440e1),
0570 ScalarType(1.50425385692907503408e1),
0571 ScalarType(2.50464946208309415979e0),
0572 ScalarType(-1.42182922854787788574e-1),
0573 ScalarType(-3.80806407691578277194e-2),
0574 ScalarType(-9.33259480895457427372e-4)
0575 };
0576
0577
0578
0579 const ScalarType p2[] = {
0580 ScalarType(3.23774891776946035970e0),
0581 ScalarType(6.91522889068984211695e0),
0582 ScalarType(3.93881025292474443415e0),
0583 ScalarType(1.33303460815807542389e0),
0584 ScalarType(2.01485389549179081538e-1),
0585 ScalarType(1.23716634817820021358e-2),
0586 ScalarType(3.01581553508235416007e-4),
0587 ScalarType(2.65806974686737550832e-6),
0588 ScalarType(6.23974539184983293730e-9)
0589 };
0590 const ScalarType q2[] = {
0591 ScalarType(1.0),
0592 ScalarType(6.02427039364742014255e0),
0593 ScalarType(3.67983563856160859403e0),
0594 ScalarType(1.37702099489081330271e0),
0595 ScalarType(2.16236993594496635890e-1),
0596 ScalarType(1.34204006088543189037e-2),
0597 ScalarType(3.28014464682127739104e-4),
0598 ScalarType(2.89247864745380683936e-6),
0599 ScalarType(6.79019408009981274425e-9)
0600 };
0601 const T eight = pset1<T>(ScalarType(8.0));
0602 const T one = pset1<T>(ScalarType(1));
0603 const T neg_two = pset1<T>(ScalarType(-2));
0604 T x, x0, x1, z;
0605
0606 x = psqrt(pmul(neg_two, plog(b)));
0607 x0 = psub(x, pdiv(plog(x), x));
0608 z = pdiv(one, x);
0609 x1 = pmul(
0610 z, pselect(
0611 pcmp_lt(x, eight),
0612 pdiv(internal::ppolevl<T, 8>::run(z, p1),
0613 internal::ppolevl<T, 8>::run(z, q1)),
0614 pdiv(internal::ppolevl<T, 8>::run(z, p2),
0615 internal::ppolevl<T, 8>::run(z, q2))));
0616 return flipsign(should_flipsign, psub(x0, x1));
0617 }
0618
0619 template <typename T, typename ScalarType>
0620 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
0621 T generic_ndtri(const T& a) {
0622 const T maxnum = pset1<T>(NumTraits<ScalarType>::infinity());
0623 const T neg_maxnum = pset1<T>(-NumTraits<ScalarType>::infinity());
0624
0625 const T zero = pset1<T>(ScalarType(0));
0626 const T one = pset1<T>(ScalarType(1));
0627
0628 const T exp_neg_two = pset1<T>(ScalarType(0.13533528323661269189));
0629 T b, ndtri, should_flipsign;
0630
0631 should_flipsign = pcmp_le(a, psub(one, exp_neg_two));
0632 b = pselect(should_flipsign, a, psub(one, a));
0633
0634 ndtri = pselect(
0635 pcmp_lt(exp_neg_two, b),
0636 generic_ndtri_gt_exp_neg_two<T, ScalarType>(b),
0637 generic_ndtri_lt_exp_neg_two<T, ScalarType>(b, should_flipsign));
0638
0639 return pselect(
0640 pcmp_le(a, zero), neg_maxnum,
0641 pselect(pcmp_le(one, a), maxnum, ndtri));
0642 }
0643
0644 template <typename Scalar>
0645 struct ndtri_retval {
0646 typedef Scalar type;
0647 };
0648
0649 #if !EIGEN_HAS_C99_MATH
0650
0651 template <typename Scalar>
0652 struct ndtri_impl {
0653 EIGEN_DEVICE_FUNC
0654 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
0655 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
0656 THIS_TYPE_IS_NOT_SUPPORTED);
0657 return Scalar(0);
0658 }
0659 };
0660
0661 # else
0662
0663 template <typename Scalar>
0664 struct ndtri_impl {
0665 EIGEN_DEVICE_FUNC
0666 static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
0667 return generic_ndtri<Scalar, Scalar>(x);
0668 }
0669 };
0670
0671 #endif
0672
0673
0674
0675
0676
0677
0678 template <typename Scalar>
0679 struct igammac_retval {
0680 typedef Scalar type;
0681 };
0682
0683
0684 template <typename Scalar>
0685 struct cephes_helper {
0686 EIGEN_DEVICE_FUNC
0687 static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
0688 EIGEN_DEVICE_FUNC
0689 static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
0690 EIGEN_DEVICE_FUNC
0691 static EIGEN_STRONG_INLINE Scalar biginv() { assert(false && "biginv not supported for this type"); return 0.0; }
0692 };
0693
0694 template <>
0695 struct cephes_helper<float> {
0696 EIGEN_DEVICE_FUNC
0697 static EIGEN_STRONG_INLINE float machep() {
0698 return NumTraits<float>::epsilon() / 2;
0699 }
0700 EIGEN_DEVICE_FUNC
0701 static EIGEN_STRONG_INLINE float big() {
0702
0703 return 1.0f / (NumTraits<float>::epsilon() / 2);
0704 }
0705 EIGEN_DEVICE_FUNC
0706 static EIGEN_STRONG_INLINE float biginv() {
0707
0708 return machep();
0709 }
0710 };
0711
0712 template <>
0713 struct cephes_helper<double> {
0714 EIGEN_DEVICE_FUNC
0715 static EIGEN_STRONG_INLINE double machep() {
0716 return NumTraits<double>::epsilon() / 2;
0717 }
0718 EIGEN_DEVICE_FUNC
0719 static EIGEN_STRONG_INLINE double big() {
0720 return 1.0 / NumTraits<double>::epsilon();
0721 }
0722 EIGEN_DEVICE_FUNC
0723 static EIGEN_STRONG_INLINE double biginv() {
0724
0725 return NumTraits<double>::epsilon();
0726 }
0727 };
0728
0729 enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
0730
0731 template <typename Scalar>
0732 EIGEN_DEVICE_FUNC
0733 static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
0734
0735 Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
0736 if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
0737
0738 (numext::isnan)(logax)) {
0739 return Scalar(0);
0740 }
0741 return numext::exp(logax);
0742 }
0743
0744 template <typename Scalar, IgammaComputationMode mode>
0745 EIGEN_DEVICE_FUNC
0746 int igamma_num_iterations() {
0747
0748
0749 if (mode == VALUE) {
0750 return 2000;
0751 }
0752
0753 if (internal::is_same<Scalar, float>::value) {
0754 return 200;
0755 } else if (internal::is_same<Scalar, double>::value) {
0756 return 500;
0757 } else {
0758 return 2000;
0759 }
0760 }
0761
0762 template <typename Scalar, IgammaComputationMode mode>
0763 struct igammac_cf_impl {
0764
0765
0766
0767
0768
0769
0770
0771
0772
0773 EIGEN_DEVICE_FUNC
0774 static Scalar run(Scalar a, Scalar x) {
0775 const Scalar zero = 0;
0776 const Scalar one = 1;
0777 const Scalar two = 2;
0778 const Scalar machep = cephes_helper<Scalar>::machep();
0779 const Scalar big = cephes_helper<Scalar>::big();
0780 const Scalar biginv = cephes_helper<Scalar>::biginv();
0781
0782 if ((numext::isinf)(x)) {
0783 return zero;
0784 }
0785
0786 Scalar ax = main_igamma_term<Scalar>(a, x);
0787
0788
0789
0790
0791 if (ax == zero) {
0792 return zero;
0793 }
0794
0795
0796 Scalar y = one - a;
0797 Scalar z = x + y + one;
0798 Scalar c = zero;
0799 Scalar pkm2 = one;
0800 Scalar qkm2 = x;
0801 Scalar pkm1 = x + one;
0802 Scalar qkm1 = z * x;
0803 Scalar ans = pkm1 / qkm1;
0804
0805 Scalar dpkm2_da = zero;
0806 Scalar dqkm2_da = zero;
0807 Scalar dpkm1_da = zero;
0808 Scalar dqkm1_da = -x;
0809 Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
0810
0811 for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
0812 c += one;
0813 y += one;
0814 z += two;
0815
0816 Scalar yc = y * c;
0817 Scalar pk = pkm1 * z - pkm2 * yc;
0818 Scalar qk = qkm1 * z - qkm2 * yc;
0819
0820 Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
0821 Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
0822
0823 if (qk != zero) {
0824 Scalar ans_prev = ans;
0825 ans = pk / qk;
0826
0827 Scalar dans_da_prev = dans_da;
0828 dans_da = (dpk_da - ans * dqk_da) / qk;
0829
0830 if (mode == VALUE) {
0831 if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
0832 break;
0833 }
0834 } else {
0835 if (numext::abs(dans_da - dans_da_prev) <= machep) {
0836 break;
0837 }
0838 }
0839 }
0840
0841 pkm2 = pkm1;
0842 pkm1 = pk;
0843 qkm2 = qkm1;
0844 qkm1 = qk;
0845
0846 dpkm2_da = dpkm1_da;
0847 dpkm1_da = dpk_da;
0848 dqkm2_da = dqkm1_da;
0849 dqkm1_da = dqk_da;
0850
0851 if (numext::abs(pk) > big) {
0852 pkm2 *= biginv;
0853 pkm1 *= biginv;
0854 qkm2 *= biginv;
0855 qkm1 *= biginv;
0856
0857 dpkm2_da *= biginv;
0858 dpkm1_da *= biginv;
0859 dqkm2_da *= biginv;
0860 dqkm1_da *= biginv;
0861 }
0862 }
0863
0864
0865 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
0866 Scalar dax_da = ax * dlogax_da;
0867
0868 switch (mode) {
0869 case VALUE:
0870 return ans * ax;
0871 case DERIVATIVE:
0872 return ans * dax_da + dans_da * ax;
0873 case SAMPLE_DERIVATIVE:
0874 default:
0875 return -(dans_da + ans * dlogax_da) * x;
0876 }
0877 }
0878 };
0879
0880 template <typename Scalar, IgammaComputationMode mode>
0881 struct igamma_series_impl {
0882
0883
0884
0885
0886
0887
0888
0889
0890 EIGEN_DEVICE_FUNC
0891 static Scalar run(Scalar a, Scalar x) {
0892 const Scalar zero = 0;
0893 const Scalar one = 1;
0894 const Scalar machep = cephes_helper<Scalar>::machep();
0895
0896 Scalar ax = main_igamma_term<Scalar>(a, x);
0897
0898
0899
0900
0901
0902 if (ax == zero) {
0903 return zero;
0904 }
0905
0906 ax /= a;
0907
0908
0909 Scalar r = a;
0910 Scalar c = one;
0911 Scalar ans = one;
0912
0913 Scalar dc_da = zero;
0914 Scalar dans_da = zero;
0915
0916 for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
0917 r += one;
0918 Scalar term = x / r;
0919 Scalar dterm_da = -x / (r * r);
0920 dc_da = term * dc_da + dterm_da * c;
0921 dans_da += dc_da;
0922 c *= term;
0923 ans += c;
0924
0925 if (mode == VALUE) {
0926 if (c <= machep * ans) {
0927 break;
0928 }
0929 } else {
0930 if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
0931 break;
0932 }
0933 }
0934 }
0935
0936 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
0937 Scalar dax_da = ax * dlogax_da;
0938
0939 switch (mode) {
0940 case VALUE:
0941 return ans * ax;
0942 case DERIVATIVE:
0943 return ans * dax_da + dans_da * ax;
0944 case SAMPLE_DERIVATIVE:
0945 default:
0946 return -(dans_da + ans * dlogax_da) * x / a;
0947 }
0948 }
0949 };
0950
0951 #if !EIGEN_HAS_C99_MATH
0952
0953 template <typename Scalar>
0954 struct igammac_impl {
0955 EIGEN_DEVICE_FUNC
0956 static Scalar run(Scalar a, Scalar x) {
0957 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
0958 THIS_TYPE_IS_NOT_SUPPORTED);
0959 return Scalar(0);
0960 }
0961 };
0962
0963 #else
0964
0965 template <typename Scalar>
0966 struct igammac_impl {
0967 EIGEN_DEVICE_FUNC
0968 static Scalar run(Scalar a, Scalar x) {
0969
0970
0971
0972
0973
0974
0975
0976
0977
0978
0979
0980
0981
0982
0983
0984
0985
0986
0987
0988
0989
0990
0991
0992
0993
0994
0995
0996
0997
0998
0999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023 const Scalar zero = 0;
1024 const Scalar one = 1;
1025 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1026
1027 if ((x < zero) || (a <= zero)) {
1028
1029 return nan;
1030 }
1031
1032 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1033 return nan;
1034 }
1035
1036 if ((x < one) || (x < a)) {
1037 return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
1038 }
1039
1040 return igammac_cf_impl<Scalar, VALUE>::run(a, x);
1041 }
1042 };
1043
1044 #endif
1045
1046
1047
1048
1049
1050 #if !EIGEN_HAS_C99_MATH
1051
1052 template <typename Scalar, IgammaComputationMode mode>
1053 struct igamma_generic_impl {
1054 EIGEN_DEVICE_FUNC
1055 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
1056 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1057 THIS_TYPE_IS_NOT_SUPPORTED);
1058 return Scalar(0);
1059 }
1060 };
1061
1062 #else
1063
1064 template <typename Scalar, IgammaComputationMode mode>
1065 struct igamma_generic_impl {
1066 EIGEN_DEVICE_FUNC
1067 static Scalar run(Scalar a, Scalar x) {
1068
1069
1070
1071
1072
1073
1074
1075
1076 const Scalar zero = 0;
1077 const Scalar one = 1;
1078 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1079
1080 if (x == zero) return zero;
1081
1082 if ((x < zero) || (a <= zero)) {
1083 return nan;
1084 }
1085
1086 if ((numext::isnan)(a) || (numext::isnan)(x)) {
1087 return nan;
1088 }
1089
1090 if ((x > one) && (x > a)) {
1091 Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
1092 if (mode == VALUE) {
1093 return one - ret;
1094 } else {
1095 return -ret;
1096 }
1097 }
1098
1099 return igamma_series_impl<Scalar, mode>::run(a, x);
1100 }
1101 };
1102
1103 #endif
1104
1105 template <typename Scalar>
1106 struct igamma_retval {
1107 typedef Scalar type;
1108 };
1109
1110 template <typename Scalar>
1111 struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179 };
1180
1181 template <typename Scalar>
1182 struct igamma_der_a_retval : igamma_retval<Scalar> {};
1183
1184 template <typename Scalar>
1185 struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200 };
1201
1202 template <typename Scalar>
1203 struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
1204
1205 template <typename Scalar>
1206 struct gamma_sample_der_alpha_impl
1207 : igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245 };
1246
1247
1248
1249
1250
1251 template <typename Scalar>
1252 struct zeta_retval {
1253 typedef Scalar type;
1254 };
1255
1256 template <typename Scalar>
1257 struct zeta_impl_series {
1258 EIGEN_DEVICE_FUNC
1259 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
1260 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1261 THIS_TYPE_IS_NOT_SUPPORTED);
1262 return Scalar(0);
1263 }
1264 };
1265
1266 template <>
1267 struct zeta_impl_series<float> {
1268 EIGEN_DEVICE_FUNC
1269 static EIGEN_STRONG_INLINE bool run(float& a, float& b, float& s, const float x, const float machep) {
1270 int i = 0;
1271 while(i < 9)
1272 {
1273 i += 1;
1274 a += 1.0f;
1275 b = numext::pow( a, -x );
1276 s += b;
1277 if( numext::abs(b/s) < machep )
1278 return true;
1279 }
1280
1281
1282 return false;
1283 }
1284 };
1285
1286 template <>
1287 struct zeta_impl_series<double> {
1288 EIGEN_DEVICE_FUNC
1289 static EIGEN_STRONG_INLINE bool run(double& a, double& b, double& s, const double x, const double machep) {
1290 int i = 0;
1291 while( (i < 9) || (a <= 9.0) )
1292 {
1293 i += 1;
1294 a += 1.0;
1295 b = numext::pow( a, -x );
1296 s += b;
1297 if( numext::abs(b/s) < machep )
1298 return true;
1299 }
1300
1301
1302 return false;
1303 }
1304 };
1305
1306 template <typename Scalar>
1307 struct zeta_impl {
1308 EIGEN_DEVICE_FUNC
1309 static Scalar run(Scalar x, Scalar q) {
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371 int i;
1372 Scalar p, r, a, b, k, s, t, w;
1373
1374 const Scalar A[] = {
1375 Scalar(12.0),
1376 Scalar(-720.0),
1377 Scalar(30240.0),
1378 Scalar(-1209600.0),
1379 Scalar(47900160.0),
1380 Scalar(-1.8924375803183791606e9),
1381 Scalar(7.47242496e10),
1382 Scalar(-2.950130727918164224e12),
1383 Scalar(1.1646782814350067249e14),
1384 Scalar(-4.5979787224074726105e15),
1385 Scalar(1.8152105401943546773e17),
1386 Scalar(-7.1661652561756670113e18)
1387 };
1388
1389 const Scalar maxnum = NumTraits<Scalar>::infinity();
1390 const Scalar zero = 0.0, half = 0.5, one = 1.0;
1391 const Scalar machep = cephes_helper<Scalar>::machep();
1392 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1393
1394 if( x == one )
1395 return maxnum;
1396
1397 if( x < one )
1398 {
1399 return nan;
1400 }
1401
1402 if( q <= zero )
1403 {
1404 if(q == numext::floor(q))
1405 {
1406 if (x == numext::floor(x) && long(x) % 2 == 0) {
1407 return maxnum;
1408 }
1409 else {
1410 return nan;
1411 }
1412 }
1413 p = x;
1414 r = numext::floor(p);
1415 if (p != r)
1416 return nan;
1417 }
1418
1419
1420
1421
1422
1423
1424 s = numext::pow( q, -x );
1425 a = q;
1426 b = zero;
1427
1428 if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
1429 return s;
1430 }
1431
1432 w = a;
1433 s += b*w/(x-one);
1434 s -= half * b;
1435 a = one;
1436 k = zero;
1437 for( i=0; i<12; i++ )
1438 {
1439 a *= x + k;
1440 b /= w;
1441 t = a*b/A[i];
1442 s = s + t;
1443 t = numext::abs(t/s);
1444 if( t < machep ) {
1445 break;
1446 }
1447 k += one;
1448 a *= x + k;
1449 b /= w;
1450 k += one;
1451 }
1452 return s;
1453 }
1454 };
1455
1456
1457
1458
1459
1460 template <typename Scalar>
1461 struct polygamma_retval {
1462 typedef Scalar type;
1463 };
1464
1465 #if !EIGEN_HAS_C99_MATH
1466
1467 template <typename Scalar>
1468 struct polygamma_impl {
1469 EIGEN_DEVICE_FUNC
1470 static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
1471 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1472 THIS_TYPE_IS_NOT_SUPPORTED);
1473 return Scalar(0);
1474 }
1475 };
1476
1477 #else
1478
1479 template <typename Scalar>
1480 struct polygamma_impl {
1481 EIGEN_DEVICE_FUNC
1482 static Scalar run(Scalar n, Scalar x) {
1483 Scalar zero = 0.0, one = 1.0;
1484 Scalar nplus = n + one;
1485 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1486
1487
1488 if (numext::floor(n) != n || n < zero) {
1489 return nan;
1490 }
1491
1492 else if (n == zero) {
1493 return digamma_impl<Scalar>::run(x);
1494 }
1495
1496 else {
1497 Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
1498 return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
1499 }
1500 }
1501 };
1502
1503 #endif
1504
1505
1506
1507
1508
1509 template <typename Scalar>
1510 struct betainc_retval {
1511 typedef Scalar type;
1512 };
1513
1514 #if !EIGEN_HAS_C99_MATH
1515
1516 template <typename Scalar>
1517 struct betainc_impl {
1518 EIGEN_DEVICE_FUNC
1519 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
1520 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1521 THIS_TYPE_IS_NOT_SUPPORTED);
1522 return Scalar(0);
1523 }
1524 };
1525
1526 #else
1527
1528 template <typename Scalar>
1529 struct betainc_impl {
1530 EIGEN_DEVICE_FUNC
1531 static EIGEN_STRONG_INLINE Scalar run(Scalar, Scalar, Scalar) {
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1602 THIS_TYPE_IS_NOT_SUPPORTED);
1603 return Scalar(0);
1604 }
1605 };
1606
1607
1608
1609
1610 template <typename Scalar>
1611 struct incbeta_cfe {
1612 EIGEN_DEVICE_FUNC
1613 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x, bool small_branch) {
1614 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value ||
1615 internal::is_same<Scalar, double>::value),
1616 THIS_TYPE_IS_NOT_SUPPORTED);
1617 const Scalar big = cephes_helper<Scalar>::big();
1618 const Scalar machep = cephes_helper<Scalar>::machep();
1619 const Scalar biginv = cephes_helper<Scalar>::biginv();
1620
1621 const Scalar zero = 0;
1622 const Scalar one = 1;
1623 const Scalar two = 2;
1624
1625 Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
1626 Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
1627 Scalar ans;
1628 int n;
1629
1630 const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
1631 const Scalar thresh =
1632 (internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
1633 Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
1634
1635 if (small_branch) {
1636 k1 = a;
1637 k2 = a + b;
1638 k3 = a;
1639 k4 = a + one;
1640 k5 = one;
1641 k6 = b - one;
1642 k7 = k4;
1643 k8 = a + two;
1644 k26update = one;
1645 } else {
1646 k1 = a;
1647 k2 = b - one;
1648 k3 = a;
1649 k4 = a + one;
1650 k5 = one;
1651 k6 = a + b;
1652 k7 = a + one;
1653 k8 = a + two;
1654 k26update = -one;
1655 x = x / (one - x);
1656 }
1657
1658 pkm2 = zero;
1659 qkm2 = one;
1660 pkm1 = one;
1661 qkm1 = one;
1662 ans = one;
1663 n = 0;
1664
1665 do {
1666 xk = -(x * k1 * k2) / (k3 * k4);
1667 pk = pkm1 + pkm2 * xk;
1668 qk = qkm1 + qkm2 * xk;
1669 pkm2 = pkm1;
1670 pkm1 = pk;
1671 qkm2 = qkm1;
1672 qkm1 = qk;
1673
1674 xk = (x * k5 * k6) / (k7 * k8);
1675 pk = pkm1 + pkm2 * xk;
1676 qk = qkm1 + qkm2 * xk;
1677 pkm2 = pkm1;
1678 pkm1 = pk;
1679 qkm2 = qkm1;
1680 qkm1 = qk;
1681
1682 if (qk != zero) {
1683 r = pk / qk;
1684 if (numext::abs(ans - r) < numext::abs(r) * thresh) {
1685 return r;
1686 }
1687 ans = r;
1688 }
1689
1690 k1 += one;
1691 k2 += k26update;
1692 k3 += two;
1693 k4 += two;
1694 k5 += one;
1695 k6 -= k26update;
1696 k7 += two;
1697 k8 += two;
1698
1699 if ((numext::abs(qk) + numext::abs(pk)) > big) {
1700 pkm2 *= biginv;
1701 pkm1 *= biginv;
1702 qkm2 *= biginv;
1703 qkm1 *= biginv;
1704 }
1705 if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
1706 pkm2 *= big;
1707 pkm1 *= big;
1708 qkm2 *= big;
1709 qkm1 *= big;
1710 }
1711 } while (++n < num_iters);
1712
1713 return ans;
1714 }
1715 };
1716
1717
1718 template <typename Scalar>
1719 struct betainc_helper {};
1720
1721 template <>
1722 struct betainc_helper<float> {
1723
1724 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float incbsa(float aa, float bb,
1725 float xx) {
1726 float ans, a, b, t, x, onemx;
1727 bool reversed_a_b = false;
1728
1729 onemx = 1.0f - xx;
1730
1731
1732 if (xx > (aa / (aa + bb))) {
1733 reversed_a_b = true;
1734 a = bb;
1735 b = aa;
1736 t = xx;
1737 x = onemx;
1738 } else {
1739 a = aa;
1740 b = bb;
1741 t = onemx;
1742 x = xx;
1743 }
1744
1745
1746 if (b > 10.0f) {
1747 if (numext::abs(b * x / a) < 0.3f) {
1748 t = betainc_helper<float>::incbps(a, b, x);
1749 if (reversed_a_b) t = 1.0f - t;
1750 return t;
1751 }
1752 }
1753
1754 ans = x * (a + b - 2.0f) / (a - 1.0f);
1755 if (ans < 1.0f) {
1756 ans = incbeta_cfe<float>::run(a, b, x, true );
1757 t = b * numext::log(t);
1758 } else {
1759 ans = incbeta_cfe<float>::run(a, b, x, false );
1760 t = (b - 1.0f) * numext::log(t);
1761 }
1762
1763 t += a * numext::log(x) + lgamma_impl<float>::run(a + b) -
1764 lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
1765 t += numext::log(ans / a);
1766 t = numext::exp(t);
1767
1768 if (reversed_a_b) t = 1.0f - t;
1769 return t;
1770 }
1771
1772 EIGEN_DEVICE_FUNC
1773 static EIGEN_STRONG_INLINE float incbps(float a, float b, float x) {
1774 float t, u, y, s;
1775 const float machep = cephes_helper<float>::machep();
1776
1777 y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
1778 y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
1779 y += lgamma_impl<float>::run(a + b);
1780
1781 t = x / (1.0f - x);
1782 s = 0.0f;
1783 u = 1.0f;
1784 do {
1785 b -= 1.0f;
1786 if (b == 0.0f) {
1787 break;
1788 }
1789 a += 1.0f;
1790 u *= t * b / a;
1791 s += u;
1792 } while (numext::abs(u) > machep);
1793
1794 return numext::exp(y) * (1.0f + s);
1795 }
1796 };
1797
1798 template <>
1799 struct betainc_impl<float> {
1800 EIGEN_DEVICE_FUNC
1801 static float run(float a, float b, float x) {
1802 const float nan = NumTraits<float>::quiet_NaN();
1803 float ans, t;
1804
1805 if (a <= 0.0f) return nan;
1806 if (b <= 0.0f) return nan;
1807 if ((x <= 0.0f) || (x >= 1.0f)) {
1808 if (x == 0.0f) return 0.0f;
1809 if (x == 1.0f) return 1.0f;
1810
1811 return nan;
1812 }
1813
1814
1815 if (a <= 1.0f) {
1816 ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
1817 t = a * numext::log(x) + b * numext::log1p(-x) +
1818 lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a + 1.0f) -
1819 lgamma_impl<float>::run(b);
1820 return (ans + numext::exp(t));
1821 } else {
1822 return betainc_helper<float>::incbsa(a, b, x);
1823 }
1824 }
1825 };
1826
1827 template <>
1828 struct betainc_helper<double> {
1829 EIGEN_DEVICE_FUNC
1830 static EIGEN_STRONG_INLINE double incbps(double a, double b, double x) {
1831 const double machep = cephes_helper<double>::machep();
1832
1833 double s, t, u, v, n, t1, z, ai;
1834
1835 ai = 1.0 / a;
1836 u = (1.0 - b) * x;
1837 v = u / (a + 1.0);
1838 t1 = v;
1839 t = u;
1840 n = 2.0;
1841 s = 0.0;
1842 z = machep * ai;
1843 while (numext::abs(v) > z) {
1844 u = (n - b) * x / n;
1845 t *= u;
1846 v = t / (a + n);
1847 s += v;
1848 n += 1.0;
1849 }
1850 s += t1;
1851 s += ai;
1852
1853 u = a * numext::log(x);
1854
1855
1856
1857
1858
1859
1860
1861 t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1862 lgamma_impl<double>::run(b) + u + numext::log(s);
1863 return s = numext::exp(t);
1864 }
1865 };
1866
1867 template <>
1868 struct betainc_impl<double> {
1869 EIGEN_DEVICE_FUNC
1870 static double run(double aa, double bb, double xx) {
1871 const double nan = NumTraits<double>::quiet_NaN();
1872 const double machep = cephes_helper<double>::machep();
1873
1874
1875 double a, b, t, x, xc, w, y;
1876 bool reversed_a_b = false;
1877
1878 if (aa <= 0.0 || bb <= 0.0) {
1879 return nan;
1880 }
1881
1882 if ((xx <= 0.0) || (xx >= 1.0)) {
1883 if (xx == 0.0) return (0.0);
1884 if (xx == 1.0) return (1.0);
1885
1886 return nan;
1887 }
1888
1889 if ((bb * xx) <= 1.0 && xx <= 0.95) {
1890 return betainc_helper<double>::incbps(aa, bb, xx);
1891 }
1892
1893 w = 1.0 - xx;
1894
1895
1896 if (xx > (aa / (aa + bb))) {
1897 reversed_a_b = true;
1898 a = bb;
1899 b = aa;
1900 xc = xx;
1901 x = w;
1902 } else {
1903 a = aa;
1904 b = bb;
1905 xc = w;
1906 x = xx;
1907 }
1908
1909 if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
1910 t = betainc_helper<double>::incbps(a, b, x);
1911 if (t <= machep) {
1912 t = 1.0 - machep;
1913 } else {
1914 t = 1.0 - t;
1915 }
1916 return t;
1917 }
1918
1919
1920 y = x * (a + b - 2.0) - (a - 1.0);
1921 if (y < 0.0) {
1922 w = incbeta_cfe<double>::run(a, b, x, true );
1923 } else {
1924 w = incbeta_cfe<double>::run(a, b, x, false ) / xc;
1925 }
1926
1927
1928
1929
1930
1931 y = a * numext::log(x);
1932 t = b * numext::log(xc);
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945 y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1946 lgamma_impl<double>::run(b);
1947 y += numext::log(w / a);
1948 t = numext::exp(y);
1949
1950
1951
1952
1953 if (reversed_a_b) {
1954 if (t <= machep) {
1955 t = 1.0 - machep;
1956 } else {
1957 t = 1.0 - t;
1958 }
1959 }
1960 return t;
1961 }
1962 };
1963
1964 #endif
1965
1966 }
1967
1968 namespace numext {
1969
1970 template <typename Scalar>
1971 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(lgamma, Scalar)
1972 lgamma(const Scalar& x) {
1973 return EIGEN_MATHFUNC_IMPL(lgamma, Scalar)::run(x);
1974 }
1975
1976 template <typename Scalar>
1977 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(digamma, Scalar)
1978 digamma(const Scalar& x) {
1979 return EIGEN_MATHFUNC_IMPL(digamma, Scalar)::run(x);
1980 }
1981
1982 template <typename Scalar>
1983 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(zeta, Scalar)
1984 zeta(const Scalar& x, const Scalar& q) {
1985 return EIGEN_MATHFUNC_IMPL(zeta, Scalar)::run(x, q);
1986 }
1987
1988 template <typename Scalar>
1989 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(polygamma, Scalar)
1990 polygamma(const Scalar& n, const Scalar& x) {
1991 return EIGEN_MATHFUNC_IMPL(polygamma, Scalar)::run(n, x);
1992 }
1993
1994 template <typename Scalar>
1995 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erf, Scalar)
1996 erf(const Scalar& x) {
1997 return EIGEN_MATHFUNC_IMPL(erf, Scalar)::run(x);
1998 }
1999
2000 template <typename Scalar>
2001 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erfc, Scalar)
2002 erfc(const Scalar& x) {
2003 return EIGEN_MATHFUNC_IMPL(erfc, Scalar)::run(x);
2004 }
2005
2006 template <typename Scalar>
2007 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(ndtri, Scalar)
2008 ndtri(const Scalar& x) {
2009 return EIGEN_MATHFUNC_IMPL(ndtri, Scalar)::run(x);
2010 }
2011
2012 template <typename Scalar>
2013 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma, Scalar)
2014 igamma(const Scalar& a, const Scalar& x) {
2015 return EIGEN_MATHFUNC_IMPL(igamma, Scalar)::run(a, x);
2016 }
2017
2018 template <typename Scalar>
2019 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma_der_a, Scalar)
2020 igamma_der_a(const Scalar& a, const Scalar& x) {
2021 return EIGEN_MATHFUNC_IMPL(igamma_der_a, Scalar)::run(a, x);
2022 }
2023
2024 template <typename Scalar>
2025 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(gamma_sample_der_alpha, Scalar)
2026 gamma_sample_der_alpha(const Scalar& a, const Scalar& x) {
2027 return EIGEN_MATHFUNC_IMPL(gamma_sample_der_alpha, Scalar)::run(a, x);
2028 }
2029
2030 template <typename Scalar>
2031 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar)
2032 igammac(const Scalar& a, const Scalar& x) {
2033 return EIGEN_MATHFUNC_IMPL(igammac, Scalar)::run(a, x);
2034 }
2035
2036 template <typename Scalar>
2037 EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(betainc, Scalar)
2038 betainc(const Scalar& a, const Scalar& b, const Scalar& x) {
2039 return EIGEN_MATHFUNC_IMPL(betainc, Scalar)::run(a, b, x);
2040 }
2041
2042 }
2043 }
2044
2045 #endif