File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <string>
0013
0014 #include "onnx/version_converter/adapters/adapter.h"
0015
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018
0019 class Softmax_12_13 final : public Adapter {
0020 public:
0021 explicit Softmax_12_13(const std::string& op_name) : Adapter(op_name, OpSetID(12), OpSetID(13)) {}
0022
0023 void adapt_softmax_12_13(std::shared_ptr<Graph> graph, Node* node) const {
0024 int old_axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 1;
0025 int input_rank = node->inputs()[0]->sizes().size();
0026
0027 if (old_axis < 0)
0028 old_axis = input_rank + old_axis;
0029
0030 if (old_axis == input_rank - 1)
0031 node->i_(kaxis, -1);
0032 else {
0033
0034
0035
0036
0037
0038 Symbol kShape("Shape");
0039 Node* shape = graph->create(kShape);
0040 shape->addInput(node->inputs()[0]);
0041 shape->insertBefore(node);
0042
0043
0044 Node* flatten = graph->create(kFlatten);
0045 flatten->addInput(node->inputs()[0]);
0046 flatten->insertBefore(node);
0047 flatten->i_(kaxis, old_axis);
0048 node->replaceInput(0, flatten->output());
0049
0050
0051 node->i_(kaxis, -1);
0052
0053
0054 const std::string original_output_name = node->output()->uniqueName();
0055 const use_list original_uses(node->output()->uses());
0056 node->output()->setUniqueName(original_output_name + "_intermediate");
0057 Node* reshape = graph->create(kReshape);
0058 reshape->addInput(node->outputs()[0]);
0059 reshape->addInput(shape->output());
0060 reshape->output()->setUniqueName(original_output_name);
0061 reshape->insertAfter(node);
0062
0063
0064 if (node->output()->sizes().size() != 0) {
0065 reshape->output()->setSizes(node->output()->sizes());
0066 }
0067 reshape->output()->setElemType(node->output()->elemType());
0068 node->output()->wipeSizes();
0069 for (Use u : original_uses) {
0070 u.user->replaceInputWith(node->output(), reshape->output());
0071 }
0072 for (size_t i = 0; i < graph->outputs().size(); i++) {
0073 if (graph->outputs()[i]->uniqueName() == original_output_name) {
0074 graph->return_node()->replaceInput(i, reshape->output());
0075 }
0076 }
0077 }
0078 }
0079
0080 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0081 adapt_softmax_12_13(graph, node);
0082 return node;
0083 }
0084 };
0085
0086 }
0087 }