Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Sum in default domain from version 8 to 7
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 #include "onnx/version_converter/adapters/adapter.h"
0015 
0016 namespace ONNX_NAMESPACE {
0017 namespace version_conversion {
0018 
0019 class Sum_8_7 final : public Adapter {
0020  public:
0021   explicit Sum_8_7() : Adapter("Sum", OpSetID(8), OpSetID(7)) {}
0022 
0023   void adapt_sum_8_7(std::shared_ptr<Graph>, Node* node) const {
0024     // Throw an exception if any broadcasting occurs
0025     const ArrayRef<Value*>& inputs = node->inputs();
0026     // Determine if inputs are of different sizes
0027     for (int i = 1; i < (int)inputs.size(); i++) {
0028       std::vector<Dimension> A_sizes = inputs[i - 1]->sizes();
0029       std::vector<Dimension> B_sizes = inputs[i]->sizes();
0030       assert_numpy_multibroadcastable(A_sizes, B_sizes);
0031     }
0032   }
0033 
0034   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0035     adapt_sum_8_7(graph, node);
0036     return node;
0037   }
0038 };
0039 
0040 } // namespace version_conversion
0041 } // namespace ONNX_NAMESPACE