File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <string>
0013 #include <utility>
0014 #include <vector>
0015
0016 #include "onnx/version_converter/adapters/adapter.h"
0017
0018 namespace ONNX_NAMESPACE {
0019 namespace version_conversion {
0020
0021 class BroadcastForwardCompatibility final : public Adapter {
0022 public:
0023 explicit BroadcastForwardCompatibility(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
0024 : Adapter(op_name, initial, target) {}
0025
0026 void adapt_broadcast_forward_compatibility(std::shared_ptr<Graph> graph, Node* node) const {
0027
0028
0029 if (node->hasAttribute(kbroadcast)) {
0030 const ArrayRef<Value*>& inputs = node->inputs();
0031 assertInputsAvailable(inputs, name().c_str(), 2);
0032 const std::vector<Dimension>& A_sizes = inputs[0]->sizes();
0033 const std::vector<Dimension>& B_sizes = inputs[1]->sizes();
0034
0035 if (node->hasAttribute(kaxis)) {
0036 if (node->i(kaxis) != (int)(A_sizes.size() - B_sizes.size())) {
0037
0038 Node* n = graph->create(kUnsqueeze);
0039 n->addInput(inputs[1]);
0040 std::vector<int64_t> axes;
0041 std::vector<Dimension> new_sizes = B_sizes;
0042 auto size = A_sizes.size() > B_sizes.size() ? A_sizes.size() - B_sizes.size() : 0;
0043 axes.reserve(size);
0044 new_sizes.reserve(new_sizes.size() + size);
0045 for (size_t i = 0; i < size; i++) {
0046 axes.emplace_back(B_sizes.size() + i);
0047 new_sizes.emplace_back(Dimension(1));
0048 }
0049 if (target_version().version() >= 13) {
0050 Tensor t;
0051 t.elem_type() = TensorProto_DataType_INT64;
0052 t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
0053 auto& data = t.int64s();
0054 for (auto a : axes) {
0055 data.emplace_back(a);
0056 }
0057 Node* constant = graph->create(kConstant);
0058 constant->insertBefore(node);
0059 constant->t_(kvalue, t);
0060 node->addInput(constant->output());
0061 } else {
0062 n->is_(kaxes, std::forward<const std::vector<int64_t>>(axes));
0063 }
0064
0065 n->insertBefore(node);
0066
0067 n->output()->setSizes(new_sizes);
0068 node->replaceInput(1, n->output());
0069 }
0070 }
0071 node->removeAttribute(kbroadcast);
0072 }
0073 if (node->hasAttribute(kaxis))
0074 node->removeAttribute(kaxis);
0075
0076 const ArrayRef<Value*>& inputs = node->inputs();
0077 assert_numpy_multibroadcastable(inputs[0]->sizes(), inputs[1]->sizes());
0078 }
0079
0080 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0081 adapt_broadcast_forward_compatibility(graph, node);
0082 return node;
0083 }
0084 };
0085
0086 }
0087 }