File indexing completed on 2025-12-16 10:20:23
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include "onnx/defs/shape_inference.h"
0008 #include "onnx/defs/tensor_proto_util.h"
0009 #include "onnx/onnx_pb.h"
0010
0011 namespace ONNX_NAMESPACE {
0012 namespace defs {
0013 namespace math {
0014 namespace utils {
0015
0016 template <typename T>
0017 T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) {
0018 if (t == nullptr) {
0019 return T{};
0020 }
0021
0022 auto data_type = t->data_type();
0023 switch (data_type) {
0024 case ONNX_NAMESPACE::TensorProto::FLOAT:
0025 return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
0026 case ONNX_NAMESPACE::TensorProto::DOUBLE:
0027 return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
0028 case ONNX_NAMESPACE::TensorProto::INT32:
0029 return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
0030 case ONNX_NAMESPACE::TensorProto::INT64:
0031 return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
0032 default:
0033 fail_shape_inference("Unsupported input data type of ", data_type);
0034 }
0035 }
0036
0037 void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
0038
0039 void QLinearMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx);
0040
0041 const char* QLinearMatMulDoc();
0042
0043 int MathOpTwoIntegers(std::string op_type, int a, int b);
0044
0045 }
0046 }
0047 }
0048 }