Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Gemm in default domain from version 7 to 6
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 #include "onnx/version_converter/adapters/adapter.h"
0015 
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 
0019 class Gemm_7_6 final : public Adapter {
0020  public:
0021   explicit Gemm_7_6() : Adapter("Gemm", OpSetID(7), OpSetID(6)) {}
0022 
0023   void adapt_gemm_7_6(std::shared_ptr<Graph>, Node* node) const {
0024     const ArrayRef<Value*>& inputs = node->inputs();
0025     assertInputsAvailable(inputs, name().c_str(), 3);
0026     const auto& A_shape = inputs[0]->sizes();
0027     const auto& B_shape = inputs[1]->sizes();
0028     // Determine if C is broadcastable
0029     const auto& C_shape = inputs[2]->sizes();
0030     // Create (M, N) to input to numpy_unibroadcastable
0031     // TODO: Reconcile fact that shapes aren't determined for 1st 2 inputs
0032     std::vector<Dimension> MN;
0033     if (node->hasAttribute(ktransA) && node->i(ktransA) == 1) {
0034       MN.emplace_back(A_shape[1]);
0035     } else {
0036       MN.emplace_back(A_shape[0]);
0037     }
0038     if (node->hasAttribute(ktransB) && node->i(ktransB) == 1) {
0039       MN.emplace_back(B_shape[0]);
0040     } else {
0041       MN.emplace_back(B_shape[1]);
0042     }
0043     int req_broadcast = check_numpy_unibroadcastable_and_require_broadcast(MN, C_shape);
0044     ONNX_ASSERTM(
0045         req_broadcast != -1,
0046         "%s being converted from %d to %d does "
0047         "not have broadcastable inputs.",
0048         name().c_str(),
0049         initial_version().version(),
0050         target_version().version());
0051     if (req_broadcast == 1) {
0052       node->i_(kbroadcast, 1);
0053     }
0054   }
0055 
0056   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0057     adapt_gemm_7_6(graph, node);
0058     return node;
0059   }
0060 };
0061 
0062 } // namespace version_conversion
0063 } // namespace ONNX_NAMESPACE