File indexing completed on 2025-08-28 08:26:58
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 #pragma once
0019
0020 #include "arrow/extension_type.h"
0021
0022 namespace arrow {
0023 namespace extension {
0024
0025 class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
0026 public:
0027 using ExtensionArray::ExtensionArray;
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038 static Result<std::shared_ptr<FixedShapeTensorArray>> FromTensor(
0039 const std::shared_ptr<Tensor>& tensor);
0040
0041
0042
0043
0044
0045
0046
0047 const Result<std::shared_ptr<Tensor>> ToTensor() const;
0048 };
0049
0050
0051
0052
0053 class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
0054 public:
0055 FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, const int32_t& size,
0056 const std::vector<int64_t>& shape,
0057 const std::vector<int64_t>& permutation = {},
0058 const std::vector<std::string>& dim_names = {})
0059 : ExtensionType(fixed_size_list(value_type, size)),
0060 value_type_(value_type),
0061 shape_(shape),
0062 permutation_(permutation),
0063 dim_names_(dim_names) {}
0064
0065 std::string extension_name() const override { return "arrow.fixed_shape_tensor"; }
0066 std::string ToString(bool show_metadata = false) const override;
0067
0068
0069 size_t ndim() const { return shape_.size(); }
0070
0071
0072 const std::vector<int64_t>& shape() const { return shape_; }
0073
0074
0075 const std::shared_ptr<DataType>& value_type() const { return value_type_; }
0076
0077
0078
0079
0080 const std::vector<int64_t>& strides();
0081
0082
0083 const std::vector<int64_t>& permutation() const { return permutation_; }
0084
0085
0086 const std::vector<std::string>& dim_names() const { return dim_names_; }
0087
0088 bool ExtensionEquals(const ExtensionType& other) const override;
0089
0090 std::string Serialize() const override;
0091
0092 Result<std::shared_ptr<DataType>> Deserialize(
0093 std::shared_ptr<DataType> storage_type,
0094 const std::string& serialized_data) const override;
0095
0096
0097 std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
0098
0099
0100
0101
0102
0103
0104
0105 static Result<std::shared_ptr<Tensor>> MakeTensor(
0106 const std::shared_ptr<ExtensionScalar>& scalar);
0107
0108
0109 static Result<std::shared_ptr<DataType>> Make(
0110 const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
0111 const std::vector<int64_t>& permutation = {},
0112 const std::vector<std::string>& dim_names = {});
0113
0114 private:
0115 std::shared_ptr<DataType> storage_type_;
0116 std::shared_ptr<DataType> value_type_;
0117 std::vector<int64_t> shape_;
0118 std::vector<int64_t> strides_;
0119 std::vector<int64_t> permutation_;
0120 std::vector<std::string> dim_names_;
0121 };
0122
0123
0124 ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
0125 const std::shared_ptr<DataType>& storage_type, const std::vector<int64_t>& shape,
0126 const std::vector<int64_t>& permutation = {},
0127 const std::vector<std::string>& dim_names = {});
0128
0129 }
0130 }