Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Softmax amd LogSoftmax in default domain from version 12 to 13
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <string>
0013 
0014 #include "onnx/version_converter/adapters/adapter.h"
0015 
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 
0019 class Softmax_12_13 final : public Adapter {
0020  public:
0021   explicit Softmax_12_13(const std::string& op_name) : Adapter(op_name, OpSetID(12), OpSetID(13)) {}
0022 
0023   void adapt_softmax_12_13(std::shared_ptr<Graph> graph, Node* node) const {
0024     int old_axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 1;
0025     int input_rank = node->inputs()[0]->sizes().size();
0026 
0027     if (old_axis < 0)
0028       old_axis = input_rank + old_axis;
0029 
0030     if (old_axis == input_rank - 1)
0031       node->i_(kaxis, -1);
0032     else {
0033       //    -- shape ------------------
0034       //   /                           |
0035       // ----- flatten -- softmax -- reshape
0036 
0037       // get original softmax's input shape
0038       Symbol kShape("Shape");
0039       Node* shape = graph->create(kShape);
0040       shape->addInput(node->inputs()[0]);
0041       shape->insertBefore(node);
0042 
0043       // Insert Flatten node before softmax
0044       Node* flatten = graph->create(kFlatten);
0045       flatten->addInput(node->inputs()[0]);
0046       flatten->insertBefore(node);
0047       flatten->i_(kaxis, old_axis);
0048       node->replaceInput(0, flatten->output());
0049 
0050       // Softmax along the last axis of the flattened 2D tensor
0051       node->i_(kaxis, -1);
0052 
0053       // Insert Reshape node after softmax
0054       const std::string original_output_name = node->output()->uniqueName();
0055       const use_list original_uses(node->output()->uses());
0056       node->output()->setUniqueName(original_output_name + "_intermediate");
0057       Node* reshape = graph->create(kReshape);
0058       reshape->addInput(node->outputs()[0]);
0059       reshape->addInput(shape->output());
0060       reshape->output()->setUniqueName(original_output_name);
0061       reshape->insertAfter(node);
0062 
0063       // Fix outputs & wiring
0064       if (node->output()->sizes().size() != 0) {
0065         reshape->output()->setSizes(node->output()->sizes());
0066       }
0067       reshape->output()->setElemType(node->output()->elemType());
0068       node->output()->wipeSizes();
0069       for (Use u : original_uses) {
0070         u.user->replaceInputWith(node->output(), reshape->output());
0071       }
0072       for (size_t i = 0; i < graph->outputs().size(); i++) {
0073         if (graph->outputs()[i]->uniqueName() == original_output_name) {
0074           graph->return_node()->replaceInput(i, reshape->output());
0075         }
0076       }
0077     }
0078   }
0079 
0080   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0081     adapt_softmax_12_13(graph, node);
0082     return node;
0083   }
0084 };
0085 
0086 } // namespace version_conversion
0087 } // namespace ONNX_NAMESPACE