Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:44

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include "onnx/defs/shape_inference.h"
0008 #include "onnx/defs/tensor_proto_util.h"
0009 #include "onnx/onnx_pb.h"
0010 
0011 namespace ONNX_NAMESPACE {
0012 namespace defs {
0013 namespace math {
0014 namespace utils {
0015 template <typename T>
0016 T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) {
0017   if (t == nullptr) {
0018     return T{};
0019   }
0020 
0021   auto data_type = t->data_type();
0022   switch (data_type) {
0023     case ONNX_NAMESPACE::TensorProto::FLOAT:
0024       return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
0025     case ONNX_NAMESPACE::TensorProto::DOUBLE:
0026       return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
0027     case ONNX_NAMESPACE::TensorProto::INT32:
0028       return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
0029     case ONNX_NAMESPACE::TensorProto::INT64:
0030       return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
0031     default:
0032       fail_shape_inference("Unsupported input data type of ", data_type);
0033   }
0034 }
0035 } // namespace utils
0036 } // namespace math
0037 } // namespace defs
0038 } // namespace ONNX_NAMESPACE