Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:44

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // ATTENTION: The code in this file is highly EXPERIMENTAL.
0008 // Adventurous users should note that the APIs will probably change.
0009 
0010 #pragma once
0011 
0012 #include <cmath>
0013 #include <functional>
0014 #include <numeric>
0015 #include <string>
0016 #include <utility>
0017 #include <vector>
0018 
0019 #include "onnx/common/assertions.h"
0020 #include "onnx/onnx_pb.h"
0021 #include "onnx/string_utils.h"
0022 
0023 namespace ONNX_NAMESPACE {
0024 
0025 struct Tensor final {
0026  private:
0027   bool is_segment_;
0028   int64_t segment_begin_;
0029   int64_t segment_end_;
0030   bool has_name_;
0031   std::string name_;
0032   int32_t elem_type_;
0033   std::vector<int64_t> sizes_;
0034 
0035   std::vector<float> float_data_;
0036   std::vector<double> double_data_;
0037   std::vector<int32_t> int32_data_;
0038   std::vector<int64_t> int64_data_;
0039   std::vector<uint64_t> uint64_data_;
0040   std::vector<std::string> string_data_;
0041 
0042   bool is_raw_data_;
0043   std::string raw_data_;
0044 
0045  public:
0046   Tensor()
0047       : is_segment_(false),
0048         segment_begin_(0),
0049         segment_end_(0),
0050         has_name_(false),
0051         elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
0052         is_raw_data_(false) {}
0053 
0054   Tensor(const Tensor& other)
0055       : is_segment_(other.is_segment_),
0056         segment_begin_(other.segment_begin_),
0057         segment_end_(other.segment_end_),
0058         has_name_(other.has_name_),
0059         elem_type_(other.elem_type_),
0060         sizes_(other.sizes_),
0061         float_data_(other.float_data_),
0062         double_data_(other.double_data_),
0063         int32_data_(other.int32_data_),
0064         int64_data_(other.int64_data_),
0065         uint64_data_(other.uint64_data_),
0066         is_raw_data_(other.is_raw_data_) {
0067     // Deep copy. Avoid copy on write when using gcc<5.0
0068     string_data_.resize(other.string_data_.size());
0069     for (unsigned int i = 0; i < other.string_data_.size(); ++i) {
0070       string_data_[i] = std::string(other.string_data_[i].data(), other.string_data_[i].size());
0071     }
0072     name_ = std::string(other.name_.data(), other.name_.size());
0073     raw_data_ = std::string(other.raw_data_.data(), other.raw_data_.size());
0074   }
0075   Tensor(Tensor&&) = default;
0076   ~Tensor() = default;
0077 
0078   friend void swap(Tensor& first, Tensor& second) {
0079     using std::swap;
0080     swap(first.is_segment_, second.is_segment_);
0081     swap(first.segment_begin_, second.segment_begin_);
0082     swap(first.segment_end_, second.segment_end_);
0083     swap(first.has_name_, second.has_name_);
0084     swap(first.name_, second.name_);
0085     swap(first.elem_type_, second.elem_type_);
0086     swap(first.sizes_, second.sizes_);
0087     swap(first.float_data_, second.float_data_);
0088     swap(first.double_data_, second.double_data_);
0089     swap(first.int32_data_, second.int32_data_);
0090     swap(first.int64_data_, second.int64_data_);
0091     swap(first.uint64_data_, second.uint64_data_);
0092     swap(first.is_raw_data_, second.is_raw_data_);
0093     swap(first.string_data_, second.string_data_);
0094     swap(first.raw_data_, second.raw_data_);
0095   }
0096 
0097   Tensor& operator=(Tensor other) noexcept {
0098     swap(*this, other);
0099     return *this;
0100   }
0101 
0102   const std::vector<int64_t>& sizes() const {
0103     return sizes_;
0104   }
0105   std::vector<int64_t>& sizes() {
0106     return sizes_;
0107   }
0108   /// if tensor is a scalar, the sizes is empty, but the element number is actually 1.
0109   /// size_from_dim() cannot handle this case, while elem_num() handles it correctly
0110   int64_t elem_num() const {
0111     return std::accumulate(sizes_.begin(), sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
0112   }
0113   int64_t size_from_dim(int dim) const {
0114     if (dim < 0) {
0115       dim += (int)sizes_.size();
0116     }
0117     ONNX_ASSERT(dim >= 0 && (size_t)dim < sizes_.size());
0118     return std::accumulate(sizes_.begin() + dim, sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
0119   }
0120 
0121   int32_t elem_type() const {
0122     return elem_type_;
0123   }
0124 
0125   int32_t& elem_type() {
0126     return elem_type_;
0127   }
0128 
0129   std::vector<std::string>& strings() {
0130     return string_data_;
0131   }
0132 
0133   const std::vector<std::string>& strings() const {
0134     return string_data_;
0135   }
0136 
0137   std::vector<float>& floats() {
0138     return float_data_;
0139   }
0140 
0141   const std::vector<float>& floats() const {
0142     return float_data_;
0143   }
0144 
0145   std::vector<double>& doubles() {
0146     return double_data_;
0147   }
0148 
0149   const std::vector<double>& doubles() const {
0150     return double_data_;
0151   }
0152 
0153   std::vector<int32_t>& int32s() {
0154     return int32_data_;
0155   }
0156 
0157   const std::vector<int32_t>& int32s() const {
0158     return int32_data_;
0159   }
0160 
0161   std::vector<int64_t>& int64s() {
0162     return int64_data_;
0163   }
0164 
0165   const std::vector<int64_t>& int64s() const {
0166     return int64_data_;
0167   }
0168 
0169   std::vector<uint64_t>& uint64s() {
0170     return uint64_data_;
0171   }
0172 
0173   const std::vector<uint64_t>& uint64s() const {
0174     return uint64_data_;
0175   }
0176 
0177   const std::string& raw() const {
0178     return raw_data_;
0179   }
0180 
0181   void set_raw_data(std::string raw_data) {
0182     is_raw_data_ = true;
0183     raw_data_ = std::move(raw_data);
0184   }
0185 
0186   template <typename T>
0187   T* data();
0188 
0189   template <typename T>
0190   const T* data() const;
0191 
0192   bool is_segment() const {
0193     return is_segment_;
0194   }
0195 
0196   int64_t segment_begin() const {
0197     return segment_begin_;
0198   }
0199 
0200   int64_t segment_end() const {
0201     return segment_end_;
0202   }
0203 
0204   void set_segment_begin_and_end(int64_t begin, int64_t end) {
0205     is_segment_ = true;
0206     segment_begin_ = begin;
0207     segment_end_ = end;
0208   }
0209 
0210   bool hasName() const {
0211     return has_name_;
0212   }
0213 
0214   const std::string& name() const {
0215     return name_;
0216   }
0217 
0218   void setName(std::string name) {
0219     has_name_ = true;
0220     name_ = std::move(name);
0221   }
0222 
0223   bool is_raw_data() const {
0224     return is_raw_data_;
0225   }
0226 };
0227 
0228 template <>
0229 inline std::string* Tensor::data<std::string>() {
0230   ONNX_ASSERTM(
0231       !is_raw_data(),
0232       "data type is string. string content is required to be stored in repeated bytes string_data field."
0233       "raw_data type cannot be string.");
0234   return string_data_.data();
0235 }
0236 template <>
0237 inline const std::string* Tensor::data<std::string>() const {
0238   ONNX_ASSERTM(
0239       !is_raw_data(),
0240       "data type is string. string content is required to be stored in repeated bytes string_data field."
0241       "raw_data type cannot be string.");
0242   return string_data_.data();
0243 }
0244 
0245 #define define_data(type, field)                             \
0246   template <>                                                \
0247   inline type* Tensor::data<type>() {                        \
0248     if (is_raw_data_) {                                      \
0249       return (type*)const_cast<char*>(&raw_data_.data()[0]); \
0250     } else {                                                 \
0251       return field.data();                                   \
0252     }                                                        \
0253   }                                                          \
0254                                                              \
0255   template <>                                                \
0256   inline const type* Tensor::data<type>() const {            \
0257     if (is_raw_data_) {                                      \
0258       return (const type*)(raw_data_.data());                \
0259     } else {                                                 \
0260       return field.data();                                   \
0261     }                                                        \
0262   }
0263 
0264 define_data(float, float_data_);
0265 define_data(double, double_data_);
0266 define_data(int32_t, int32_data_);
0267 define_data(int64_t, int64_data_);
0268 define_data(uint64_t, uint64_data_);
0269 #undef define_data
0270 
0271 } // namespace ONNX_NAMESPACE