File indexing completed on 2025-08-28 08:58:49
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 #include "onnx/version_converter/adapters/type_restriction.h"
0015
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018
0019 static const std::vector<TensorProto_DataType> q_dq_20_unallowed_types = {
0020 TensorProto_DataType_UINT16,
0021 TensorProto_DataType_INT16,
0022 TensorProto_DataType_UINT4,
0023 TensorProto_DataType_INT4};
0024
0025 class QuantizeLinear_21_20 final : public TypeRestriction {
0026 public:
0027 explicit QuantizeLinear_21_20()
0028 : TypeRestriction("QuantizeLinear", OpSetID(21), OpSetID(20), q_dq_20_unallowed_types) {}
0029
0030 void adapt_quantize_linear_21_20(std::shared_ptr<Graph>, Node* node) const {
0031 if (node->hasAttribute(kblock_size)) {
0032 if ((node->i(kblock_size) != 0)) {
0033 ONNX_ASSERTM(false, "Blocked quantization is not supported for Opset Version %d.", target_version().version())
0034 }
0035 node->removeAttribute(kblock_size);
0036 }
0037 if (node->hasAttribute(koutput_dtype)) {
0038 if (node->i(koutput_dtype) != TensorProto_DataType_UINT8 && node->inputs().size() < 3) {
0039 ONNX_ASSERTM(
0040 false,
0041 "Attribute output_dtype is not supported for Opset Version %d, supply a zero-point tensor instead",
0042 target_version().version())
0043 }
0044 node->removeAttribute(koutput_dtype);
0045 }
0046 }
0047
0048 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0049 adapt_type_restriction(graph, node);
0050 adapt_quantize_linear_21_20(graph, node);
0051 return node;
0052 }
0053 };
0054
0055 class DequantizeLinear_21_20 final : public TypeRestriction {
0056 public:
0057 explicit DequantizeLinear_21_20()
0058 : TypeRestriction("DequantizeLinear", OpSetID(21), OpSetID(20), q_dq_20_unallowed_types) {}
0059
0060 void adapt_dequantize_linear_21_20(std::shared_ptr<Graph>, Node* node) const {
0061 if (node->hasAttribute(kblock_size)) {
0062 if ((node->i(kblock_size) != 0)) {
0063 ONNX_ASSERTM(false, "Blocked quantization is not supported for Opset Version %d.", target_version().version())
0064 }
0065 node->removeAttribute(kblock_size);
0066 }
0067 }
0068
0069 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0070 adapt_type_restriction(graph, node);
0071 adapt_dequantize_linear_21_20(graph, node);
0072 return node;
0073 }
0074 };
0075
0076 }
0077 }