File indexing completed on 2025-02-22 10:42:44
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #pragma once
0011
0012 #include <cmath>
0013 #include <functional>
0014 #include <numeric>
0015 #include <string>
0016 #include <utility>
0017 #include <vector>
0018
0019 #include "onnx/common/assertions.h"
0020 #include "onnx/onnx_pb.h"
0021 #include "onnx/string_utils.h"
0022
0023 namespace ONNX_NAMESPACE {
0024
0025 struct Tensor final {
0026 private:
0027 bool is_segment_;
0028 int64_t segment_begin_;
0029 int64_t segment_end_;
0030 bool has_name_;
0031 std::string name_;
0032 int32_t elem_type_;
0033 std::vector<int64_t> sizes_;
0034
0035 std::vector<float> float_data_;
0036 std::vector<double> double_data_;
0037 std::vector<int32_t> int32_data_;
0038 std::vector<int64_t> int64_data_;
0039 std::vector<uint64_t> uint64_data_;
0040 std::vector<std::string> string_data_;
0041
0042 bool is_raw_data_;
0043 std::string raw_data_;
0044
0045 public:
0046 Tensor()
0047 : is_segment_(false),
0048 segment_begin_(0),
0049 segment_end_(0),
0050 has_name_(false),
0051 elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
0052 is_raw_data_(false) {}
0053
0054 Tensor(const Tensor& other)
0055 : is_segment_(other.is_segment_),
0056 segment_begin_(other.segment_begin_),
0057 segment_end_(other.segment_end_),
0058 has_name_(other.has_name_),
0059 elem_type_(other.elem_type_),
0060 sizes_(other.sizes_),
0061 float_data_(other.float_data_),
0062 double_data_(other.double_data_),
0063 int32_data_(other.int32_data_),
0064 int64_data_(other.int64_data_),
0065 uint64_data_(other.uint64_data_),
0066 is_raw_data_(other.is_raw_data_) {
0067
0068 string_data_.resize(other.string_data_.size());
0069 for (unsigned int i = 0; i < other.string_data_.size(); ++i) {
0070 string_data_[i] = std::string(other.string_data_[i].data(), other.string_data_[i].size());
0071 }
0072 name_ = std::string(other.name_.data(), other.name_.size());
0073 raw_data_ = std::string(other.raw_data_.data(), other.raw_data_.size());
0074 }
0075 Tensor(Tensor&&) = default;
0076 ~Tensor() = default;
0077
0078 friend void swap(Tensor& first, Tensor& second) {
0079 using std::swap;
0080 swap(first.is_segment_, second.is_segment_);
0081 swap(first.segment_begin_, second.segment_begin_);
0082 swap(first.segment_end_, second.segment_end_);
0083 swap(first.has_name_, second.has_name_);
0084 swap(first.name_, second.name_);
0085 swap(first.elem_type_, second.elem_type_);
0086 swap(first.sizes_, second.sizes_);
0087 swap(first.float_data_, second.float_data_);
0088 swap(first.double_data_, second.double_data_);
0089 swap(first.int32_data_, second.int32_data_);
0090 swap(first.int64_data_, second.int64_data_);
0091 swap(first.uint64_data_, second.uint64_data_);
0092 swap(first.is_raw_data_, second.is_raw_data_);
0093 swap(first.string_data_, second.string_data_);
0094 swap(first.raw_data_, second.raw_data_);
0095 }
0096
0097 Tensor& operator=(Tensor other) noexcept {
0098 swap(*this, other);
0099 return *this;
0100 }
0101
0102 const std::vector<int64_t>& sizes() const {
0103 return sizes_;
0104 }
0105 std::vector<int64_t>& sizes() {
0106 return sizes_;
0107 }
0108
0109
0110 int64_t elem_num() const {
0111 return std::accumulate(sizes_.begin(), sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
0112 }
0113 int64_t size_from_dim(int dim) const {
0114 if (dim < 0) {
0115 dim += (int)sizes_.size();
0116 }
0117 ONNX_ASSERT(dim >= 0 && (size_t)dim < sizes_.size());
0118 return std::accumulate(sizes_.begin() + dim, sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
0119 }
0120
0121 int32_t elem_type() const {
0122 return elem_type_;
0123 }
0124
0125 int32_t& elem_type() {
0126 return elem_type_;
0127 }
0128
0129 std::vector<std::string>& strings() {
0130 return string_data_;
0131 }
0132
0133 const std::vector<std::string>& strings() const {
0134 return string_data_;
0135 }
0136
0137 std::vector<float>& floats() {
0138 return float_data_;
0139 }
0140
0141 const std::vector<float>& floats() const {
0142 return float_data_;
0143 }
0144
0145 std::vector<double>& doubles() {
0146 return double_data_;
0147 }
0148
0149 const std::vector<double>& doubles() const {
0150 return double_data_;
0151 }
0152
0153 std::vector<int32_t>& int32s() {
0154 return int32_data_;
0155 }
0156
0157 const std::vector<int32_t>& int32s() const {
0158 return int32_data_;
0159 }
0160
0161 std::vector<int64_t>& int64s() {
0162 return int64_data_;
0163 }
0164
0165 const std::vector<int64_t>& int64s() const {
0166 return int64_data_;
0167 }
0168
0169 std::vector<uint64_t>& uint64s() {
0170 return uint64_data_;
0171 }
0172
0173 const std::vector<uint64_t>& uint64s() const {
0174 return uint64_data_;
0175 }
0176
0177 const std::string& raw() const {
0178 return raw_data_;
0179 }
0180
0181 void set_raw_data(std::string raw_data) {
0182 is_raw_data_ = true;
0183 raw_data_ = std::move(raw_data);
0184 }
0185
0186 template <typename T>
0187 T* data();
0188
0189 template <typename T>
0190 const T* data() const;
0191
0192 bool is_segment() const {
0193 return is_segment_;
0194 }
0195
0196 int64_t segment_begin() const {
0197 return segment_begin_;
0198 }
0199
0200 int64_t segment_end() const {
0201 return segment_end_;
0202 }
0203
0204 void set_segment_begin_and_end(int64_t begin, int64_t end) {
0205 is_segment_ = true;
0206 segment_begin_ = begin;
0207 segment_end_ = end;
0208 }
0209
0210 bool hasName() const {
0211 return has_name_;
0212 }
0213
0214 const std::string& name() const {
0215 return name_;
0216 }
0217
0218 void setName(std::string name) {
0219 has_name_ = true;
0220 name_ = std::move(name);
0221 }
0222
0223 bool is_raw_data() const {
0224 return is_raw_data_;
0225 }
0226 };
0227
0228 template <>
0229 inline std::string* Tensor::data<std::string>() {
0230 ONNX_ASSERTM(
0231 !is_raw_data(),
0232 "data type is string. string content is required to be stored in repeated bytes string_data field."
0233 "raw_data type cannot be string.");
0234 return string_data_.data();
0235 }
0236 template <>
0237 inline const std::string* Tensor::data<std::string>() const {
0238 ONNX_ASSERTM(
0239 !is_raw_data(),
0240 "data type is string. string content is required to be stored in repeated bytes string_data field."
0241 "raw_data type cannot be string.");
0242 return string_data_.data();
0243 }
0244
0245 #define define_data(type, field) \
0246 template <> \
0247 inline type* Tensor::data<type>() { \
0248 if (is_raw_data_) { \
0249 return (type*)const_cast<char*>(&raw_data_.data()[0]); \
0250 } else { \
0251 return field.data(); \
0252 } \
0253 } \
0254 \
0255 template <> \
0256 inline const type* Tensor::data<type>() const { \
0257 if (is_raw_data_) { \
0258 return (const type*)(raw_data_.data()); \
0259 } else { \
0260 return field.data(); \
0261 } \
0262 }
0263
0264 define_data(float, float_data_);
0265 define_data(double, double_data_);
0266 define_data(int32_t, int32_data_);
0267 define_data(int64_t, int64_data_);
0268 define_data(uint64_t, uint64_data_);
0269 #undef define_data
0270
0271 }