File indexing completed on 2025-02-22 10:34:44
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
0011 #define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H
0012
0013 namespace Eigen {
0014
0015 namespace internal {
0016
0017
0018
0019 template<typename Lhs, typename Rhs, typename ResultType>
0020 static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, const typename ResultType::RealScalar& tolerance)
0021 {
0022
0023
0024 typedef typename remove_all<Rhs>::type::Scalar RhsScalar;
0025 typedef typename remove_all<ResultType>::type::Scalar ResScalar;
0026 typedef typename remove_all<Lhs>::type::StorageIndex StorageIndex;
0027
0028
0029 Index rows = lhs.innerSize();
0030 Index cols = rhs.outerSize();
0031
0032 eigen_assert(lhs.outerSize() == rhs.innerSize());
0033
0034
0035 AmbiVector<ResScalar,StorageIndex> tempVector(rows);
0036
0037
0038 if(ResultType::IsRowMajor)
0039 res.resize(cols, rows);
0040 else
0041 res.resize(rows, cols);
0042
0043 evaluator<Lhs> lhsEval(lhs);
0044 evaluator<Rhs> rhsEval(rhs);
0045
0046
0047
0048
0049
0050
0051
0052 Index estimated_nnz_prod = lhsEval.nonZerosEstimate() + rhsEval.nonZerosEstimate();
0053
0054 res.reserve(estimated_nnz_prod);
0055 double ratioColRes = double(estimated_nnz_prod)/(double(lhs.rows())*double(rhs.cols()));
0056 for (Index j=0; j<cols; ++j)
0057 {
0058
0059
0060
0061 tempVector.init(ratioColRes);
0062 tempVector.setZero();
0063 for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
0064 {
0065
0066 tempVector.restart();
0067 RhsScalar x = rhsIt.value();
0068 for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
0069 {
0070 tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
0071 }
0072 }
0073 res.startVec(j);
0074 for (typename AmbiVector<ResScalar,StorageIndex>::Iterator it(tempVector,tolerance); it; ++it)
0075 res.insertBackByOuterInner(j,it.index()) = it.value();
0076 }
0077 res.finalize();
0078 }
0079
0080 template<typename Lhs, typename Rhs, typename ResultType,
0081 int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
0082 int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
0083 int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
0084 struct sparse_sparse_product_with_pruning_selector;
0085
0086 template<typename Lhs, typename Rhs, typename ResultType>
0087 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
0088 {
0089 typedef typename ResultType::RealScalar RealScalar;
0090
0091 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0092 {
0093 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
0094 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance);
0095 res.swap(_res);
0096 }
0097 };
0098
0099 template<typename Lhs, typename Rhs, typename ResultType>
0100 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
0101 {
0102 typedef typename ResultType::RealScalar RealScalar;
0103 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0104 {
0105
0106 typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> SparseTemporaryType;
0107 SparseTemporaryType _res(res.rows(), res.cols());
0108 internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance);
0109 res = _res;
0110 }
0111 };
0112
0113 template<typename Lhs, typename Rhs, typename ResultType>
0114 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
0115 {
0116 typedef typename ResultType::RealScalar RealScalar;
0117 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0118 {
0119
0120 typename remove_all<ResultType>::type _res(res.rows(), res.cols());
0121 internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance);
0122 res.swap(_res);
0123 }
0124 };
0125
0126 template<typename Lhs, typename Rhs, typename ResultType>
0127 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
0128 {
0129 typedef typename ResultType::RealScalar RealScalar;
0130 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0131 {
0132 typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
0133 typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
0134 ColMajorMatrixLhs colLhs(lhs);
0135 ColMajorMatrixRhs colRhs(rhs);
0136 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,ColMajorMatrixRhs,ResultType>(colLhs, colRhs, res, tolerance);
0137
0138
0139
0140
0141
0142
0143 }
0144 };
0145
0146 template<typename Lhs, typename Rhs, typename ResultType>
0147 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
0148 {
0149 typedef typename ResultType::RealScalar RealScalar;
0150 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0151 {
0152 typedef SparseMatrix<typename Lhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixLhs;
0153 RowMajorMatrixLhs rowLhs(lhs);
0154 sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
0155 }
0156 };
0157
0158 template<typename Lhs, typename Rhs, typename ResultType>
0159 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
0160 {
0161 typedef typename ResultType::RealScalar RealScalar;
0162 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0163 {
0164 typedef SparseMatrix<typename Rhs::Scalar,RowMajor,typename Lhs::StorageIndex> RowMajorMatrixRhs;
0165 RowMajorMatrixRhs rowRhs(rhs);
0166 sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
0167 }
0168 };
0169
0170 template<typename Lhs, typename Rhs, typename ResultType>
0171 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
0172 {
0173 typedef typename ResultType::RealScalar RealScalar;
0174 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0175 {
0176 typedef SparseMatrix<typename Rhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixRhs;
0177 ColMajorMatrixRhs colRhs(rhs);
0178 internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
0179 }
0180 };
0181
0182 template<typename Lhs, typename Rhs, typename ResultType>
0183 struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
0184 {
0185 typedef typename ResultType::RealScalar RealScalar;
0186 static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
0187 {
0188 typedef SparseMatrix<typename Lhs::Scalar,ColMajor,typename Lhs::StorageIndex> ColMajorMatrixLhs;
0189 ColMajorMatrixLhs colLhs(lhs);
0190 internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
0191 }
0192 };
0193
0194 }
0195
0196 }
0197
0198 #endif