File indexing completed on 2025-02-22 10:42:44
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include "onnx/defs/shape_inference.h"
0008 #include "onnx/defs/tensor_proto_util.h"
0009 #include "onnx/onnx_pb.h"
0010
0011 namespace ONNX_NAMESPACE {
0012 namespace defs {
0013 namespace math {
0014 namespace utils {
0015 template <typename T>
0016 T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) {
0017 if (t == nullptr) {
0018 return T{};
0019 }
0020
0021 auto data_type = t->data_type();
0022 switch (data_type) {
0023 case ONNX_NAMESPACE::TensorProto::FLOAT:
0024 return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
0025 case ONNX_NAMESPACE::TensorProto::DOUBLE:
0026 return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
0027 case ONNX_NAMESPACE::TensorProto::INT32:
0028 return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
0029 case ONNX_NAMESPACE::TensorProto::INT64:
0030 return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
0031 default:
0032 fail_shape_inference("Unsupported input data type of ", data_type);
0033 }
0034 }
0035 }
0036 }
0037 }
0038 }