File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <string>
0013 #include <vector>
0014
0015 #include "onnx/version_converter/adapters/adapter.h"
0016
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019
0020 class BroadcastBackwardCompatibility final : public Adapter {
0021 public:
0022 explicit BroadcastBackwardCompatibility(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
0023 : Adapter(op_name, initial, target) {}
0024
0025 void adapt_broadcast_backward_compatibility(std::shared_ptr<Graph>, Node* node) const {
0026
0027
0028
0029
0030 const ArrayRef<Value*>& inputs = node->inputs();
0031 assertInputsAvailable(inputs, name().c_str(), 2);
0032 const std::vector<Dimension>& A_sizes = inputs[0]->sizes();
0033 const std::vector<Dimension>& B_sizes = inputs[1]->sizes();
0034
0035
0036
0037
0038 int req_broadcast = check_numpy_unibroadcastable_and_require_broadcast(A_sizes, B_sizes);
0039 ONNX_ASSERTM(
0040 req_broadcast != -1,
0041 "%s being converted from %d to %d does "
0042 "not have broadcastable inputs.",
0043 name().c_str(),
0044 initial_version().version(),
0045 target_version().version());
0046 if (req_broadcast == 1) {
0047
0048
0049 node->i_(kbroadcast, 1);
0050 }
0051 }
0052
0053 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0054 adapt_broadcast_backward_compatibility(graph, node);
0055 return node;
0056 }
0057 };
0058
0059 }
0060 }