File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 #include "onnx/version_converter/adapters/adapter.h"
0015
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018
0019 class Gemm_7_6 final : public Adapter {
0020 public:
0021 explicit Gemm_7_6() : Adapter("Gemm", OpSetID(7), OpSetID(6)) {}
0022
0023 void adapt_gemm_7_6(std::shared_ptr<Graph>, Node* node) const {
0024 const ArrayRef<Value*>& inputs = node->inputs();
0025 assertInputsAvailable(inputs, name().c_str(), 3);
0026 const auto& A_shape = inputs[0]->sizes();
0027 const auto& B_shape = inputs[1]->sizes();
0028
0029 const auto& C_shape = inputs[2]->sizes();
0030
0031
0032 std::vector<Dimension> MN;
0033 if (node->hasAttribute(ktransA) && node->i(ktransA) == 1) {
0034 MN.emplace_back(A_shape[1]);
0035 } else {
0036 MN.emplace_back(A_shape[0]);
0037 }
0038 if (node->hasAttribute(ktransB) && node->i(ktransB) == 1) {
0039 MN.emplace_back(B_shape[0]);
0040 } else {
0041 MN.emplace_back(B_shape[1]);
0042 }
0043 int req_broadcast = check_numpy_unibroadcastable_and_require_broadcast(MN, C_shape);
0044 ONNX_ASSERTM(
0045 req_broadcast != -1,
0046 "%s being converted from %d to %d does "
0047 "not have broadcastable inputs.",
0048 name().c_str(),
0049 initial_version().version(),
0050 target_version().version());
0051 if (req_broadcast == 1) {
0052 node->i_(kbroadcast, 1);
0053 }
0054 }
0055
0056 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0057 adapt_gemm_7_6(graph, node);
0058 return node;
0059 }
0060 };
0061
0062 }
0063 }