Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-08-27 08:47:20

0001 // Licensed to the Apache Software Foundation (ASF) under one
0002 // or more contributor license agreements.  See the NOTICE file
0003 // distributed with this work for additional information
0004 // regarding copyright ownership.  The ASF licenses this file
0005 // to you under the Apache License, Version 2.0 (the
0006 // "License"); you may not use this file except in compliance
0007 // with the License.  You may obtain a copy of the License at
0008 //
0009 //   http://www.apache.org/licenses/LICENSE-2.0
0010 //
0011 // Unless required by applicable law or agreed to in writing,
0012 // software distributed under the License is distributed on an
0013 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
0014 // KIND, either express or implied.  See the License for the
0015 // specific language governing permissions and limitations
0016 // under the License.
0017 
0018 #pragma once
0019 
0020 #include <cstddef>
0021 #include <cstdint>
0022 #include <memory>
0023 #include <string>
0024 #include <utility>
0025 #include <vector>
0026 
0027 #include "arrow/buffer.h"
0028 #include "arrow/compare.h"
0029 #include "arrow/result.h"
0030 #include "arrow/status.h"
0031 #include "arrow/tensor.h"  // IWYU pragma: export
0032 #include "arrow/type.h"
0033 #include "arrow/util/checked_cast.h"
0034 #include "arrow/util/macros.h"
0035 #include "arrow/util/visibility.h"
0036 
0037 namespace arrow {
0038 
0039 class MemoryPool;
0040 
0041 namespace internal {
0042 
0043 ARROW_EXPORT
0044 Status CheckSparseIndexMaximumValue(const std::shared_ptr<DataType>& index_value_type,
0045                                     const std::vector<int64_t>& shape);
0046 
0047 }  // namespace internal
0048 
0049 // ----------------------------------------------------------------------
0050 // SparseIndex class
0051 
0052 struct SparseTensorFormat {
0053   /// EXPERIMENTAL: The index format type of SparseTensor
0054   enum type {
0055     /// Coordinate list (COO) format.
0056     COO,
0057     /// Compressed sparse row (CSR) format.
0058     CSR,
0059     /// Compressed sparse column (CSC) format.
0060     CSC,
0061     /// Compressed sparse fiber (CSF) format.
0062     CSF
0063   };
0064 };
0065 
0066 /// \brief EXPERIMENTAL: The base class for the index of a sparse tensor
0067 ///
0068 /// SparseIndex describes where the non-zero elements are within a SparseTensor.
0069 ///
0070 /// There are several ways to represent this.  The format_id is used to
0071 /// distinguish what kind of representation is used.  Each possible value of
0072 /// format_id must have only one corresponding concrete subclass of SparseIndex.
0073 class ARROW_EXPORT SparseIndex {
0074  public:
0075   explicit SparseIndex(SparseTensorFormat::type format_id) : format_id_(format_id) {}
0076 
0077   virtual ~SparseIndex() = default;
0078 
0079   /// \brief Return the identifier of the format type
0080   SparseTensorFormat::type format_id() const { return format_id_; }
0081 
0082   /// \brief Return the number of non zero values in the sparse tensor related
0083   /// to this sparse index
0084   virtual int64_t non_zero_length() const = 0;
0085 
0086   /// \brief Return the string representation of the sparse index
0087   virtual std::string ToString() const = 0;
0088 
0089   virtual Status ValidateShape(const std::vector<int64_t>& shape) const;
0090 
0091  protected:
0092   const SparseTensorFormat::type format_id_;
0093 };
0094 
0095 namespace internal {
0096 template <typename SparseIndexType>
0097 class SparseIndexBase : public SparseIndex {
0098  public:
0099   SparseIndexBase() : SparseIndex(SparseIndexType::format_id) {}
0100 };
0101 }  // namespace internal
0102 
0103 // ----------------------------------------------------------------------
0104 // SparseCOOIndex class
0105 
0106 /// \brief EXPERIMENTAL: The index data for a COO sparse tensor
0107 ///
0108 /// A COO sparse index manages the location of its non-zero values by their
0109 /// coordinates.
0110 class ARROW_EXPORT SparseCOOIndex : public internal::SparseIndexBase<SparseCOOIndex> {
0111  public:
0112   static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::COO;
0113 
0114   /// \brief Make SparseCOOIndex from a coords tensor and canonicality
0115   static Result<std::shared_ptr<SparseCOOIndex>> Make(
0116       const std::shared_ptr<Tensor>& coords, bool is_canonical);
0117 
0118   /// \brief Make SparseCOOIndex from a coords tensor with canonicality auto-detection
0119   static Result<std::shared_ptr<SparseCOOIndex>> Make(
0120       const std::shared_ptr<Tensor>& coords);
0121 
0122   /// \brief Make SparseCOOIndex from raw properties with canonicality auto-detection
0123   static Result<std::shared_ptr<SparseCOOIndex>> Make(
0124       const std::shared_ptr<DataType>& indices_type,
0125       const std::vector<int64_t>& indices_shape,
0126       const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data);
0127 
0128   /// \brief Make SparseCOOIndex from raw properties
0129   static Result<std::shared_ptr<SparseCOOIndex>> Make(
0130       const std::shared_ptr<DataType>& indices_type,
0131       const std::vector<int64_t>& indices_shape,
0132       const std::vector<int64_t>& indices_strides, std::shared_ptr<Buffer> indices_data,
0133       bool is_canonical);
0134 
0135   /// \brief Make SparseCOOIndex from sparse tensor's shape properties and data
0136   /// with canonicality auto-detection
0137   ///
0138   /// The indices_data should be in row-major (C-like) order.  If not,
0139   /// use the raw properties constructor.
0140   static Result<std::shared_ptr<SparseCOOIndex>> Make(
0141       const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
0142       int64_t non_zero_length, std::shared_ptr<Buffer> indices_data);
0143 
0144   /// \brief Make SparseCOOIndex from sparse tensor's shape properties and data
0145   ///
0146   /// The indices_data should be in row-major (C-like) order.  If not,
0147   /// use the raw properties constructor.
0148   static Result<std::shared_ptr<SparseCOOIndex>> Make(
0149       const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
0150       int64_t non_zero_length, std::shared_ptr<Buffer> indices_data, bool is_canonical);
0151 
0152   /// \brief Construct SparseCOOIndex from column-major NumericTensor
0153   explicit SparseCOOIndex(const std::shared_ptr<Tensor>& coords, bool is_canonical);
0154 
0155   /// \brief Return a tensor that has the coordinates of the non-zero values
0156   ///
0157   /// The returned tensor is a N x D tensor where N is the number of non-zero
0158   /// values and D is the number of dimensions in the logical data.
0159   /// The column at index `i` is a D-tuple of coordinates indicating that the
0160   /// logical value at those coordinates should be found at physical index `i`.
0161   const std::shared_ptr<Tensor>& indices() const { return coords_; }
0162 
0163   /// \brief Return the number of non zero values in the sparse tensor related
0164   /// to this sparse index
0165   int64_t non_zero_length() const override { return coords_->shape()[0]; }
0166 
0167   /// \brief Return whether a sparse tensor index is canonical, or not.
0168   /// If a sparse tensor index is canonical, it is sorted in the lexicographical order,
0169   /// and the corresponding sparse tensor doesn't have duplicated entries.
0170   bool is_canonical() const { return is_canonical_; }
0171 
0172   /// \brief Return a string representation of the sparse index
0173   std::string ToString() const override;
0174 
0175   /// \brief Return whether the COO indices are equal
0176   bool Equals(const SparseCOOIndex& other) const {
0177     return indices()->Equals(*other.indices());
0178   }
0179 
0180   inline Status ValidateShape(const std::vector<int64_t>& shape) const override {
0181     ARROW_RETURN_NOT_OK(SparseIndex::ValidateShape(shape));
0182 
0183     if (static_cast<size_t>(coords_->shape()[1]) == shape.size()) {
0184       return Status::OK();
0185     }
0186 
0187     return Status::Invalid(
0188         "shape length is inconsistent with the coords matrix in COO index");
0189   }
0190 
0191  protected:
0192   std::shared_ptr<Tensor> coords_;
0193   bool is_canonical_;
0194 };
0195 
0196 namespace internal {
0197 
0198 /// EXPERIMENTAL: The axis to be compressed
0199 enum class SparseMatrixCompressedAxis : char {
0200   /// The value for CSR matrix
0201   ROW,
0202   /// The value for CSC matrix
0203   COLUMN
0204 };
0205 
0206 ARROW_EXPORT
0207 Status ValidateSparseCSXIndex(const std::shared_ptr<DataType>& indptr_type,
0208                               const std::shared_ptr<DataType>& indices_type,
0209                               const std::vector<int64_t>& indptr_shape,
0210                               const std::vector<int64_t>& indices_shape,
0211                               char const* type_name);
0212 
0213 ARROW_EXPORT
0214 void CheckSparseCSXIndexValidity(const std::shared_ptr<DataType>& indptr_type,
0215                                  const std::shared_ptr<DataType>& indices_type,
0216                                  const std::vector<int64_t>& indptr_shape,
0217                                  const std::vector<int64_t>& indices_shape,
0218                                  char const* type_name);
0219 
0220 template <typename SparseIndexType, SparseMatrixCompressedAxis COMPRESSED_AXIS>
0221 class SparseCSXIndex : public SparseIndexBase<SparseIndexType> {
0222  public:
0223   static constexpr SparseMatrixCompressedAxis kCompressedAxis = COMPRESSED_AXIS;
0224 
0225   /// \brief Make a subclass of SparseCSXIndex from raw properties
0226   static Result<std::shared_ptr<SparseIndexType>> Make(
0227       const std::shared_ptr<DataType>& indptr_type,
0228       const std::shared_ptr<DataType>& indices_type,
0229       const std::vector<int64_t>& indptr_shape, const std::vector<int64_t>& indices_shape,
0230       std::shared_ptr<Buffer> indptr_data, std::shared_ptr<Buffer> indices_data) {
0231     ARROW_RETURN_NOT_OK(ValidateSparseCSXIndex(indptr_type, indices_type, indptr_shape,
0232                                                indices_shape,
0233                                                SparseIndexType::kTypeName));
0234     return std::make_shared<SparseIndexType>(
0235         std::make_shared<Tensor>(indptr_type, indptr_data, indptr_shape),
0236         std::make_shared<Tensor>(indices_type, indices_data, indices_shape));
0237   }
0238 
0239   /// \brief Make a subclass of SparseCSXIndex from raw properties
0240   static Result<std::shared_ptr<SparseIndexType>> Make(
0241       const std::shared_ptr<DataType>& indices_type,
0242       const std::vector<int64_t>& indptr_shape, const std::vector<int64_t>& indices_shape,
0243       std::shared_ptr<Buffer> indptr_data, std::shared_ptr<Buffer> indices_data) {
0244     return Make(indices_type, indices_type, indptr_shape, indices_shape, indptr_data,
0245                 indices_data);
0246   }
0247 
0248   /// \brief Make a subclass of SparseCSXIndex from sparse tensor's shape properties and
0249   /// data
0250   static Result<std::shared_ptr<SparseIndexType>> Make(
0251       const std::shared_ptr<DataType>& indptr_type,
0252       const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
0253       int64_t non_zero_length, std::shared_ptr<Buffer> indptr_data,
0254       std::shared_ptr<Buffer> indices_data) {
0255     std::vector<int64_t> indptr_shape({shape[0] + 1});
0256     std::vector<int64_t> indices_shape({non_zero_length});
0257     return Make(indptr_type, indices_type, indptr_shape, indices_shape, indptr_data,
0258                 indices_data);
0259   }
0260 
0261   /// \brief Make a subclass of SparseCSXIndex from sparse tensor's shape properties and
0262   /// data
0263   static Result<std::shared_ptr<SparseIndexType>> Make(
0264       const std::shared_ptr<DataType>& indices_type, const std::vector<int64_t>& shape,
0265       int64_t non_zero_length, std::shared_ptr<Buffer> indptr_data,
0266       std::shared_ptr<Buffer> indices_data) {
0267     return Make(indices_type, indices_type, shape, non_zero_length, indptr_data,
0268                 indices_data);
0269   }
0270 
0271   /// \brief Construct SparseCSXIndex from two index vectors
0272   explicit SparseCSXIndex(const std::shared_ptr<Tensor>& indptr,
0273                           const std::shared_ptr<Tensor>& indices)
0274       : SparseIndexBase<SparseIndexType>(), indptr_(indptr), indices_(indices) {
0275     CheckSparseCSXIndexValidity(indptr_->type(), indices_->type(), indptr_->shape(),
0276                                 indices_->shape(), SparseIndexType::kTypeName);
0277   }
0278 
0279   /// \brief Return a 1D tensor of indptr vector
0280   const std::shared_ptr<Tensor>& indptr() const { return indptr_; }
0281 
0282   /// \brief Return a 1D tensor of indices vector
0283   const std::shared_ptr<Tensor>& indices() const { return indices_; }
0284 
0285   /// \brief Return the number of non zero values in the sparse tensor related
0286   /// to this sparse index
0287   int64_t non_zero_length() const override { return indices_->shape()[0]; }
0288 
0289   /// \brief Return a string representation of the sparse index
0290   std::string ToString() const override {
0291     return std::string(SparseIndexType::kTypeName);
0292   }
0293 
0294   /// \brief Return whether the CSR indices are equal
0295   bool Equals(const SparseIndexType& other) const {
0296     return indptr()->Equals(*other.indptr()) && indices()->Equals(*other.indices());
0297   }
0298 
0299   inline Status ValidateShape(const std::vector<int64_t>& shape) const override {
0300     ARROW_RETURN_NOT_OK(SparseIndex::ValidateShape(shape));
0301 
0302     if (shape.size() < 2) {
0303       return Status::Invalid("shape length is too short");
0304     }
0305 
0306     if (shape.size() > 2) {
0307       return Status::Invalid("shape length is too long");
0308     }
0309 
0310     if (indptr_->shape()[0] == shape[static_cast<int64_t>(kCompressedAxis)] + 1) {
0311       return Status::OK();
0312     }
0313 
0314     return Status::Invalid("shape length is inconsistent with the ", ToString());
0315   }
0316 
0317  protected:
0318   std::shared_ptr<Tensor> indptr_;
0319   std::shared_ptr<Tensor> indices_;
0320 };
0321 
0322 }  // namespace internal
0323 
0324 // ----------------------------------------------------------------------
0325 // SparseCSRIndex class
0326 
0327 /// \brief EXPERIMENTAL: The index data for a CSR sparse matrix
0328 ///
0329 /// A CSR sparse index manages the location of its non-zero values by two
0330 /// vectors.
0331 ///
0332 /// The first vector, called indptr, represents the range of the rows; the i-th
0333 /// row spans from indptr[i] to indptr[i+1] in the corresponding value vector.
0334 /// So the length of an indptr vector is the number of rows + 1.
0335 ///
0336 /// The other vector, called indices, represents the column indices of the
0337 /// corresponding non-zero values.  So the length of an indices vector is same
0338 /// as the number of non-zero-values.
0339 class ARROW_EXPORT SparseCSRIndex
0340     : public internal::SparseCSXIndex<SparseCSRIndex,
0341                                       internal::SparseMatrixCompressedAxis::ROW> {
0342  public:
0343   using BaseClass =
0344       internal::SparseCSXIndex<SparseCSRIndex, internal::SparseMatrixCompressedAxis::ROW>;
0345 
0346   static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSR;
0347   static constexpr char const* kTypeName = "SparseCSRIndex";
0348 
0349   using SparseCSXIndex::kCompressedAxis;
0350   using SparseCSXIndex::Make;
0351   using SparseCSXIndex::SparseCSXIndex;
0352 };
0353 
0354 // ----------------------------------------------------------------------
0355 // SparseCSCIndex class
0356 
0357 /// \brief EXPERIMENTAL: The index data for a CSC sparse matrix
0358 ///
0359 /// A CSC sparse index manages the location of its non-zero values by two
0360 /// vectors.
0361 ///
0362 /// The first vector, called indptr, represents the range of the column; the i-th
0363 /// column spans from indptr[i] to indptr[i+1] in the corresponding value vector.
0364 /// So the length of an indptr vector is the number of columns + 1.
0365 ///
0366 /// The other vector, called indices, represents the row indices of the
0367 /// corresponding non-zero values.  So the length of an indices vector is same
0368 /// as the number of non-zero-values.
0369 class ARROW_EXPORT SparseCSCIndex
0370     : public internal::SparseCSXIndex<SparseCSCIndex,
0371                                       internal::SparseMatrixCompressedAxis::COLUMN> {
0372  public:
0373   using BaseClass =
0374       internal::SparseCSXIndex<SparseCSCIndex,
0375                                internal::SparseMatrixCompressedAxis::COLUMN>;
0376 
0377   static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSC;
0378   static constexpr char const* kTypeName = "SparseCSCIndex";
0379 
0380   using SparseCSXIndex::kCompressedAxis;
0381   using SparseCSXIndex::Make;
0382   using SparseCSXIndex::SparseCSXIndex;
0383 };
0384 
0385 // ----------------------------------------------------------------------
0386 // SparseCSFIndex class
0387 
0388 /// \brief EXPERIMENTAL: The index data for a CSF sparse tensor
0389 ///
0390 /// A CSF sparse index manages the location of its non-zero values by set of
0391 /// prefix trees. Each path from a root to leaf forms one tensor non-zero index.
0392 /// CSF is implemented with three vectors.
0393 ///
0394 /// Vectors inptr and indices contain N-1 and N buffers respectively, where N is the
0395 /// number of dimensions. Axis_order is a vector of integers of length N. Indptr and
0396 /// indices describe the set of prefix trees. Trees traverse dimensions in order given by
0397 /// axis_order.
0398 class ARROW_EXPORT SparseCSFIndex : public internal::SparseIndexBase<SparseCSFIndex> {
0399  public:
0400   static constexpr SparseTensorFormat::type format_id = SparseTensorFormat::CSF;
0401   static constexpr char const* kTypeName = "SparseCSFIndex";
0402 
0403   /// \brief Make SparseCSFIndex from raw properties
0404   static Result<std::shared_ptr<SparseCSFIndex>> Make(
0405       const std::shared_ptr<DataType>& indptr_type,
0406       const std::shared_ptr<DataType>& indices_type,
0407       const std::vector<int64_t>& indices_shapes, const std::vector<int64_t>& axis_order,
0408       const std::vector<std::shared_ptr<Buffer>>& indptr_data,
0409       const std::vector<std::shared_ptr<Buffer>>& indices_data);
0410 
0411   /// \brief Make SparseCSFIndex from raw properties
0412   static Result<std::shared_ptr<SparseCSFIndex>> Make(
0413       const std::shared_ptr<DataType>& indices_type,
0414       const std::vector<int64_t>& indices_shapes, const std::vector<int64_t>& axis_order,
0415       const std::vector<std::shared_ptr<Buffer>>& indptr_data,
0416       const std::vector<std::shared_ptr<Buffer>>& indices_data) {
0417     return Make(indices_type, indices_type, indices_shapes, axis_order, indptr_data,
0418                 indices_data);
0419   }
0420 
0421   /// \brief Construct SparseCSFIndex from two index vectors
0422   explicit SparseCSFIndex(const std::vector<std::shared_ptr<Tensor>>& indptr,
0423                           const std::vector<std::shared_ptr<Tensor>>& indices,
0424                           const std::vector<int64_t>& axis_order);
0425 
0426   /// \brief Return a 1D vector of indptr tensors
0427   const std::vector<std::shared_ptr<Tensor>>& indptr() const { return indptr_; }
0428 
0429   /// \brief Return a 1D vector of indices tensors
0430   const std::vector<std::shared_ptr<Tensor>>& indices() const { return indices_; }
0431 
0432   /// \brief Return a 1D vector specifying the order of axes
0433   const std::vector<int64_t>& axis_order() const { return axis_order_; }
0434 
0435   /// \brief Return the number of non zero values in the sparse tensor related
0436   /// to this sparse index
0437   int64_t non_zero_length() const override { return indices_.back()->shape()[0]; }
0438 
0439   /// \brief Return a string representation of the sparse index
0440   std::string ToString() const override;
0441 
0442   /// \brief Return whether the CSF indices are equal
0443   bool Equals(const SparseCSFIndex& other) const;
0444 
0445  protected:
0446   std::vector<std::shared_ptr<Tensor>> indptr_;
0447   std::vector<std::shared_ptr<Tensor>> indices_;
0448   std::vector<int64_t> axis_order_;
0449 };
0450 
0451 // ----------------------------------------------------------------------
0452 // SparseTensor class
0453 
0454 /// \brief EXPERIMENTAL: The base class of sparse tensor container
0455 class ARROW_EXPORT SparseTensor {
0456  public:
0457   virtual ~SparseTensor() = default;
0458 
0459   SparseTensorFormat::type format_id() const { return sparse_index_->format_id(); }
0460 
0461   /// \brief Return a value type of the sparse tensor
0462   std::shared_ptr<DataType> type() const { return type_; }
0463 
0464   /// \brief Return a buffer that contains the value vector of the sparse tensor
0465   std::shared_ptr<Buffer> data() const { return data_; }
0466 
0467   /// \brief Return an immutable raw data pointer
0468   const uint8_t* raw_data() const { return data_->data(); }
0469 
0470   /// \brief Return a mutable raw data pointer
0471   uint8_t* raw_mutable_data() const { return data_->mutable_data(); }
0472 
0473   /// \brief Return a shape vector of the sparse tensor
0474   const std::vector<int64_t>& shape() const { return shape_; }
0475 
0476   /// \brief Return a sparse index of the sparse tensor
0477   const std::shared_ptr<SparseIndex>& sparse_index() const { return sparse_index_; }
0478 
0479   /// \brief Return a number of dimensions of the sparse tensor
0480   int ndim() const { return static_cast<int>(shape_.size()); }
0481 
0482   /// \brief Return a vector of dimension names
0483   const std::vector<std::string>& dim_names() const { return dim_names_; }
0484 
0485   /// \brief Return the name of the i-th dimension
0486   const std::string& dim_name(int i) const;
0487 
0488   /// \brief Total number of value cells in the sparse tensor
0489   int64_t size() const;
0490 
0491   /// \brief Return true if the underlying data buffer is mutable
0492   bool is_mutable() const { return data_->is_mutable(); }
0493 
0494   /// \brief Total number of non-zero cells in the sparse tensor
0495   int64_t non_zero_length() const {
0496     return sparse_index_ ? sparse_index_->non_zero_length() : 0;
0497   }
0498 
0499   /// \brief Return whether sparse tensors are equal
0500   bool Equals(const SparseTensor& other,
0501               const EqualOptions& = EqualOptions::Defaults()) const;
0502 
0503   /// \brief Return dense representation of sparse tensor as tensor
0504   ///
0505   /// The returned Tensor has row-major order (C-like).
0506   Result<std::shared_ptr<Tensor>> ToTensor(MemoryPool* pool) const;
0507   Result<std::shared_ptr<Tensor>> ToTensor() const {
0508     return ToTensor(default_memory_pool());
0509   }
0510 
0511  protected:
0512   // Constructor with all attributes
0513   SparseTensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
0514                const std::vector<int64_t>& shape,
0515                const std::shared_ptr<SparseIndex>& sparse_index,
0516                const std::vector<std::string>& dim_names);
0517 
0518   std::shared_ptr<DataType> type_;
0519   std::shared_ptr<Buffer> data_;
0520   std::vector<int64_t> shape_;
0521   std::shared_ptr<SparseIndex> sparse_index_;
0522 
0523   // These names are optional
0524   std::vector<std::string> dim_names_;
0525 };
0526 
0527 // ----------------------------------------------------------------------
0528 // SparseTensorImpl class
0529 
0530 namespace internal {
0531 
0532 ARROW_EXPORT
0533 Status MakeSparseTensorFromTensor(const Tensor& tensor,
0534                                   SparseTensorFormat::type sparse_format_id,
0535                                   const std::shared_ptr<DataType>& index_value_type,
0536                                   MemoryPool* pool,
0537                                   std::shared_ptr<SparseIndex>* out_sparse_index,
0538                                   std::shared_ptr<Buffer>* out_data);
0539 
0540 }  // namespace internal
0541 
0542 /// \brief EXPERIMENTAL: Concrete sparse tensor implementation classes with sparse index
0543 /// type
0544 template <typename SparseIndexType>
0545 class SparseTensorImpl : public SparseTensor {
0546  public:
0547   virtual ~SparseTensorImpl() = default;
0548 
0549   /// \brief Construct a sparse tensor from physical data buffer and logical index
0550   SparseTensorImpl(const std::shared_ptr<SparseIndexType>& sparse_index,
0551                    const std::shared_ptr<DataType>& type,
0552                    const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape,
0553                    const std::vector<std::string>& dim_names)
0554       : SparseTensor(type, data, shape, sparse_index, dim_names) {}
0555 
0556   /// \brief Construct an empty sparse tensor
0557   SparseTensorImpl(const std::shared_ptr<DataType>& type,
0558                    const std::vector<int64_t>& shape,
0559                    const std::vector<std::string>& dim_names = {})
0560       : SparseTensorImpl(NULLPTR, type, NULLPTR, shape, dim_names) {}
0561 
0562   /// \brief Create a SparseTensor with full parameters
0563   static inline Result<std::shared_ptr<SparseTensorImpl<SparseIndexType>>> Make(
0564       const std::shared_ptr<SparseIndexType>& sparse_index,
0565       const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
0566       const std::vector<int64_t>& shape, const std::vector<std::string>& dim_names) {
0567     if (!is_tensor_supported(type->id())) {
0568       return Status::Invalid(type->ToString(),
0569                              " is not valid data type for a sparse tensor");
0570     }
0571     ARROW_RETURN_NOT_OK(sparse_index->ValidateShape(shape));
0572     if (dim_names.size() > 0 && dim_names.size() != shape.size()) {
0573       return Status::Invalid("dim_names length is inconsistent with shape");
0574     }
0575     return std::make_shared<SparseTensorImpl<SparseIndexType>>(sparse_index, type, data,
0576                                                                shape, dim_names);
0577   }
0578 
0579   /// \brief Create a sparse tensor from a dense tensor
0580   ///
0581   /// The dense tensor is re-encoded as a sparse index and a physical
0582   /// data buffer for the non-zero value.
0583   static inline Result<std::shared_ptr<SparseTensorImpl<SparseIndexType>>> Make(
0584       const Tensor& tensor, const std::shared_ptr<DataType>& index_value_type,
0585       MemoryPool* pool = default_memory_pool()) {
0586     std::shared_ptr<SparseIndex> sparse_index;
0587     std::shared_ptr<Buffer> data;
0588     ARROW_RETURN_NOT_OK(internal::MakeSparseTensorFromTensor(
0589         tensor, SparseIndexType::format_id, index_value_type, pool, &sparse_index,
0590         &data));
0591     return std::make_shared<SparseTensorImpl<SparseIndexType>>(
0592         internal::checked_pointer_cast<SparseIndexType>(sparse_index), tensor.type(),
0593         data, tensor.shape(), tensor.dim_names_);
0594   }
0595 
0596   static inline Result<std::shared_ptr<SparseTensorImpl<SparseIndexType>>> Make(
0597       const Tensor& tensor, MemoryPool* pool = default_memory_pool()) {
0598     return Make(tensor, int64(), pool);
0599   }
0600 
0601  private:
0602   ARROW_DISALLOW_COPY_AND_ASSIGN(SparseTensorImpl);
0603 };
0604 
0605 /// \brief EXPERIMENTAL: Type alias for COO sparse tensor
0606 using SparseCOOTensor = SparseTensorImpl<SparseCOOIndex>;
0607 
0608 /// \brief EXPERIMENTAL: Type alias for CSR sparse matrix
0609 using SparseCSRMatrix = SparseTensorImpl<SparseCSRIndex>;
0610 
0611 /// \brief EXPERIMENTAL: Type alias for CSC sparse matrix
0612 using SparseCSCMatrix = SparseTensorImpl<SparseCSCIndex>;
0613 
0614 /// \brief EXPERIMENTAL: Type alias for CSF sparse matrix
0615 using SparseCSFTensor = SparseTensorImpl<SparseCSFIndex>;
0616 
0617 }  // namespace arrow