File indexing completed on 2025-01-18 09:43:03
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013 #ifndef BOOST_UBLAS_TENSOR_MULTIPLICATION
0014 #define BOOST_UBLAS_TENSOR_MULTIPLICATION
0015
0016 #include <cassert>
0017
0018 namespace boost {
0019 namespace numeric {
0020 namespace ublas {
0021 namespace detail {
0022 namespace recursive {
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0053 void ttt(SizeType const k,
0054 SizeType const r, SizeType const s, SizeType const q,
0055 SizeType const*const phia, SizeType const*const phib,
0056 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0057 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0058 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0059 {
0060 if(k < r)
0061 {
0062 assert(nc[k] == na[phia[k]-1]);
0063 for(size_t ic = 0u; ic < nc[k]; a += wa[phia[k]-1], c += wc[k], ++ic)
0064 ttt(k+1, r, s, q, phia,phib, c, nc, wc, a, na, wa, b, nb, wb);
0065 }
0066 else if(k < r+s)
0067 {
0068 assert(nc[k] == nb[phib[k-r]-1]);
0069 for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
0070 ttt(k+1, r, s, q, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
0071 }
0072 else if(k < r+s+q-1)
0073 {
0074 assert(na[phia[k-s]-1] == nb[phib[k-r]-1]);
0075 for(size_t ia = 0u; ia < na[phia[k-s]-1]; a += wa[phia[k-s]-1], b += wb[phib[k-r]-1], ++ia)
0076 ttt(k+1, r, s, q, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
0077 }
0078 else
0079 {
0080 assert(na[phia[k-s]-1] == nb[phib[k-r]-1]);
0081 for(size_t ia = 0u; ia < na[phia[k-s]-1]; a += wa[phia[k-s]-1], b += wb[phib[k-r]-1], ++ia)
0082 *c += *a * *b;
0083 }
0084 }
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110
0111
0112
0113
0114
0115
0116 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0117 void ttt(SizeType const k,
0118 SizeType const r, SizeType const s, SizeType const q,
0119 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0120 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0121 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0122 {
0123 if(k < r)
0124 {
0125 assert(nc[k] == na[k]);
0126 for(size_t ic = 0u; ic < nc[k]; a += wa[k], c += wc[k], ++ic)
0127 ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
0128 }
0129 else if(k < r+s)
0130 {
0131 assert(nc[k] == nb[k-r]);
0132 for(size_t ic = 0u; ic < nc[k]; b += wb[k-r], c += wc[k], ++ic)
0133 ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
0134 }
0135 else if(k < r+s+q-1)
0136 {
0137 assert(na[k-s] == nb[k-r]);
0138 for(size_t ia = 0u; ia < na[k-s]; a += wa[k-s], b += wb[k-r], ++ia)
0139 ttt(k+1, r, s, q, c, nc, wc, a, na, wa, b, nb, wb);
0140 }
0141 else
0142 {
0143 assert(na[k-s] == nb[k-r]);
0144 for(size_t ia = 0u; ia < na[k-s]; a += wa[k-s], b += wb[k-r], ++ia)
0145 *c += *a * *b;
0146 }
0147 }
0148
0149
0150
0151
0152
0153
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163
0164
0165
0166
0167
0168
0169 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0170 void ttm(SizeType const m, SizeType const r,
0171 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0172 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0173 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0174 {
0175
0176 if(r == m) {
0177 ttm(m, r-1, c, nc, wc, a, na, wa, b, nb, wb);
0178 }
0179 else if(r == 0){
0180 for(auto i0 = 0ul; i0 < nc[0]; c += wc[0], a += wa[0], ++i0) {
0181 auto cm = c;
0182 auto b0 = b;
0183 for(auto i0 = 0ul; i0 < nc[m]; cm += wc[m], b0 += wb[0], ++i0){
0184 auto am = a;
0185 auto b1 = b0;
0186 for(auto i1 = 0ul; i1 < nb[1]; am += wa[m], b1 += wb[1], ++i1)
0187 *cm += *am * *b1;
0188 }
0189 }
0190 }
0191
0192 else{
0193 for(auto i = 0ul; i < na[r]; c += wc[r], a += wa[r], ++i)
0194 ttm(m, r-1, c, nc, wc, a, na, wa, b, nb, wb);
0195 }
0196 }
0197
0198
0199
0200
0201
0202
0203
0204
0205
0206
0207
0208
0209
0210
0211
0212
0213
0214
0215
0216 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0217 void ttm0( SizeType const r,
0218 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0219 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0220 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0221 {
0222
0223 if(r > 1){
0224 for(auto i = 0ul; i < na[r]; c += wc[r], a += wa[r], ++i)
0225 ttm0(r-1, c, nc, wc, a, na, wa, b, nb, wb);
0226 }
0227 else{
0228 for(auto i1 = 0ul; i1 < nc[1]; c += wc[1], a += wa[1], ++i1) {
0229 auto cm = c;
0230 auto b0 = b;
0231
0232 for(auto i0 = 0ul; i0 < nc[0]; cm += wc[0], b0 += wb[0], ++i0){
0233
0234 auto am = a;
0235 auto b1 = b0;
0236 for(auto i1 = 0u; i1 < nb[1]; am += wa[0], b1 += wb[1], ++i1){
0237
0238 *cm += *am * *b1;
0239 }
0240 }
0241 }
0242 }
0243 }
0244
0245
0246
0247
0248
0249
0250
0251
0252
0253
0254
0255
0256
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267
0268
0269
0270 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0271 void ttv( SizeType const m, SizeType const r, SizeType const q,
0272 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0273 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0274 PointerIn2 b)
0275 {
0276
0277 if(r == m) {
0278 ttv(m, r-1, q, c, nc, wc, a, na, wa, b);
0279 }
0280 else if(r == 0){
0281 for(auto i0 = 0u; i0 < na[0]; c += wc[0], a += wa[0], ++i0) {
0282 auto c1 = c; auto a1 = a; auto b1 = b;
0283 for(auto im = 0u; im < na[m]; a1 += wa[m], ++b1, ++im)
0284 *c1 += *a1 * *b1;
0285 }
0286 }
0287 else{
0288 for(auto i = 0u; i < na[r]; c += wc[q], a += wa[r], ++i)
0289 ttv(m, r-1, q-1, c, nc, wc, a, na, wa, b);
0290 }
0291 }
0292
0293
0294
0295
0296
0297
0298
0299
0300
0301
0302
0303
0304
0305
0306
0307
0308
0309
0310 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0311 void ttv0(SizeType const r,
0312 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0313 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0314 PointerIn2 b)
0315 {
0316
0317 if(r > 1){
0318 for(auto i = 0u; i < na[r]; c += wc[r-1], a += wa[r], ++i)
0319 ttv0(r-1, c, nc, wc, a, na, wa, b);
0320 }
0321 else{
0322 for(auto i1 = 0u; i1 < na[1]; c += wc[0], a += wa[1], ++i1)
0323 {
0324 auto c1 = c; auto a1 = a; auto b1 = b;
0325 for(auto i0 = 0u; i0 < na[0]; a1 += wa[0], ++b1, ++i0)
0326 *c1 += *a1 * *b1;
0327 }
0328 }
0329 }
0330
0331
0332
0333
0334
0335
0336
0337
0338
0339
0340
0341
0342
0343
0344
0345
0346
0347 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0348 void mtv(SizeType const m,
0349 PointerOut c, SizeType const*const , SizeType const*const wc,
0350 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0351 PointerIn2 b)
0352 {
0353
0354 const auto o = (m == 0) ? 1 : 0;
0355
0356 for(auto io = 0u; io < na[o]; c += wc[o], a += wa[o], ++io) {
0357 auto c1 = c; auto a1 = a; auto b1 = b;
0358 for(auto im = 0u; im < na[m]; a1 += wa[m], ++b1, ++im)
0359 *c1 += *a1 * *b1;
0360 }
0361 }
0362
0363
0364
0365
0366
0367
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377
0378
0379
0380 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0381 void mtm(PointerOut c, SizeType const*const nc, SizeType const*const wc,
0382 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0383 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0384 {
0385
0386
0387
0388 assert(nc[0] == na[0]);
0389 assert(nc[1] == nb[1]);
0390 assert(na[1] == nb[0]);
0391
0392 auto cj = c; auto bj = b;
0393 for(auto j = 0u; j < nc[1]; cj += wc[1], bj += wb[1], ++j) {
0394
0395 auto bk = bj; auto ak = a;
0396 for(auto k = 0u; k < na[1]; ak += wa[1], bk += wb[0], ++k) {
0397
0398 auto ci = cj; auto ai = ak;
0399 for(auto i = 0u; i < na[0]; ai += wa[0], ci += wc[0], ++i){
0400 *ci += *ai * *bk;
0401 }
0402
0403 }
0404
0405 }
0406 }
0407
0408
0409
0410
0411
0412
0413
0414
0415
0416
0417
0418
0419
0420
0421
0422
0423
0424
0425 template <class PointerIn1, class PointerIn2, class value_t, class SizeType>
0426 value_t inner(SizeType const r, SizeType const*const n,
0427 PointerIn1 a, SizeType const*const wa,
0428 PointerIn2 b, SizeType const*const wb,
0429 value_t v)
0430 {
0431 if(r == 0)
0432 for(auto i0 = 0u; i0 < n[0]; a += wa[0], b += wb[0], ++i0)
0433 v += *a * *b;
0434 else
0435 for(auto ir = 0u; ir < n[r]; a += wa[r], b += wb[r], ++ir)
0436 v = inner(r-1, n, a, wa, b, wb, v);
0437 return v;
0438 }
0439
0440
0441 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0442 void outer_2x2(SizeType const pa,
0443 PointerOut c, SizeType const*const , SizeType const*const wc,
0444 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0445 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0446 {
0447
0448
0449
0450
0451 for(auto ib1 = 0u; ib1 < nb[1]; b += wb[1], c += wc[pa+1], ++ib1) {
0452 auto c2 = c;
0453 auto b0 = b;
0454 for(auto ib0 = 0u; ib0 < nb[0]; b0 += wb[0], c2 += wc[pa], ++ib0) {
0455 const auto b = *b0;
0456 auto c1 = c2;
0457 auto a1 = a;
0458 for(auto ia1 = 0u; ia1 < na[1]; a1 += wa[1], c1 += wc[1], ++ia1) {
0459 auto a0 = a1;
0460 auto c0 = c1;
0461 for(SizeType ia0 = 0u; ia0 < na[0]; a0 += wa[0], c0 += wc[0], ++ia0)
0462 *c0 = *a0 * b;
0463 }
0464 }
0465 }
0466 }
0467
0468
0469
0470
0471
0472
0473
0474
0475
0476
0477
0478
0479
0480
0481
0482
0483
0484
0485
0486
0487
0488
0489
0490
0491
0492 template<class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0493 void outer(SizeType const pa,
0494 SizeType const rc, PointerOut c, SizeType const*const nc, SizeType const*const wc,
0495 SizeType const ra, PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0496 SizeType const rb, PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0497 {
0498 if(rb > 1)
0499 for(auto ib = 0u; ib < nb[rb]; b += wb[rb], c += wc[rc], ++ib)
0500 outer(pa, rc-1, c, nc, wc, ra, a, na, wa, rb-1, b, nb, wb);
0501 else if(ra > 1)
0502 for(auto ia = 0u; ia < na[ra]; a += wa[ra], c += wc[ra], ++ia)
0503 outer(pa, rc-1, c, nc, wc, ra-1, a, na, wa, rb, b, nb, wb);
0504 else
0505 outer_2x2(pa, c, nc, wc, a, na, wa, b, nb, wb);
0506 }
0507
0508
0509
0510
0511
0512
0513
0514
0515
0516
0517
0518
0519
0520
0521
0522
0523
0524
0525
0526
0527
0528
0529
0530
0531
0532
0533
0534
0535
0536 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0537 void outer(SizeType const k,
0538 SizeType const r, SizeType const s,
0539 SizeType const*const phia, SizeType const*const phib,
0540 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0541 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0542 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0543 {
0544 if(k < r)
0545 {
0546 assert(nc[k] == na[phia[k]-1]);
0547 for(size_t ic = 0u; ic < nc[k]; a += wa[phia[k]-1], c += wc[k], ++ic)
0548 outer(k+1, r, s, phia,phib, c, nc, wc, a, na, wa, b, nb, wb);
0549 }
0550 else if(k < r+s-1)
0551 {
0552 assert(nc[k] == nb[phib[k-r]-1]);
0553 for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
0554 outer(k+1, r, s, phia, phib, c, nc, wc, a, na, wa, b, nb, wb);
0555 }
0556 else
0557 {
0558 assert(nc[k] == nb[phib[k-r]-1]);
0559 for(size_t ic = 0u; ic < nc[k]; b += wb[phib[k-r]-1], c += wc[k], ++ic)
0560 *c = *a * *b;
0561 }
0562 }
0563
0564
0565 }
0566 }
0567 }
0568 }
0569 }
0570
0571
0572
0573
0574
0575
0576
0577
0578
0579
0580
0581
0582
0583
0584
0585 #include <stdexcept>
0586
0587 namespace boost {
0588 namespace numeric {
0589 namespace ublas {
0590
0591
0592
0593
0594
0595
0596
0597
0598
0599
0600
0601
0602
0603
0604
0605
0606
0607
0608
0609
0610
0611 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0612 void ttv(SizeType const m, SizeType const p,
0613 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0614 const PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0615 const PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0616 {
0617 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
0618 "Static error in boost::numeric::ublas::ttv: Argument types for pointers are not pointer types.");
0619
0620 if( m == 0)
0621 throw std::length_error("Error in boost::numeric::ublas::ttv: Contraction mode must be greater than zero.");
0622
0623 if( p < m )
0624 throw std::length_error("Error in boost::numeric::ublas::ttv: Rank must be greater equal the modus.");
0625
0626 if( p == 0)
0627 throw std::length_error("Error in boost::numeric::ublas::ttv: Rank must be greater than zero.");
0628
0629 if(c == nullptr || a == nullptr || b == nullptr)
0630 throw std::length_error("Error in boost::numeric::ublas::ttv: Pointers shall not be null pointers.");
0631
0632 for(auto i = 0u; i < m-1; ++i)
0633 if(na[i] != nc[i])
0634 throw std::length_error("Error in boost::numeric::ublas::ttv: Extents (except of dimension mode) of A and C must be equal.");
0635
0636 for(auto i = m; i < p; ++i)
0637 if(na[i] != nc[i-1])
0638 throw std::length_error("Error in boost::numeric::ublas::ttv: Extents (except of dimension mode) of A and C must be equal.");
0639
0640 const auto max = std::max(nb[0], nb[1]);
0641 if( na[m-1] != max)
0642 throw std::length_error("Error in boost::numeric::ublas::ttv: Extent of dimension mode of A and b must be equal.");
0643
0644
0645 if((m != 1) && (p > 2))
0646 detail::recursive::ttv(m-1, p-1, p-2, c, nc, wc, a, na, wa, b);
0647 else if ((m == 1) && (p > 2))
0648 detail::recursive::ttv0(p-1, c, nc, wc, a, na, wa, b);
0649 else if( p == 2 )
0650 detail::recursive::mtv(m-1, c, nc, wc, a, na, wa, b);
0651 else {
0652 auto v = std::remove_pointer_t<std::remove_cv_t<PointerOut>>{};
0653 *c = detail::recursive::inner(SizeType(0), na, a, wa, b, wb, v);
0654 }
0655
0656 }
0657
0658
0659
0660
0661
0662
0663
0664
0665
0666
0667
0668
0669
0670
0671
0672
0673
0674
0675
0676
0677
0678
0679
0680
0681 template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
0682 void ttm(SizeType const m, SizeType const p,
0683 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0684 const PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0685 const PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0686 {
0687
0688 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
0689 "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
0690
0691 if( m == 0 )
0692 throw std::length_error("Error in boost::numeric::ublas::ttm: Contraction mode must be greater than zero.");
0693
0694 if( p < m )
0695 throw std::length_error("Error in boost::numeric::ublas::ttm: Rank must be greater equal than the specified mode.");
0696
0697 if( p == 0)
0698 throw std::length_error("Error in boost::numeric::ublas::ttm:Rank must be greater than zero.");
0699
0700 if(c == nullptr || a == nullptr || b == nullptr)
0701 throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
0702
0703 for(auto i = 0u; i < m-1; ++i)
0704 if(na[i] != nc[i])
0705 throw std::length_error("Error in boost::numeric::ublas::ttm: Extents (except of dimension mode) of A and C must be equal.");
0706
0707 for(auto i = m; i < p; ++i)
0708 if(na[i] != nc[i])
0709 throw std::length_error("Error in boost::numeric::ublas::ttm: Extents (except of dimension mode) of A and C must be equal.");
0710
0711 if(na[m-1] != nb[1])
0712 throw std::length_error("Error in boost::numeric::ublas::ttm: 2nd Extent of B and M-th Extent of A must be the equal.");
0713
0714 if(nc[m-1] != nb[0])
0715 throw std::length_error("Error in boost::numeric::ublas::ttm: 1nd Extent of B and M-th Extent of C must be the equal.");
0716
0717 if ( m != 1 )
0718 detail::recursive::ttm (m-1, p-1, c, nc, wc, a, na, wa, b, nb, wb);
0719 else
0720 detail::recursive::ttm0( p-1, c, nc, wc, a, na, wa, b, nb, wb);
0721
0722 }
0723
0724
0725
0726
0727
0728
0729
0730
0731
0732
0733
0734
0735
0736
0737
0738
0739
0740
0741
0742
0743
0744
0745
0746
0747
0748
0749
0750
0751 template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
0752 void ttt(SizeType const pa, SizeType const pb, SizeType const q,
0753 SizeType const*const phia, SizeType const*const phib,
0754 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0755 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0756 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0757 {
0758 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
0759 "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
0760
0761 if( pa == 0 || pb == 0)
0762 throw std::length_error("Error in boost::numeric::ublas::ttt: tensor order must be greater zero.");
0763
0764 if( q > pa && q > pb)
0765 throw std::length_error("Error in boost::numeric::ublas::ttt: number of contraction must be smaller than or equal to the tensor order.");
0766
0767
0768 SizeType const r = pa - q;
0769 SizeType const s = pb - q;
0770
0771 if(c == nullptr || a == nullptr || b == nullptr)
0772 throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
0773
0774 for(auto i = 0ul; i < r; ++i)
0775 if( na[phia[i]-1] != nc[i] )
0776 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and res tensor not correct.");
0777
0778 for(auto i = 0ul; i < s; ++i)
0779 if( nb[phib[i]-1] != nc[r+i] )
0780 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of rhs and res not correct.");
0781
0782 for(auto i = 0ul; i < q; ++i)
0783 if( nb[phib[s+i]-1] != na[phia[r+i]-1] )
0784 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and rhs not correct.");
0785
0786
0787 if(q == 0ul)
0788 detail::recursive::outer(SizeType{0},r,s, phia,phib, c,nc,wc, a,na,wa, b,nb,wb);
0789 else
0790 detail::recursive::ttt(SizeType{0},r,s,q, phia,phib, c,nc,wc, a,na,wa, b,nb,wb);
0791 }
0792
0793
0794
0795
0796
0797
0798
0799
0800
0801
0802
0803
0804
0805
0806
0807
0808
0809
0810
0811
0812
0813
0814
0815
0816
0817
0818
0819 template <class PointerIn1, class PointerIn2, class PointerOut, class SizeType>
0820 void ttt(SizeType const pa, SizeType const pb, SizeType const q,
0821 PointerOut c, SizeType const*const nc, SizeType const*const wc,
0822 PointerIn1 a, SizeType const*const na, SizeType const*const wa,
0823 PointerIn2 b, SizeType const*const nb, SizeType const*const wb)
0824 {
0825 static_assert( std::is_pointer<PointerOut>::value & std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value,
0826 "Static error in boost::numeric::ublas::ttm: Argument types for pointers are not pointer types.");
0827
0828 if( pa == 0 || pb == 0)
0829 throw std::length_error("Error in boost::numeric::ublas::ttt: tensor order must be greater zero.");
0830
0831 if( q > pa && q > pb)
0832 throw std::length_error("Error in boost::numeric::ublas::ttt: number of contraction must be smaller than or equal to the tensor order.");
0833
0834
0835 SizeType const r = pa - q;
0836 SizeType const s = pb - q;
0837 SizeType const pc = r+s;
0838
0839 if(c == nullptr || a == nullptr || b == nullptr)
0840 throw std::length_error("Error in boost::numeric::ublas::ttm: Pointers shall not be null pointers.");
0841
0842 for(auto i = 0ul; i < r; ++i)
0843 if( na[i] != nc[i] )
0844 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and res tensor not correct.");
0845
0846 for(auto i = 0ul; i < s; ++i)
0847 if( nb[i] != nc[r+i] )
0848 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of rhs and res not correct.");
0849
0850 for(auto i = 0ul; i < q; ++i)
0851 if( nb[s+i] != na[r+i] )
0852 throw std::length_error("Error in boost::numeric::ublas::ttt: dimensions of lhs and rhs not correct.");
0853
0854 using value_type = std::decay_t<decltype(*c)>;
0855
0856
0857
0858 if(q == 0ul)
0859 detail::recursive::outer(pa, pc-1, c,nc,wc, pa-1, a,na,wa, pb-1, b,nb,wb);
0860 else if(r == 0ul && s == 0ul)
0861 *c = detail::recursive::inner(q-1, na, a,wa, b,wb, value_type(0) );
0862 else
0863 detail::recursive::ttt(SizeType{0},r,s,q, c,nc,wc, a,na,wa, b,nb,wb);
0864 }
0865
0866
0867
0868
0869
0870
0871
0872
0873
0874
0875
0876
0877
0878
0879
0880
0881
0882
0883 template <class PointerIn1, class PointerIn2, class value_t, class SizeType>
0884 auto inner(const SizeType p, SizeType const*const n,
0885 const PointerIn1 a, SizeType const*const wa,
0886 const PointerIn2 b, SizeType const*const wb,
0887 value_t v)
0888 {
0889 static_assert( std::is_pointer<PointerIn1>::value && std::is_pointer<PointerIn2>::value,
0890 "Static error in boost::numeric::ublas::inner: Argument types for pointers must be pointer types.");
0891 if(p<2)
0892 throw std::length_error("Error in boost::numeric::ublas::inner: Rank must be greater than zero.");
0893 if(a == nullptr || b == nullptr)
0894 throw std::length_error("Error in boost::numeric::ublas::inner: Pointers shall not be null pointers.");
0895
0896 return detail::recursive::inner(p-1, n, a, wa, b, wb, v);
0897
0898 }
0899
0900
0901
0902
0903
0904
0905
0906
0907
0908
0909
0910
0911
0912
0913
0914
0915
0916
0917
0918
0919
0920 template <class PointerOut, class PointerIn1, class PointerIn2, class SizeType>
0921 void outer(PointerOut c, SizeType const pc, SizeType const*const nc, SizeType const*const wc,
0922 const PointerIn1 a, SizeType const pa, SizeType const*const na, SizeType const*const wa,
0923 const PointerIn2 b, SizeType const pb, SizeType const*const nb, SizeType const*const wb)
0924 {
0925 static_assert( std::is_pointer<PointerIn1>::value & std::is_pointer<PointerIn2>::value & std::is_pointer<PointerOut>::value,
0926 "Static error in boost::numeric::ublas::outer: argument types for pointers must be pointer types.");
0927 if(pa < 2u || pb < 2u)
0928 throw std::length_error("Error in boost::numeric::ublas::outer: number of extents of lhs and rhs tensor must be equal or greater than two.");
0929 if((pa + pb) != pc)
0930 throw std::length_error("Error in boost::numeric::ublas::outer: number of extents of lhs plus rhs tensor must be equal to the number of extents of C.");
0931 if(a == nullptr || b == nullptr || c == nullptr)
0932 throw std::length_error("Error in boost::numeric::ublas::outer: pointers shall not be null pointers.");
0933
0934 detail::recursive::outer(pa, pc-1, c, nc, wc, pa-1, a, na, wa, pb-1, b, nb, wb);
0935
0936 }
0937
0938
0939
0940
0941 }
0942 }
0943 }
0944
0945 #endif