File indexing completed on 2025-08-28 08:58:49
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012
0013 #include "onnx/version_converter/adapters/adapter.h"
0014
0015 namespace ONNX_NAMESPACE {
0016 namespace version_conversion {
0017
0018 class GroupNormalization_20_21 final : public Adapter {
0019 public:
0020 explicit GroupNormalization_20_21() : Adapter("GroupNormalization", OpSetID(20), OpSetID(21)) {}
0021
0022 void transform_input(
0023 std::shared_ptr<Graph> graph,
0024 Node* node,
0025 int64_t input_id,
0026 Value* reshape0_shape,
0027 Value* reshape1_shape,
0028 Value* expand_shape) const {
0029 Node* reshape0 = graph->create(kReshape);
0030 reshape0->addInput(node->inputs()[input_id]);
0031 reshape0->addInput(reshape0_shape);
0032 reshape0->insertBefore(node);
0033
0034 Node* expand = graph->create(kExpand);
0035 expand->addInput(reshape0->output());
0036 expand->addInput(expand_shape);
0037 expand->insertBefore(node);
0038
0039 Node* reshape1 = graph->create(kReshape);
0040 reshape1->addInput(expand->output());
0041 reshape1->addInput(reshape1_shape);
0042 reshape1->insertBefore(node);
0043
0044 node->replaceInput(input_id, reshape1->output());
0045 }
0046
0047 void adapt_group_normalization_20_21(std::shared_ptr<Graph> graph, Node* node) const {
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057 Symbol kShape("Shape");
0058 Node* C = graph->create(kShape);
0059 C->i_(kstart, 1);
0060 C->i_(kend, 2);
0061 C->addInput(node->inputs()[0]);
0062 C->insertBefore(node);
0063
0064
0065 Tensor tensor_num_groups;
0066 tensor_num_groups.elem_type() = TensorProto_DataType_INT64;
0067 int64_t num_groups = node->i(knum_groups);
0068 tensor_num_groups.sizes() = {1};
0069 tensor_num_groups.int64s() = {num_groups};
0070 Node* constant_num_groups = graph->create(kConstant);
0071 constant_num_groups->t_(kvalue, tensor_num_groups);
0072 constant_num_groups->insertBefore(node);
0073
0074 Node* div = graph->create(kDiv);
0075 div->addInput(C->output());
0076 div->addInput(constant_num_groups->output());
0077 div->insertBefore(node);
0078
0079
0080 Tensor tensor_one;
0081 tensor_one.elem_type() = TensorProto_DataType_INT64;
0082 tensor_one.sizes() = {1};
0083 tensor_one.int64s() = {1};
0084 Node* constant_one = graph->create(kConstant);
0085 constant_one->t_(kvalue, tensor_one);
0086 constant_one->insertBefore(node);
0087 Node* concat = graph->create(kConcat);
0088 concat->i_(kaxis, 0);
0089 concat->addInput(constant_one->output());
0090 concat->addInput(div->output());
0091 concat->insertBefore(node);
0092
0093
0094 Tensor tensor_reshape0_shape;
0095 tensor_reshape0_shape.elem_type() = TensorProto_DataType_INT64;
0096 tensor_reshape0_shape.sizes() = {2};
0097 tensor_reshape0_shape.int64s() = {-1, 1};
0098 Node* constant_reshape0_shape = graph->create(kConstant);
0099 constant_reshape0_shape->t_(kvalue, tensor_reshape0_shape);
0100 constant_reshape0_shape->insertBefore(node);
0101
0102
0103 Tensor tensor_reshape1_shape;
0104 tensor_reshape1_shape.elem_type() = TensorProto_DataType_INT64;
0105 tensor_reshape1_shape.sizes() = {1};
0106 tensor_reshape1_shape.int64s() = {-1};
0107 Node* constant_reshape1_shape = graph->create(kConstant);
0108 constant_reshape1_shape->t_(kvalue, tensor_reshape1_shape);
0109 constant_reshape1_shape->insertBefore(node);
0110
0111
0112 transform_input(
0113 graph, node, 1, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());
0114 transform_input(
0115 graph, node, 2, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());
0116
0117
0118 node->i_(kstash_type, node->inputs()[0]->elemType());
0119 }
0120
0121 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0122 adapt_group_normalization_20_21(graph, node);
0123 return node;
0124 }
0125 };
0126
0127 }
0128 }