Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:20:23

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include "onnx/defs/shape_inference.h"
0008 #include "onnx/defs/tensor_proto_util.h"
0009 #include "onnx/onnx_pb.h"
0010 
0011 namespace ONNX_NAMESPACE {
0012 namespace defs {
0013 namespace math {
0014 namespace utils {
0015 
0016 template <typename T>
0017 T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) {
0018   if (t == nullptr) {
0019     return T{};
0020   }
0021 
0022   auto data_type = t->data_type();
0023   switch (data_type) {
0024     case ONNX_NAMESPACE::TensorProto::FLOAT:
0025       return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
0026     case ONNX_NAMESPACE::TensorProto::DOUBLE:
0027       return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
0028     case ONNX_NAMESPACE::TensorProto::INT32:
0029       return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
0030     case ONNX_NAMESPACE::TensorProto::INT64:
0031       return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
0032     default:
0033       fail_shape_inference("Unsupported input data type of ", data_type);
0034   }
0035 }
0036 
0037 void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
0038 
0039 void QLinearMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
0040 
0041 const char* QLinearMatMulDoc();
0042 
0043 int MathOpTwoIntegers(std::string op_type, int a, int b);
0044 
0045 } // namespace utils
0046 } // namespace math
0047 } // namespace defs
0048 } // namespace ONNX_NAMESPACE