Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-08-28 08:26:58

0001 // Licensed to the Apache Software Foundation (ASF) under one
0002 // or more contributor license agreements.  See the NOTICE file
0003 // distributed with this work for additional information
0004 // regarding copyright ownership.  The ASF licenses this file
0005 // to you under the Apache License, Version 2.0 (the
0006 // "License"); you may not use this file except in compliance
0007 // with the License.  You may obtain a copy of the License at
0008 //
0009 //   http://www.apache.org/licenses/LICENSE-2.0
0010 //
0011 // Unless required by applicable law or agreed to in writing,
0012 // software distributed under the License is distributed on an
0013 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
0014 // KIND, either express or implied.  See the License for the
0015 // specific language governing permissions and limitations
0016 // under the License.
0017 
0018 #pragma once
0019 
0020 #include "arrow/extension_type.h"
0021 
0022 namespace arrow {
0023 namespace extension {
0024 
0025 class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
0026  public:
0027   using ExtensionArray::ExtensionArray;
0028 
0029   /// \brief Create a FixedShapeTensorArray from a Tensor
0030   ///
0031   /// This method will create a FixedShapeTensorArray from a Tensor, taking its first
0032   /// dimension as the number of elements in the resulting array and the remaining
0033   /// dimensions as the shape of the individual tensors. If Tensor provides strides,
0034   /// they will be used to determine dimension permutation. Otherwise, row-major layout
0035   /// (i.e. no permutation) will be assumed.
0036   ///
0037   /// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray
0038   static Result<std::shared_ptr<FixedShapeTensorArray>> FromTensor(
0039       const std::shared_ptr<Tensor>& tensor);
0040 
0041   /// \brief Create a Tensor from FixedShapeTensorArray
0042   ///
0043   /// This method will create a Tensor from a FixedShapeTensorArray, setting its first
0044   /// dimension as length equal to the FixedShapeTensorArray's length and the remaining
0045   /// dimensions as the FixedShapeTensorType's shape. Shape and dim_names will be
0046   /// permuted according to permutation stored in the FixedShapeTensorType metadata.
0047   const Result<std::shared_ptr<Tensor>> ToTensor() const;
0048 };
0049 
0050 /// \brief Concrete type class for constant-size Tensor data.
0051 /// This is a canonical arrow extension type.
0052 /// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html
0053 class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
0054  public:
0055   FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, const int32_t& size,
0056                        const std::vector<int64_t>& shape,
0057                        const std::vector<int64_t>& permutation = {},
0058                        const std::vector<std::string>& dim_names = {})
0059       : ExtensionType(fixed_size_list(value_type, size)),
0060         value_type_(value_type),
0061         shape_(shape),
0062         permutation_(permutation),
0063         dim_names_(dim_names) {}
0064 
0065   std::string extension_name() const override { return "arrow.fixed_shape_tensor"; }
0066   std::string ToString(bool show_metadata = false) const override;
0067 
0068   /// Number of dimensions of tensor elements
0069   size_t ndim() const { return shape_.size(); }
0070 
0071   /// Shape of tensor elements
0072   const std::vector<int64_t>& shape() const { return shape_; }
0073 
0074   /// Value type of tensor elements
0075   const std::shared_ptr<DataType>& value_type() const { return value_type_; }
0076 
0077   /// Strides of tensor elements. Strides state offset in bytes between adjacent
0078   /// elements along each dimension. In case permutation is non-empty strides are
0079   /// computed from permuted tensor element's shape.
0080   const std::vector<int64_t>& strides();
0081 
0082   /// Permutation mapping from logical to physical memory layout of tensor elements
0083   const std::vector<int64_t>& permutation() const { return permutation_; }
0084 
0085   /// Dimension names of tensor elements. Dimensions are ordered physically.
0086   const std::vector<std::string>& dim_names() const { return dim_names_; }
0087 
0088   bool ExtensionEquals(const ExtensionType& other) const override;
0089 
0090   std::string Serialize() const override;
0091 
0092   Result<std::shared_ptr<DataType>> Deserialize(
0093       std::shared_ptr<DataType> storage_type,
0094       const std::string& serialized_data) const override;
0095 
0096   /// Create a FixedShapeTensorArray from ArrayData
0097   std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
0098 
0099   /// \brief Create a Tensor from an ExtensionScalar from a FixedShapeTensorArray
0100   ///
0101   /// This method will return a Tensor from ExtensionScalar with strides
0102   /// derived from shape and permutation of FixedShapeTensorType. Shape and
0103   /// dim_names will be permuted according to permutation stored in the
0104   /// FixedShapeTensorType metadata.
0105   static Result<std::shared_ptr<Tensor>> MakeTensor(
0106       const std::shared_ptr<ExtensionScalar>& scalar);
0107 
0108   /// \brief Create a FixedShapeTensorType instance
0109   static Result<std::shared_ptr<DataType>> Make(
0110       const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
0111       const std::vector<int64_t>& permutation = {},
0112       const std::vector<std::string>& dim_names = {});
0113 
0114  private:
0115   std::shared_ptr<DataType> storage_type_;
0116   std::shared_ptr<DataType> value_type_;
0117   std::vector<int64_t> shape_;
0118   std::vector<int64_t> strides_;
0119   std::vector<int64_t> permutation_;
0120   std::vector<std::string> dim_names_;
0121 };
0122 
0123 /// \brief Return a FixedShapeTensorType instance.
0124 ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
0125     const std::shared_ptr<DataType>& storage_type, const std::vector<int64_t>& shape,
0126     const std::vector<int64_t>& permutation = {},
0127     const std::vector<std::string>& dim_names = {});
0128 
0129 }  // namespace extension
0130 }  // namespace arrow