Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for broadcasting ops in default domain from version 6 to 7
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <string>
0013 #include <utility>
0014 #include <vector>
0015 
0016 #include "onnx/version_converter/adapters/adapter.h"
0017 
0018 namespace ONNX_NAMESPACE {
0019 namespace version_conversion {
0020 
0021 class BroadcastForwardCompatibility final : public Adapter {
0022  public:
0023   explicit BroadcastForwardCompatibility(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
0024       : Adapter(op_name, initial, target) {}
0025 
0026   void adapt_broadcast_forward_compatibility(std::shared_ptr<Graph> graph, Node* node) const {
0027     // Remove axis and broadcast attributes
0028     // Assess whether axis requires reshaping
0029     if (node->hasAttribute(kbroadcast)) {
0030       const ArrayRef<Value*>& inputs = node->inputs();
0031       assertInputsAvailable(inputs, name().c_str(), 2);
0032       const std::vector<Dimension>& A_sizes = inputs[0]->sizes();
0033       const std::vector<Dimension>& B_sizes = inputs[1]->sizes();
0034       // Also assert that broadcasting syntax are correct if axis is not present
0035       if (node->hasAttribute(kaxis)) {
0036         if (node->i(kaxis) != (int)(A_sizes.size() - B_sizes.size())) {
0037           // Add a Reshape node before input B
0038           Node* n = graph->create(kUnsqueeze);
0039           n->addInput(inputs[1]);
0040           std::vector<int64_t> axes;
0041           std::vector<Dimension> new_sizes = B_sizes;
0042           auto size = A_sizes.size() > B_sizes.size() ? A_sizes.size() - B_sizes.size() : 0;
0043           axes.reserve(size);
0044           new_sizes.reserve(new_sizes.size() + size);
0045           for (size_t i = 0; i < size; i++) {
0046             axes.emplace_back(B_sizes.size() + i);
0047             new_sizes.emplace_back(Dimension(1));
0048           }
0049           if (target_version().version() >= 13) { // Unsqueeze takes 'axes' input
0050             Tensor t;
0051             t.elem_type() = TensorProto_DataType_INT64;
0052             t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
0053             auto& data = t.int64s();
0054             for (auto a : axes) {
0055               data.emplace_back(a);
0056             }
0057             Node* constant = graph->create(kConstant);
0058             constant->insertBefore(node);
0059             constant->t_(kvalue, t);
0060             node->addInput(constant->output());
0061           } else { // Unsqueeze takes 'axes' attribute
0062             n->is_(kaxes, std::forward<const std::vector<int64_t>>(axes));
0063           }
0064           // Move n before node
0065           n->insertBefore(node);
0066           // Set 2nd input to node to 1st of n and output of n to 2nd input to node
0067           n->output()->setSizes(new_sizes);
0068           node->replaceInput(1, n->output());
0069         }
0070       }
0071       node->removeAttribute(kbroadcast);
0072     }
0073     if (node->hasAttribute(kaxis))
0074       node->removeAttribute(kaxis);
0075     // Assert multi_broadcastable on inputs
0076     const ArrayRef<Value*>& inputs = node->inputs();
0077     assert_numpy_multibroadcastable(inputs[0]->sizes(), inputs[1]->sizes());
0078   }
0079 
0080   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0081     adapt_broadcast_forward_compatibility(graph, node);
0082     return node;
0083   }
0084 };
0085 
0086 } // namespace version_conversion
0087 } // namespace ONNX_NAMESPACE