File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 #include "onnx/version_converter/adapters/adapter.h"
0015
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018
0019 class Sum_8_7 final : public Adapter {
0020 public:
0021 explicit Sum_8_7() : Adapter("Sum", OpSetID(8), OpSetID(7)) {}
0022
0023 void adapt_sum_8_7(std::shared_ptr<Graph>, Node* node) const {
0024
0025 const ArrayRef<Value*>& inputs = node->inputs();
0026
0027 for (int i = 1; i < (int)inputs.size(); i++) {
0028 std::vector<Dimension> A_sizes = inputs[i - 1]->sizes();
0029 std::vector<Dimension> B_sizes = inputs[i]->sizes();
0030 assert_numpy_multibroadcastable(A_sizes, B_sizes);
0031 }
0032 }
0033
0034 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0035 adapt_sum_8_7(graph, node);
0036 return node;
0037 }
0038 };
0039
0040 }
0041 }