Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-08-28 08:58:49

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Cast in default domain from version 9 to 8
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 #include "onnx/version_converter/adapters/type_restriction.h"
0015 
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 
0019 static const std::vector<TensorProto_DataType> q_dq_20_unallowed_types = {
0020     TensorProto_DataType_UINT16,
0021     TensorProto_DataType_INT16,
0022     TensorProto_DataType_UINT4,
0023     TensorProto_DataType_INT4};
0024 
0025 class QuantizeLinear_21_20 final : public TypeRestriction {
0026  public:
0027   explicit QuantizeLinear_21_20()
0028       : TypeRestriction("QuantizeLinear", OpSetID(21), OpSetID(20), q_dq_20_unallowed_types) {}
0029 
0030   void adapt_quantize_linear_21_20(std::shared_ptr<Graph>, Node* node) const {
0031     if (node->hasAttribute(kblock_size)) {
0032       if ((node->i(kblock_size) != 0)) {
0033         ONNX_ASSERTM(false, "Blocked quantization is not supported for Opset Version %d.", target_version().version())
0034       }
0035       node->removeAttribute(kblock_size);
0036     }
0037     if (node->hasAttribute(koutput_dtype)) {
0038       if (node->i(koutput_dtype) != TensorProto_DataType_UINT8 && node->inputs().size() < 3) {
0039         ONNX_ASSERTM(
0040             false,
0041             "Attribute output_dtype is not supported for Opset Version %d, supply a zero-point tensor instead",
0042             target_version().version())
0043       }
0044       node->removeAttribute(koutput_dtype);
0045     }
0046   }
0047 
0048   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0049     adapt_type_restriction(graph, node);
0050     adapt_quantize_linear_21_20(graph, node);
0051     return node;
0052   }
0053 };
0054 
0055 class DequantizeLinear_21_20 final : public TypeRestriction {
0056  public:
0057   explicit DequantizeLinear_21_20()
0058       : TypeRestriction("DequantizeLinear", OpSetID(21), OpSetID(20), q_dq_20_unallowed_types) {}
0059 
0060   void adapt_dequantize_linear_21_20(std::shared_ptr<Graph>, Node* node) const {
0061     if (node->hasAttribute(kblock_size)) {
0062       if ((node->i(kblock_size) != 0)) {
0063         ONNX_ASSERTM(false, "Blocked quantization is not supported for Opset Version %d.", target_version().version())
0064       }
0065       node->removeAttribute(kblock_size);
0066     }
0067   }
0068 
0069   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0070     adapt_type_restriction(graph, node);
0071     adapt_dequantize_linear_21_20(graph, node);
0072     return node;
0073   }
0074 };
0075 
0076 } // namespace version_conversion
0077 } // namespace ONNX_NAMESPACE