Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-08-28 08:58:49

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for GroupNormalization in default domain from version 20 to 21
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 
0013 #include "onnx/version_converter/adapters/adapter.h"
0014 
0015 namespace ONNX_NAMESPACE {
0016 namespace version_conversion {
0017 
0018 class GroupNormalization_20_21 final : public Adapter {
0019  public:
0020   explicit GroupNormalization_20_21() : Adapter("GroupNormalization", OpSetID(20), OpSetID(21)) {}
0021 
0022   void transform_input(
0023       std::shared_ptr<Graph> graph,
0024       Node* node,
0025       int64_t input_id,
0026       Value* reshape0_shape,
0027       Value* reshape1_shape,
0028       Value* expand_shape) const {
0029     Node* reshape0 = graph->create(kReshape);
0030     reshape0->addInput(node->inputs()[input_id]);
0031     reshape0->addInput(reshape0_shape);
0032     reshape0->insertBefore(node);
0033 
0034     Node* expand = graph->create(kExpand);
0035     expand->addInput(reshape0->output());
0036     expand->addInput(expand_shape);
0037     expand->insertBefore(node);
0038 
0039     Node* reshape1 = graph->create(kReshape);
0040     reshape1->addInput(expand->output());
0041     reshape1->addInput(reshape1_shape);
0042     reshape1->insertBefore(node);
0043 
0044     node->replaceInput(input_id, reshape1->output());
0045   }
0046 
0047   void adapt_group_normalization_20_21(std::shared_ptr<Graph> graph, Node* node) const {
0048     // Perform following sequence of ops on scale/bias, effect is similar to numpy.repeat()
0049     //
0050     //   Shape<start=1,end=2>(input0) -- Div(Shape_out (C), num_groups)
0051     //                                                           |
0052     // Reshape(input1/2, [-1, 1]) ----------- Expand(Reshape_out, [1, Div_out]) -- Reshape(Expand_out, [-1])
0053     //
0054     // The helper function transform_input() implements the bottom row of the diagram
0055 
0056     // Get number of channels: C
0057     Symbol kShape("Shape");
0058     Node* C = graph->create(kShape);
0059     C->i_(kstart, 1);
0060     C->i_(kend, 2);
0061     C->addInput(node->inputs()[0]);
0062     C->insertBefore(node);
0063 
0064     // Get number of channels per group
0065     Tensor tensor_num_groups;
0066     tensor_num_groups.elem_type() = TensorProto_DataType_INT64;
0067     int64_t num_groups = node->i(knum_groups);
0068     tensor_num_groups.sizes() = {1};
0069     tensor_num_groups.int64s() = {num_groups};
0070     Node* constant_num_groups = graph->create(kConstant);
0071     constant_num_groups->t_(kvalue, tensor_num_groups);
0072     constant_num_groups->insertBefore(node);
0073 
0074     Node* div = graph->create(kDiv);
0075     div->addInput(C->output());
0076     div->addInput(constant_num_groups->output());
0077     div->insertBefore(node);
0078 
0079     // Get Expand shape: [1, Div_out]
0080     Tensor tensor_one;
0081     tensor_one.elem_type() = TensorProto_DataType_INT64;
0082     tensor_one.sizes() = {1};
0083     tensor_one.int64s() = {1};
0084     Node* constant_one = graph->create(kConstant);
0085     constant_one->t_(kvalue, tensor_one);
0086     constant_one->insertBefore(node);
0087     Node* concat = graph->create(kConcat);
0088     concat->i_(kaxis, 0);
0089     concat->addInput(constant_one->output());
0090     concat->addInput(div->output());
0091     concat->insertBefore(node);
0092 
0093     // Get shape of first reshape: [-1, 1]
0094     Tensor tensor_reshape0_shape;
0095     tensor_reshape0_shape.elem_type() = TensorProto_DataType_INT64;
0096     tensor_reshape0_shape.sizes() = {2};
0097     tensor_reshape0_shape.int64s() = {-1, 1};
0098     Node* constant_reshape0_shape = graph->create(kConstant);
0099     constant_reshape0_shape->t_(kvalue, tensor_reshape0_shape);
0100     constant_reshape0_shape->insertBefore(node);
0101 
0102     // Get shape of last reshape: [-1]
0103     Tensor tensor_reshape1_shape;
0104     tensor_reshape1_shape.elem_type() = TensorProto_DataType_INT64;
0105     tensor_reshape1_shape.sizes() = {1};
0106     tensor_reshape1_shape.int64s() = {-1};
0107     Node* constant_reshape1_shape = graph->create(kConstant);
0108     constant_reshape1_shape->t_(kvalue, tensor_reshape1_shape);
0109     constant_reshape1_shape->insertBefore(node);
0110 
0111     // transform scale and bias
0112     transform_input(
0113         graph, node, 1, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());
0114     transform_input(
0115         graph, node, 2, constant_reshape0_shape->output(), constant_reshape1_shape->output(), concat->output());
0116 
0117     // Set stash_type
0118     node->i_(kstash_type, node->inputs()[0]->elemType());
0119   }
0120 
0121   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0122     adapt_group_normalization_20_21(graph, node);
0123     return node;
0124   }
0125 };
0126 
0127 } // namespace version_conversion
0128 } // namespace ONNX_NAMESPACE