Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for broadcasting ops in default domain from version 7 to 6
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <string>
0013 #include <vector>
0014 
0015 #include "onnx/version_converter/adapters/adapter.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019 
0020 class BroadcastBackwardCompatibility final : public Adapter {
0021  public:
0022   explicit BroadcastBackwardCompatibility(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
0023       : Adapter(op_name, initial, target) {}
0024 
0025   void adapt_broadcast_backward_compatibility(std::shared_ptr<Graph>, Node* node) const {
0026     // Verify that broadcasts are allowed in limited spec of opset version 6
0027     // Multidirectional broadcasting, as defined in Broadcasting.md
0028     // MathDocGenerator provides differences
0029     // Main change: encode broadcasting commands as explicit attribute
0030     const ArrayRef<Value*>& inputs = node->inputs();
0031     assertInputsAvailable(inputs, name().c_str(), 2);
0032     const std::vector<Dimension>& A_sizes = inputs[0]->sizes();
0033     const std::vector<Dimension>& B_sizes = inputs[1]->sizes();
0034     // Ensure that first input is larger than or equal to the second
0035     // numpy_unibroadcastable here is considered to be equivalent to opset1_broadcastable
0036     // This is because backwards conversion does not allow for an axis that is not
0037     // suffix matching
0038     int req_broadcast = check_numpy_unibroadcastable_and_require_broadcast(A_sizes, B_sizes);
0039     ONNX_ASSERTM(
0040         req_broadcast != -1,
0041         "%s being converted from %d to %d does "
0042         "not have broadcastable inputs.",
0043         name().c_str(),
0044         initial_version().version(),
0045         target_version().version());
0046     if (req_broadcast == 1) {
0047       // If conditional is not fulfilled, we have a default broadcast
0048       // Add broadcast attribute
0049       node->i_(kbroadcast, 1);
0050     }
0051   }
0052 
0053   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0054     adapt_broadcast_backward_compatibility(graph, node);
0055     return node;
0056   }
0057 };
0058 
0059 } // namespace version_conversion
0060 } // namespace ONNX_NAMESPACE