Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-08-27 09:38:45

0001 // Copyright (c) Microsoft Corporation. All rights reserved.
0002 // Licensed under the MIT License.
0003 
0004 // Summary
0005 // The header has APIs to save custom op authors the trouble of defining schemas,
0006 // which will be inferred by functions' signature, as long as their argument list has types supported here.
0007 // Input could be:
0008 // 1. Tensor of onnx data types.
0009 // 2. Span of onnx data types.
0010 // 3. Scalar of onnx data types.
0011 // A input could be optional if indicated as std::optional<...>.
0012 // For an output, it must be a tensor of onnx data types.
0013 // Further, the header also has utility for a simple custom struct, where resources could be kept, to be registered as a custom op.
0014 // For concrete examples, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
0015 // Note - all APIs in this header are ABI.
0016 
0017 #pragma once
0018 #include "onnxruntime_cxx_api.h"
0019 #include <optional>
0020 #include <numeric>
0021 #include <functional>
0022 #include <unordered_set>
0023 
0024 namespace Ort {
0025 namespace Custom {
0026 
0027 class ArgBase {
0028  public:
0029   ArgBase(OrtKernelContext* ctx,
0030           size_t indice,
0031           bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
0032   virtual ~ArgBase() {};
0033 
0034  protected:
0035   struct KernelContext ctx_;
0036   size_t indice_;
0037   bool is_input_;
0038 };
0039 
0040 using ArgPtr = std::unique_ptr<Custom::ArgBase>;
0041 using ArgPtrs = std::vector<ArgPtr>;
0042 
0043 class TensorBase : public ArgBase {
0044  public:
0045   TensorBase(OrtKernelContext* ctx,
0046              size_t indice,
0047              bool is_input) : ArgBase(ctx, indice, is_input) {}
0048 
0049   operator bool() const {
0050     return shape_.has_value();
0051   }
0052 
0053   const std::vector<int64_t>& Shape() const {
0054     if (!shape_.has_value()) {
0055       ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0056     }
0057     return shape_.value();
0058   }
0059 
0060   ONNXTensorElementDataType Type() const {
0061     return type_;
0062   }
0063 
0064   int64_t NumberOfElement() const {
0065     if (shape_.has_value()) {
0066       return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
0067     } else {
0068       return 0;
0069     }
0070   }
0071 
0072   std::string Shape2Str() const {
0073     if (shape_.has_value()) {
0074       std::string shape_str;
0075       for (const auto& dim : *shape_) {
0076         shape_str.append(std::to_string(dim));
0077         shape_str.append(", ");
0078       }
0079       return shape_str;
0080     } else {
0081       return "empty";
0082     }
0083   }
0084 
0085   bool IsCpuTensor() const {
0086     return strcmp("Cpu", mem_type_) == 0;
0087   }
0088 
0089   virtual const void* DataRaw() const = 0;
0090   virtual size_t SizeInBytes() const = 0;
0091 
0092  protected:
0093   std::optional<std::vector<int64_t>> shape_;
0094   ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
0095   const char* mem_type_ = "Cpu";
0096 };
0097 
0098 template <typename T>
0099 struct Span {
0100   const T* data_ = {};
0101   size_t size_ = {};
0102   void Assign(const T* data, size_t size) {
0103     data_ = data;
0104     size_ = size;
0105   }
0106   size_t size() const { return size_; }
0107   T operator[](size_t indice) const {
0108     return data_[indice];
0109   }
0110   const T* data() const { return data_; }
0111 };
0112 
0113 template <typename T>
0114 class Tensor : public TensorBase {
0115  public:
0116   using TT = typename std::remove_reference<T>::type;
0117   Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
0118     if (is_input_) {
0119       if (indice >= ctx_.GetInputCount()) {
0120         ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
0121       }
0122       const_value_ = ctx_.GetInput(indice);
0123       auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
0124       shape_ = type_shape_info.GetShape();
0125     }
0126   }
0127   const TT* Data() const {
0128     return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
0129   }
0130   TT* Allocate(const std::vector<int64_t>& shape) {
0131     shape_ = shape;
0132     if (!data_) {
0133       shape_ = shape;
0134       data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
0135     }
0136     return data_;
0137   }
0138   static TT GetT() { return (TT)0; }
0139   const Span<T>& AsSpan() {
0140     if (!shape_.has_value() || shape_->size() != 1) {
0141       ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
0142                         OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0143     }
0144     span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
0145     return span_;
0146   }
0147   const T& AsScalar() {
0148     if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
0149       ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
0150                         OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0151     }
0152     return *Data();
0153   }
0154   const void* DataRaw() const override {
0155     return reinterpret_cast<const void*>(Data());
0156   }
0157 
0158   size_t SizeInBytes() const override {
0159     return sizeof(TT) * static_cast<size_t>(NumberOfElement());
0160   }
0161 
0162  private:
0163   ConstValue const_value_;  // for input
0164   TT* data_{};              // for output
0165   Span<T> span_;
0166 };
0167 
0168 template <>
0169 class Tensor<std::string> : public TensorBase {
0170  public:
0171   using strings = std::vector<std::string>;
0172 
0173   Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
0174     if (is_input_) {
0175       if (indice >= ctx_.GetInputCount()) {
0176         ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
0177       }
0178       auto const_value = ctx_.GetInput(indice);
0179       auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
0180       shape_ = type_shape_info.GetShape();
0181       auto num_chars = const_value.GetStringTensorDataLength();
0182       // note - there will be copy ...
0183       auto num_strings = static_cast<size_t>(NumberOfElement());
0184       if (num_strings) {
0185         std::vector<char> chars(num_chars + 1, '\0');
0186         std::vector<size_t> offsets(num_strings);
0187         const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
0188         auto upper_bound = num_strings - 1;
0189         input_strings_.resize(num_strings);
0190         for (size_t i = upper_bound;; --i) {
0191           if (i < upper_bound) {
0192             chars[offsets[i + 1]] = '\0';
0193           }
0194           input_strings_[i] = chars.data() + offsets[i];
0195           if (0 == i) {
0196             break;
0197           }
0198         }
0199       }
0200     }
0201   }
0202   const strings& Data() const {
0203     return input_strings_;
0204   }
0205   const void* DataRaw() const override {
0206     if (input_strings_.size() != 1) {
0207       ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0208     }
0209     return reinterpret_cast<const void*>(input_strings_[0].c_str());
0210   }
0211   size_t SizeInBytes() const override {
0212     if (input_strings_.size() != 1) {
0213       ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0214     }
0215     return input_strings_[0].size();
0216   }
0217   void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
0218     shape_ = dims;
0219     std::vector<const char*> raw;
0220     for (const auto& s : ss) {
0221       raw.push_back(s.data());
0222     }
0223     auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
0224     // note - there will be copy ...
0225     output.FillStringTensor(raw.data(), raw.size());
0226   }
0227   const Span<std::string>& AsSpan() {
0228     ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0229   }
0230   const std::string& AsScalar() {
0231     if (input_strings_.size() != 1) {
0232       ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
0233                         OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0234     }
0235     return input_strings_[0];
0236   }
0237 
0238  private:
0239   std::vector<std::string> input_strings_;  // for input
0240 };
0241 
0242 template <>
0243 class Tensor<std::string_view> : public TensorBase {
0244  public:
0245   using strings = std::vector<std::string>;
0246   using string_views = std::vector<std::string_view>;
0247 
0248   Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
0249     if (is_input_) {
0250       if (indice >= ctx_.GetInputCount()) {
0251         ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
0252       }
0253       auto const_value = ctx_.GetInput(indice);
0254       auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
0255       shape_ = type_shape_info.GetShape();
0256       auto num_chars = const_value.GetStringTensorDataLength();
0257       chars_.resize(num_chars + 1, '\0');
0258       auto num_strings = static_cast<size_t>(NumberOfElement());
0259       if (num_strings) {
0260         std::vector<size_t> offsets(num_strings);
0261         const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
0262         offsets.push_back(num_chars);
0263         for (size_t i = 0; i < num_strings; ++i) {
0264           input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
0265         }
0266       }
0267     }
0268   }
0269   const string_views& Data() const {
0270     return input_string_views_;
0271   }
0272   const void* DataRaw() const override {
0273     if (input_string_views_.size() != 1) {
0274       ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0275     }
0276     return reinterpret_cast<const void*>(input_string_views_[0].data());
0277   }
0278   size_t SizeInBytes() const override {
0279     if (input_string_views_.size() != 1) {
0280       ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0281     }
0282     return input_string_views_[0].size();
0283   }
0284   void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
0285     shape_ = dims;
0286     std::vector<const char*> raw;
0287     for (const auto& s : ss) {
0288       raw.push_back(s.data());
0289     }
0290     auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
0291     // note - there will be copy ...
0292     output.FillStringTensor(raw.data(), raw.size());
0293   }
0294   const Span<std::string_view>& AsSpan() {
0295     ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0296   }
0297   std::string_view AsScalar() {
0298     if (input_string_views_.size() != 1) {
0299       ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
0300                         OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0301     }
0302     return input_string_views_[0];
0303   }
0304 
0305  private:
0306   std::vector<char> chars_;                           // for input
0307   std::vector<std::string_view> input_string_views_;  // for input
0308 };
0309 
0310 using TensorPtr = std::unique_ptr<Custom::TensorBase>;
0311 using TensorPtrs = std::vector<TensorPtr>;
0312 
0313 struct TensorArray : public ArgBase {
0314   TensorArray(OrtKernelContext* ctx,
0315               size_t start_indice,
0316               bool is_input) : ArgBase(ctx,
0317                                        start_indice,
0318                                        is_input) {
0319     if (is_input) {
0320       auto input_count = ctx_.GetInputCount();
0321       for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
0322         auto const_value = ctx_.GetInput(start_indice);
0323         auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
0324         auto type = type_shape_info.GetElementType();
0325         TensorPtr tensor;
0326         switch (type) {
0327           case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
0328             tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
0329             break;
0330           case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
0331             tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
0332             break;
0333           case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
0334             tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
0335             break;
0336           case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
0337             tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
0338             break;
0339           case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
0340             tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
0341             break;
0342           case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
0343             tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
0344             break;
0345           case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
0346             tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
0347             break;
0348           case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
0349             tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
0350             break;
0351           case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
0352             tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
0353             break;
0354           case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
0355             tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
0356             break;
0357           case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
0358             tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
0359             break;
0360           case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
0361             tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
0362             break;
0363           default:
0364             ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
0365             break;
0366         }
0367         tensors_.emplace_back(tensor.release());
0368       }  // for
0369     }
0370   }
0371   template <typename T>
0372   T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
0373     // ith_output is the indice of output relative to the tensor array
0374     // indice_ + ith_output is the indice relative to context
0375     auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
0376     auto raw_output = tensor.get()->Allocate(shape);
0377     tensors_.emplace_back(tensor.release());
0378     return raw_output;
0379   }
0380   Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
0381     // ith_output is the indice of output relative to the tensor array
0382     // indice_ + ith_output is the indice relative to context
0383     auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
0384     Tensor<std::string>& output = *tensor;
0385     tensors_.emplace_back(tensor.release());
0386     return output;
0387   }
0388   size_t Size() const {
0389     return tensors_.size();
0390   }
0391   const TensorPtr& operator[](size_t ith_input) const {
0392     // ith_input is the indice of output relative to the tensor array
0393     return tensors_.at(ith_input);
0394   }
0395 
0396  private:
0397   TensorPtrs tensors_;
0398 };
0399 
0400 using Variadic = TensorArray;
0401 
0402 /*
0403 Note:
0404 OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
0405 The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
0406 1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierarchy.
0407 2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
0408    hence memory could still be recycled properly.
0409 Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
0410 */
0411 struct OrtLiteCustomOp : public OrtCustomOp {
0412   using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
0413   using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
0414 
0415   // CreateTuple
0416   template <size_t ith_input, size_t ith_output, typename... Ts>
0417   static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
0418   CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
0419     return std::make_tuple();
0420   }
0421 
0422   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0423   static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
0424   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0425     std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
0426     auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0427     return std::tuple_cat(current, next);
0428   }
0429 
0430   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0431   static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
0432   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0433     std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
0434     auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0435     return std::tuple_cat(current, next);
0436   }
0437 
0438 #ifdef ORT_CUDA_CTX
0439   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0440   static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
0441   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0442     thread_local CudaContext cuda_context;
0443     cuda_context.Init(*context);
0444     std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
0445     auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0446     return std::tuple_cat(current, next);
0447   }
0448 #endif
0449 
0450 #ifdef ORT_ROCM_CTX
0451   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0452   static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
0453   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0454     thread_local RocmContext rocm_context;
0455     rocm_context.Init(*context);
0456     std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
0457     auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0458     return std::tuple_cat(current, next);
0459   }
0460 #endif
0461 
0462   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0463   static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
0464   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0465     args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
0466     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
0467     auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
0468     return std::tuple_cat(current, next);
0469   }
0470 
0471   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0472   static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
0473   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0474     args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
0475     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
0476     auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
0477     return std::tuple_cat(current, next);
0478   }
0479 
0480   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0481   static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
0482   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0483     args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
0484     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
0485     auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
0486     return std::tuple_cat(current, next);
0487   }
0488 
0489   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0490   static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
0491   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0492     args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
0493     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
0494     auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
0495     return std::tuple_cat(current, next);
0496   }
0497 
0498 #define CREATE_TUPLE_INPUT(data_type)                                                                                                 \
0499   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0500   static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type                \
0501   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0502     args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                            \
0503     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};                                                    \
0504     auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                              \
0505     return std::tuple_cat(current, next);                                                                                             \
0506   }                                                                                                                                   \
0507   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0508   static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type                \
0509   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0510     args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                            \
0511     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};                                                   \
0512     auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                              \
0513     return std::tuple_cat(current, next);                                                                                             \
0514   }                                                                                                                                   \
0515   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0516   static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
0517   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0518     if (ith_input < num_input) {                                                                                                      \
0519       args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                          \
0520       std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())};                         \
0521       auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                            \
0522       return std::tuple_cat(current, next);                                                                                           \
0523     } else {                                                                                                                          \
0524       std::tuple<T> current = std::tuple<T>{};                                                                                        \
0525       auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                            \
0526       return std::tuple_cat(current, next);                                                                                           \
0527     }                                                                                                                                 \
0528   }                                                                                                                                   \
0529   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0530   static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type                  \
0531   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0532     if ("CPUExecutionProvider" != ep) {                                                                                               \
0533       ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION);                           \
0534     }                                                                                                                                 \
0535     args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                            \
0536     std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()};                \
0537     auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                              \
0538     return std::tuple_cat(current, next);                                                                                             \
0539   }                                                                                                                                   \
0540   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0541   static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type                  \
0542   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0543     if ("CPUExecutionProvider" != ep) {                                                                                               \
0544       ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION);                           \
0545     }                                                                                                                                 \
0546     args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                            \
0547     std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()};                 \
0548     auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                              \
0549     return std::tuple_cat(current, next);                                                                                             \
0550   }                                                                                                                                   \
0551   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0552   static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type   \
0553   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0554     if (ith_input < num_input) {                                                                                                      \
0555       if ("CPUExecutionProvider" != ep) {                                                                                             \
0556         ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION);                         \
0557       }                                                                                                                               \
0558       args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                          \
0559       std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()};              \
0560       auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                            \
0561       return std::tuple_cat(current, next);                                                                                           \
0562     } else {                                                                                                                          \
0563       std::tuple<T> current = std::tuple<T>{};                                                                                        \
0564       auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                            \
0565       return std::tuple_cat(current, next);                                                                                           \
0566     }                                                                                                                                 \
0567   }                                                                                                                                   \
0568   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0569   static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type                                       \
0570   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0571     if ("CPUExecutionProvider" != ep) {                                                                                               \
0572       ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION);                         \
0573     }                                                                                                                                 \
0574     args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                            \
0575     std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()};               \
0576     auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                              \
0577     return std::tuple_cat(current, next);                                                                                             \
0578   }                                                                                                                                   \
0579   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                          \
0580   static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type                        \
0581   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {                 \
0582     if (ith_input < num_input) {                                                                                                      \
0583       if ("CPUExecutionProvider" != ep) {                                                                                             \
0584         ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION);                       \
0585       }                                                                                                                               \
0586       args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true));                                          \
0587       std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()};             \
0588       auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                            \
0589       return std::tuple_cat(current, next);                                                                                           \
0590     } else {                                                                                                                          \
0591       std::tuple<T> current = std::tuple<T>{};                                                                                        \
0592       auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);                            \
0593       return std::tuple_cat(current, next);                                                                                           \
0594     }                                                                                                                                 \
0595   }
0596 #define CREATE_TUPLE_OUTPUT(data_type)                                                                                          \
0597   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                    \
0598   static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type                \
0599   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {           \
0600     args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false));                                    \
0601     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};                                              \
0602     auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);                        \
0603     return std::tuple_cat(current, next);                                                                                       \
0604   }                                                                                                                             \
0605   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                    \
0606   static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type                \
0607   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {           \
0608     args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false));                                    \
0609     std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};                                             \
0610     auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);                        \
0611     return std::tuple_cat(current, next);                                                                                       \
0612   }                                                                                                                             \
0613   template <size_t ith_input, size_t ith_output, typename T, typename... Ts>                                                    \
0614   static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
0615   CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {           \
0616     if (ith_output < num_output) {                                                                                              \
0617       args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false));                                  \
0618       std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())};                   \
0619       auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);                      \
0620       return std::tuple_cat(current, next);                                                                                     \
0621     } else {                                                                                                                    \
0622       std::tuple<T> current = std::tuple<T>{};                                                                                  \
0623       auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);                      \
0624       return std::tuple_cat(current, next);                                                                                     \
0625     }                                                                                                                           \
0626   }
0627 #define CREATE_TUPLE(data_type) \
0628   CREATE_TUPLE_INPUT(data_type) \
0629   CREATE_TUPLE_OUTPUT(data_type)
0630 
0631   CREATE_TUPLE(bool)
0632   CREATE_TUPLE(float)
0633   CREATE_TUPLE(Ort::Float16_t)
0634   CREATE_TUPLE(Ort::BFloat16_t)
0635   CREATE_TUPLE(double)
0636   CREATE_TUPLE(int8_t)
0637   CREATE_TUPLE(int16_t)
0638   CREATE_TUPLE(int32_t)
0639   CREATE_TUPLE(int64_t)
0640   CREATE_TUPLE(uint8_t)
0641   CREATE_TUPLE(uint16_t)
0642   CREATE_TUPLE(uint32_t)
0643   CREATE_TUPLE(uint64_t)
0644   CREATE_TUPLE(std::string)
0645   CREATE_TUPLE_INPUT(std::string_view)
0646   CREATE_TUPLE(Ort::Float8E4M3FN_t)
0647   CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
0648   CREATE_TUPLE(Ort::Float8E5M2_t)
0649   CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
0650 
0651   // ParseArgs ...
0652   template <typename... Ts>
0653   static typename std::enable_if<0 == sizeof...(Ts)>::type
0654   ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
0655   }
0656 
0657   template <typename T, typename... Ts>
0658   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
0659   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0660     ParseArgs<Ts...>(input_types, output_types);
0661   }
0662 
0663   template <typename T, typename... Ts>
0664   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
0665   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0666     ParseArgs<Ts...>(input_types, output_types);
0667   }
0668 
0669 #ifdef ORT_CUDA_CTX
0670   template <typename T, typename... Ts>
0671   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
0672   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0673     ParseArgs<Ts...>(input_types, output_types);
0674   }
0675 #endif
0676 
0677 #ifdef ORT_ROCM_CTX
0678   template <typename T, typename... Ts>
0679   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
0680   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0681     ParseArgs<Ts...>(input_types, output_types);
0682   }
0683 #endif
0684 
0685   template <typename T, typename... Ts>
0686   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
0687   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0688     input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0689     ParseArgs<Ts...>(input_types, output_types);
0690   }
0691 
0692   template <typename T, typename... Ts>
0693   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
0694   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0695     input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0696     ParseArgs<Ts...>(input_types, output_types);
0697   }
0698 
0699   template <typename T, typename... Ts>
0700   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
0701   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0702     output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0703     ParseArgs<Ts...>(input_types, output_types);
0704   }
0705 
0706   template <typename T, typename... Ts>
0707   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
0708   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0709     output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0710     ParseArgs<Ts...>(input_types, output_types);
0711   }
0712 
0713 #define PARSE_INPUT_BASE(pack_type, onnx_type)                                                                           \
0714   template <typename T, typename... Ts>                                                                                  \
0715   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type                          \
0716   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0717     input_types.push_back(onnx_type);                                                                                    \
0718     ParseArgs<Ts...>(input_types, output_types);                                                                         \
0719   }                                                                                                                      \
0720   template <typename T, typename... Ts>                                                                                  \
0721   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type     \
0722   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0723     input_types.push_back(onnx_type);                                                                                    \
0724     ParseArgs<Ts...>(input_types, output_types);                                                                         \
0725   }                                                                                                                      \
0726   template <typename T, typename... Ts>                                                                                  \
0727   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type           \
0728   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0729     input_types.push_back(onnx_type);                                                                                    \
0730     ParseArgs<Ts...>(input_types, output_types);                                                                         \
0731   }
0732 
0733 #define PARSE_INPUT(data_type, onnx_type)                       \
0734   PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
0735   PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
0736   PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type)   \
0737   PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type)   \
0738   PARSE_INPUT_BASE(data_type, onnx_type)
0739 
0740 #define PARSE_OUTPUT(data_type, onnx_type)                                                                                      \
0741   template <typename T, typename... Ts>                                                                                         \
0742   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type                \
0743   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {        \
0744     output_types.push_back(onnx_type);                                                                                          \
0745     ParseArgs<Ts...>(input_types, output_types);                                                                                \
0746   }                                                                                                                             \
0747   template <typename T, typename... Ts>                                                                                         \
0748   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type                \
0749   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {        \
0750     output_types.push_back(onnx_type);                                                                                          \
0751     ParseArgs<Ts...>(input_types, output_types);                                                                                \
0752   }                                                                                                                             \
0753   template <typename T, typename... Ts>                                                                                         \
0754   static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
0755   ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {        \
0756     output_types.push_back(onnx_type);                                                                                          \
0757     ParseArgs<Ts...>(input_types, output_types);                                                                                \
0758   }
0759 
0760 #define PARSE_ARGS(data_type, onnx_type) \
0761   PARSE_INPUT(data_type, onnx_type)      \
0762   PARSE_OUTPUT(data_type, onnx_type)
0763 
0764   PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
0765   PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
0766   PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
0767   PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
0768   PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
0769   PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
0770   PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
0771   PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
0772   PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
0773   PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
0774   PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
0775   PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
0776   PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
0777   PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
0778   PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)  // todo - remove string_view output
0779   PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
0780   PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
0781   PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
0782   PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
0783 
0784   OrtLiteCustomOp(const char* op_name,
0785                   const char* execution_provider,
0786                   ShapeInferFn shape_infer_fn,
0787                   int start_ver = 1,
0788                   int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
0789                                                          execution_provider_(execution_provider),
0790                                                          shape_infer_fn_(shape_infer_fn),
0791                                                          start_ver_(start_ver),
0792                                                          end_ver_(end_ver) {
0793     OrtCustomOp::version = ORT_API_VERSION;
0794 
0795     OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
0796     OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
0797     OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
0798 
0799     OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
0800       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0801       return self->input_types_.size();
0802     };
0803 
0804     OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
0805       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0806       return self->input_types_[indice];
0807     };
0808 
0809     OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
0810       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0811       return self->output_types_.size();
0812     };
0813 
0814     OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
0815       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0816       return self->output_types_[indice];
0817     };
0818 
0819     OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
0820       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0821       return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
0822     };
0823 
0824     OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
0825       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0826       return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
0827     };
0828 
0829     OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
0830       return 1;
0831     };
0832 
0833     OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
0834       return 0;
0835     };
0836 
0837     OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
0838       return 1;
0839     };
0840 
0841     OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
0842       return 0;
0843     };
0844 
0845     OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
0846     OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
0847     OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
0848     OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
0849 
0850     OrtCustomOp::CreateKernelV2 = {};
0851     OrtCustomOp::KernelComputeV2 = {};
0852     OrtCustomOp::KernelCompute = {};
0853 
0854     OrtCustomOp::InferOutputShapeFn = {};
0855 
0856     OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
0857       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0858       return self->start_ver_;
0859     };
0860 
0861     OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
0862       auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0863       return self->end_ver_;
0864     };
0865 
0866     OrtCustomOp::GetMayInplace = {};
0867     OrtCustomOp::ReleaseMayInplace = {};
0868     OrtCustomOp::GetAliasMap = {};
0869     OrtCustomOp::ReleaseAliasMap = {};
0870   }
0871 
0872   const std::string op_name_;
0873   const std::string execution_provider_;
0874 
0875   std::vector<ONNXTensorElementDataType> input_types_;
0876   std::vector<ONNXTensorElementDataType> output_types_;
0877 
0878   ShapeInferFn shape_infer_fn_ = {};
0879 
0880   int start_ver_ = 1;
0881   int end_ver_ = MAX_CUSTOM_OP_END_VER;
0882 
0883   void* compute_fn_ = {};
0884   void* compute_fn_return_status_ = {};
0885 };
0886 
0887 //////////////////////////// OrtLiteCustomFunc ////////////////////////////////
0888 // The struct is to implement function-as-op.
0889 // E.g. a function might be defined as:
0890 //   void Filter(const Ort::Custom::Tensor<float>& floats_in, Ort::Custom::Tensor<float>& floats_out) { ... }
0891 // It could be registered this way:
0892 //   Ort::CustomOpDomain v2_domain{"v2"};
0893 //   std::unique_ptr<OrtLiteCustomOp> fil_op_ptr{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
0894 //   v2_domain.Add(fil_op_ptr.get());
0895 //   session_options.Add(v2_domain);
0896 // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
0897 template <typename... Args>
0898 struct OrtLiteCustomFunc : public OrtLiteCustomOp {
0899   using ComputeFn = void (*)(Args...);
0900   using ComputeFnReturnStatus = Status (*)(Args...);
0901   using MyType = OrtLiteCustomFunc<Args...>;
0902 
0903   struct Kernel {
0904     size_t num_input_{};
0905     size_t num_output_{};
0906     ComputeFn compute_fn_{};
0907     ComputeFnReturnStatus compute_fn_return_status_{};
0908     std::string ep_{};
0909   };
0910 
0911   OrtLiteCustomFunc(const char* op_name,
0912                     const char* execution_provider,
0913                     ComputeFn compute_fn,
0914                     ShapeInferFn shape_infer_fn = {},
0915                     int start_ver = 1,
0916                     int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
0917     compute_fn_ = reinterpret_cast<void*>(compute_fn);
0918     ParseArgs<Args...>(input_types_, output_types_);
0919 
0920     OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
0921       auto kernel = reinterpret_cast<Kernel*>(op_kernel);
0922       std::vector<ArgPtr> args;
0923       auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
0924       std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
0925     };
0926 
0927     OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
0928       auto kernel = std::make_unique<Kernel>();
0929       auto me = static_cast<const MyType*>(this_);
0930       kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
0931       Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
0932       Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
0933       auto self = static_cast<const OrtLiteCustomFunc*>(this_);
0934       kernel->ep_ = self->execution_provider_;
0935       return reinterpret_cast<void*>(kernel.release());
0936     };
0937 
0938     OrtCustomOp::KernelDestroy = [](void* op_kernel) {
0939       delete reinterpret_cast<Kernel*>(op_kernel);
0940     };
0941 
0942     if (shape_infer_fn_) {
0943       OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
0944         auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
0945         ShapeInferContext ctx(&GetApi(), ort_ctx);
0946         return shape_info_fn(ctx);
0947       };
0948     }
0949   }
0950 
0951   OrtLiteCustomFunc(const char* op_name,
0952                     const char* execution_provider,
0953                     ComputeFnReturnStatus compute_fn_return_status,
0954                     ShapeInferFn shape_infer_fn = {},
0955                     int start_ver = 1,
0956                     int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
0957     compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
0958     ParseArgs<Args...>(input_types_, output_types_);
0959 
0960     OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
0961       auto kernel = reinterpret_cast<Kernel*>(op_kernel);
0962       std::vector<ArgPtr> args;
0963       auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
0964       return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
0965     };
0966 
0967     OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
0968       auto kernel = std::make_unique<Kernel>();
0969       auto me = static_cast<const MyType*>(this_);
0970       kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
0971       Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
0972       Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
0973       auto self = static_cast<const OrtLiteCustomFunc*>(this_);
0974       kernel->ep_ = self->execution_provider_;
0975       return reinterpret_cast<void*>(kernel.release());
0976     };
0977 
0978     OrtCustomOp::KernelDestroy = [](void* op_kernel) {
0979       delete reinterpret_cast<Kernel*>(op_kernel);
0980     };
0981 
0982     if (shape_infer_fn_) {
0983       OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
0984         auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
0985         ShapeInferContext ctx(&GetApi(), ort_ctx);
0986         return shape_info_fn(ctx);
0987       };
0988     }
0989   }
0990 };  // struct OrtLiteCustomFunc
0991 
0992 /////////////////////////// OrtLiteCustomStruct ///////////////////////////
0993 // The struct is to implement struct-as-op.
0994 // E.g. a struct might be defined as:
0995 //   struct Merge {
0996 //      Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {...}
0997 //      void Compute(const Ort::Custom::Tensor<std::string_view>& strings_in,
0998 //                   std::string_view string_in,
0999 //                   Ort::Custom::Tensor<std::string>* strings_out) {...}
1000 //      bool reverse_ = false;
1001 //   };
1002 // It could be registered this way:
1003 //   Ort::CustomOpDomain v2_domain{"v2"};
1004 //   std::unique_ptr<OrtLiteCustomOp> mrg_op_ptr{Ort::Custom::CreateLiteCustomOp<Merge>("Merge", "CPUExecutionProvider")};
1005 //   v2_domain.Add(mrg_op_ptr.get());
1006 //   session_options.Add(v2_domain);
1007 // For the complete example, please search keyword "LiteCustomOpTest" under "<cloned_src_dir>/onnxruntime/test/".
1008 template <typename CustomOp>
1009 struct OrtLiteCustomStruct : public OrtLiteCustomOp {
1010   template <typename... Args>
1011   using CustomComputeFn = void (CustomOp::*)(Args...);
1012 
1013   template <typename... Args>
1014   using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
1015 
1016   using MyType = OrtLiteCustomStruct<CustomOp>;
1017 
1018   struct Kernel {
1019     size_t num_input_{};
1020     size_t num_output_{};
1021     std::unique_ptr<CustomOp> custom_op_;
1022     std::string ep_{};
1023   };
1024 
1025   OrtLiteCustomStruct(const char* op_name,
1026                       const char* execution_provider,
1027                       int start_ver = 1,
1028                       int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
1029     SetCompute(&CustomOp::Compute);
1030 
1031     OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1032       auto kernel = std::make_unique<Kernel>();
1033       Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
1034       Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
1035       kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
1036       auto self = static_cast<const OrtLiteCustomStruct*>(this_);
1037       kernel->ep_ = self->execution_provider_;
1038       return reinterpret_cast<void*>(kernel.release());
1039     };
1040 
1041     OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1042       delete reinterpret_cast<Kernel*>(op_kernel);
1043     };
1044 
1045     SetShapeInfer<CustomOp>(0);
1046   }
1047 
1048   template <typename... Args>
1049   void SetCompute(CustomComputeFn<Args...>) {
1050     ParseArgs<Args...>(input_types_, output_types_);
1051     OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1052       auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1053       ArgPtrs args;
1054       auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1055       std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
1056     };
1057   }
1058 
1059   template <typename... Args>
1060   void SetCompute(CustomComputeFnReturnStatus<Args...>) {
1061     ParseArgs<Args...>(input_types_, output_types_);
1062     OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
1063       auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1064       ArgPtrs args;
1065       auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1066       return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
1067     };
1068   }
1069 
1070   template <typename C>
1071   decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
1072     OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
1073       ShapeInferContext ctx(&GetApi(), ort_ctx);
1074       return C::InferOutputShape(ctx);
1075     };
1076     return {};
1077   }
1078 
1079   template <typename C>
1080   void SetShapeInfer(...) {
1081     OrtCustomOp::InferOutputShapeFn = {};
1082   }
1083 };  // struct OrtLiteCustomStruct
1084 
1085 /////////////////////////// CreateLiteCustomOp ////////////////////////////
1086 
1087 template <typename... Args>
1088 OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1089                                     const char* execution_provider,
1090                                     void (*custom_compute_fn)(Args...),
1091                                     Status (*shape_infer_fn)(ShapeInferContext&) = {},
1092                                     int start_ver = 1,
1093                                     int end_ver = MAX_CUSTOM_OP_END_VER) {
1094   using LiteOp = OrtLiteCustomFunc<Args...>;
1095   return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
1096 }
1097 
1098 template <typename... Args>
1099 OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1100                                     const char* execution_provider,
1101                                     Status (*custom_compute_fn_v2)(Args...),
1102                                     Status (*shape_infer_fn)(ShapeInferContext&) = {},
1103                                     int start_ver = 1,
1104                                     int end_ver = MAX_CUSTOM_OP_END_VER) {
1105   using LiteOp = OrtLiteCustomFunc<Args...>;
1106   return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
1107 }
1108 
1109 template <typename CustomOp>
1110 OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1111                                     const char* execution_provider,
1112                                     int start_ver = 1,
1113                                     int end_ver = MAX_CUSTOM_OP_END_VER) {
1114   using LiteOp = OrtLiteCustomStruct<CustomOp>;
1115   return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
1116 }
1117 
1118 }  // namespace Custom
1119 }  // namespace Ort