File indexing completed on 2025-08-27 09:38:45
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 #pragma once
0018 #include "onnxruntime_cxx_api.h"
0019 #include <optional>
0020 #include <numeric>
0021 #include <functional>
0022 #include <unordered_set>
0023
0024 namespace Ort {
0025 namespace Custom {
0026
0027 class ArgBase {
0028 public:
0029 ArgBase(OrtKernelContext* ctx,
0030 size_t indice,
0031 bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {}
0032 virtual ~ArgBase() {};
0033
0034 protected:
0035 struct KernelContext ctx_;
0036 size_t indice_;
0037 bool is_input_;
0038 };
0039
0040 using ArgPtr = std::unique_ptr<Custom::ArgBase>;
0041 using ArgPtrs = std::vector<ArgPtr>;
0042
0043 class TensorBase : public ArgBase {
0044 public:
0045 TensorBase(OrtKernelContext* ctx,
0046 size_t indice,
0047 bool is_input) : ArgBase(ctx, indice, is_input) {}
0048
0049 operator bool() const {
0050 return shape_.has_value();
0051 }
0052
0053 const std::vector<int64_t>& Shape() const {
0054 if (!shape_.has_value()) {
0055 ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0056 }
0057 return shape_.value();
0058 }
0059
0060 ONNXTensorElementDataType Type() const {
0061 return type_;
0062 }
0063
0064 int64_t NumberOfElement() const {
0065 if (shape_.has_value()) {
0066 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
0067 } else {
0068 return 0;
0069 }
0070 }
0071
0072 std::string Shape2Str() const {
0073 if (shape_.has_value()) {
0074 std::string shape_str;
0075 for (const auto& dim : *shape_) {
0076 shape_str.append(std::to_string(dim));
0077 shape_str.append(", ");
0078 }
0079 return shape_str;
0080 } else {
0081 return "empty";
0082 }
0083 }
0084
0085 bool IsCpuTensor() const {
0086 return strcmp("Cpu", mem_type_) == 0;
0087 }
0088
0089 virtual const void* DataRaw() const = 0;
0090 virtual size_t SizeInBytes() const = 0;
0091
0092 protected:
0093 std::optional<std::vector<int64_t>> shape_;
0094 ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
0095 const char* mem_type_ = "Cpu";
0096 };
0097
0098 template <typename T>
0099 struct Span {
0100 const T* data_ = {};
0101 size_t size_ = {};
0102 void Assign(const T* data, size_t size) {
0103 data_ = data;
0104 size_ = size;
0105 }
0106 size_t size() const { return size_; }
0107 T operator[](size_t indice) const {
0108 return data_[indice];
0109 }
0110 const T* data() const { return data_; }
0111 };
0112
0113 template <typename T>
0114 class Tensor : public TensorBase {
0115 public:
0116 using TT = typename std::remove_reference<T>::type;
0117 Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
0118 if (is_input_) {
0119 if (indice >= ctx_.GetInputCount()) {
0120 ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
0121 }
0122 const_value_ = ctx_.GetInput(indice);
0123 auto type_shape_info = const_value_.GetTensorTypeAndShapeInfo();
0124 shape_ = type_shape_info.GetShape();
0125 }
0126 }
0127 const TT* Data() const {
0128 return reinterpret_cast<const TT*>(const_value_.GetTensorRawData());
0129 }
0130 TT* Allocate(const std::vector<int64_t>& shape) {
0131 shape_ = shape;
0132 if (!data_) {
0133 shape_ = shape;
0134 data_ = ctx_.GetOutput(indice_, shape).template GetTensorMutableData<TT>();
0135 }
0136 return data_;
0137 }
0138 static TT GetT() { return (TT)0; }
0139 const Span<T>& AsSpan() {
0140 if (!shape_.has_value() || shape_->size() != 1) {
0141 ORT_CXX_API_THROW("invalid shape while trying to get a span out of Ort::Custom::Tensor",
0142 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0143 }
0144 span_.Assign(Data(), static_cast<size_t>((*shape_)[0]));
0145 return span_;
0146 }
0147 const T& AsScalar() {
0148 if (!shape_.has_value() || shape_->size() != 1 || (*shape_)[0] != 1) {
0149 ORT_CXX_API_THROW("invalid shape while trying to get a scalar from Ort::Custom::Tensor",
0150 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0151 }
0152 return *Data();
0153 }
0154 const void* DataRaw() const override {
0155 return reinterpret_cast<const void*>(Data());
0156 }
0157
0158 size_t SizeInBytes() const override {
0159 return sizeof(TT) * static_cast<size_t>(NumberOfElement());
0160 }
0161
0162 private:
0163 ConstValue const_value_;
0164 TT* data_{};
0165 Span<T> span_;
0166 };
0167
0168 template <>
0169 class Tensor<std::string> : public TensorBase {
0170 public:
0171 using strings = std::vector<std::string>;
0172
0173 Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
0174 if (is_input_) {
0175 if (indice >= ctx_.GetInputCount()) {
0176 ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
0177 }
0178 auto const_value = ctx_.GetInput(indice);
0179 auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
0180 shape_ = type_shape_info.GetShape();
0181 auto num_chars = const_value.GetStringTensorDataLength();
0182
0183 auto num_strings = static_cast<size_t>(NumberOfElement());
0184 if (num_strings) {
0185 std::vector<char> chars(num_chars + 1, '\0');
0186 std::vector<size_t> offsets(num_strings);
0187 const_value.GetStringTensorContent(static_cast<void*>(chars.data()), num_chars, offsets.data(), offsets.size());
0188 auto upper_bound = num_strings - 1;
0189 input_strings_.resize(num_strings);
0190 for (size_t i = upper_bound;; --i) {
0191 if (i < upper_bound) {
0192 chars[offsets[i + 1]] = '\0';
0193 }
0194 input_strings_[i] = chars.data() + offsets[i];
0195 if (0 == i) {
0196 break;
0197 }
0198 }
0199 }
0200 }
0201 }
0202 const strings& Data() const {
0203 return input_strings_;
0204 }
0205 const void* DataRaw() const override {
0206 if (input_strings_.size() != 1) {
0207 ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0208 }
0209 return reinterpret_cast<const void*>(input_strings_[0].c_str());
0210 }
0211 size_t SizeInBytes() const override {
0212 if (input_strings_.size() != 1) {
0213 ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0214 }
0215 return input_strings_[0].size();
0216 }
0217 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
0218 shape_ = dims;
0219 std::vector<const char*> raw;
0220 for (const auto& s : ss) {
0221 raw.push_back(s.data());
0222 }
0223 auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
0224
0225 output.FillStringTensor(raw.data(), raw.size());
0226 }
0227 const Span<std::string>& AsSpan() {
0228 ORT_CXX_API_THROW("span for TensorT of string not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0229 }
0230 const std::string& AsScalar() {
0231 if (input_strings_.size() != 1) {
0232 ORT_CXX_API_THROW("invalid shape while trying to get a scalar string from Ort::Custom::Tensor",
0233 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0234 }
0235 return input_strings_[0];
0236 }
0237
0238 private:
0239 std::vector<std::string> input_strings_;
0240 };
0241
0242 template <>
0243 class Tensor<std::string_view> : public TensorBase {
0244 public:
0245 using strings = std::vector<std::string>;
0246 using string_views = std::vector<std::string_view>;
0247
0248 Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) {
0249 if (is_input_) {
0250 if (indice >= ctx_.GetInputCount()) {
0251 ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT);
0252 }
0253 auto const_value = ctx_.GetInput(indice);
0254 auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
0255 shape_ = type_shape_info.GetShape();
0256 auto num_chars = const_value.GetStringTensorDataLength();
0257 chars_.resize(num_chars + 1, '\0');
0258 auto num_strings = static_cast<size_t>(NumberOfElement());
0259 if (num_strings) {
0260 std::vector<size_t> offsets(num_strings);
0261 const_value.GetStringTensorContent(static_cast<void*>(chars_.data()), num_chars, offsets.data(), offsets.size());
0262 offsets.push_back(num_chars);
0263 for (size_t i = 0; i < num_strings; ++i) {
0264 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
0265 }
0266 }
0267 }
0268 }
0269 const string_views& Data() const {
0270 return input_string_views_;
0271 }
0272 const void* DataRaw() const override {
0273 if (input_string_views_.size() != 1) {
0274 ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0275 }
0276 return reinterpret_cast<const void*>(input_string_views_[0].data());
0277 }
0278 size_t SizeInBytes() const override {
0279 if (input_string_views_.size() != 1) {
0280 ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
0281 }
0282 return input_string_views_[0].size();
0283 }
0284 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
0285 shape_ = dims;
0286 std::vector<const char*> raw;
0287 for (const auto& s : ss) {
0288 raw.push_back(s.data());
0289 }
0290 auto output = ctx_.GetOutput(indice_, dims.data(), dims.size());
0291
0292 output.FillStringTensor(raw.data(), raw.size());
0293 }
0294 const Span<std::string_view>& AsSpan() {
0295 ORT_CXX_API_THROW("span for TensorT of string view not implemented", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0296 }
0297 std::string_view AsScalar() {
0298 if (input_string_views_.size() != 1) {
0299 ORT_CXX_API_THROW("invalid shape while trying to get a scalar string view from Ort::Custom::Tensor",
0300 OrtErrorCode::ORT_RUNTIME_EXCEPTION);
0301 }
0302 return input_string_views_[0];
0303 }
0304
0305 private:
0306 std::vector<char> chars_;
0307 std::vector<std::string_view> input_string_views_;
0308 };
0309
0310 using TensorPtr = std::unique_ptr<Custom::TensorBase>;
0311 using TensorPtrs = std::vector<TensorPtr>;
0312
0313 struct TensorArray : public ArgBase {
0314 TensorArray(OrtKernelContext* ctx,
0315 size_t start_indice,
0316 bool is_input) : ArgBase(ctx,
0317 start_indice,
0318 is_input) {
0319 if (is_input) {
0320 auto input_count = ctx_.GetInputCount();
0321 for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) {
0322 auto const_value = ctx_.GetInput(start_indice);
0323 auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
0324 auto type = type_shape_info.GetElementType();
0325 TensorPtr tensor;
0326 switch (type) {
0327 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
0328 tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
0329 break;
0330 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
0331 tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
0332 break;
0333 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
0334 tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
0335 break;
0336 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
0337 tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
0338 break;
0339 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
0340 tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
0341 break;
0342 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
0343 tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
0344 break;
0345 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
0346 tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
0347 break;
0348 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
0349 tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
0350 break;
0351 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
0352 tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
0353 break;
0354 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
0355 tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
0356 break;
0357 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
0358 tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
0359 break;
0360 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
0361 tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
0362 break;
0363 default:
0364 ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
0365 break;
0366 }
0367 tensors_.emplace_back(tensor.release());
0368 }
0369 }
0370 }
0371 template <typename T>
0372 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
0373
0374
0375 auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
0376 auto raw_output = tensor.get()->Allocate(shape);
0377 tensors_.emplace_back(tensor.release());
0378 return raw_output;
0379 }
0380 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
0381
0382
0383 auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false);
0384 Tensor<std::string>& output = *tensor;
0385 tensors_.emplace_back(tensor.release());
0386 return output;
0387 }
0388 size_t Size() const {
0389 return tensors_.size();
0390 }
0391 const TensorPtr& operator[](size_t ith_input) const {
0392
0393 return tensors_.at(ith_input);
0394 }
0395
0396 private:
0397 TensorPtrs tensors_;
0398 };
0399
0400 using Variadic = TensorArray;
0401
0402
0403
0404
0405
0406
0407
0408
0409
0410
0411 struct OrtLiteCustomOp : public OrtCustomOp {
0412 using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
0413 using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
0414
0415
0416 template <size_t ith_input, size_t ith_output, typename... Ts>
0417 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
0418 CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) {
0419 return std::make_tuple();
0420 }
0421
0422 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0423 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
0424 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0425 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
0426 auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0427 return std::tuple_cat(current, next);
0428 }
0429
0430 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0431 static typename std::enable_if<std::is_same<T, OrtKernelContext&>::value, std::tuple<T, Ts...>>::type
0432 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0433 std::tuple<T> current = std::tuple<OrtKernelContext&>{*context};
0434 auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0435 return std::tuple_cat(current, next);
0436 }
0437
0438 #ifdef ORT_CUDA_CTX
0439 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0440 static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
0441 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0442 thread_local CudaContext cuda_context;
0443 cuda_context.Init(*context);
0444 std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
0445 auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0446 return std::tuple_cat(current, next);
0447 }
0448 #endif
0449
0450 #ifdef ORT_ROCM_CTX
0451 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0452 static typename std::enable_if<std::is_same<T, const RocmContext&>::value, std::tuple<T, Ts...>>::type
0453 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0454 thread_local RocmContext rocm_context;
0455 rocm_context.Init(*context);
0456 std::tuple<T> current = std::tuple<const RocmContext&>{rocm_context};
0457 auto next = CreateTuple<ith_input, ith_output, Ts...>(context, args, num_input, num_output, ep);
0458 return std::tuple_cat(current, next);
0459 }
0460 #endif
0461
0462 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0463 static typename std::enable_if<std::is_same<T, const TensorArray*>::value, std::tuple<T, Ts...>>::type
0464 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0465 args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
0466 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
0467 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
0468 return std::tuple_cat(current, next);
0469 }
0470
0471 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0472 static typename std::enable_if<std::is_same<T, const TensorArray&>::value, std::tuple<T, Ts...>>::type
0473 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0474 args.push_back(std::make_unique<TensorArray>(context, ith_input, true));
0475 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
0476 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep);
0477 return std::tuple_cat(current, next);
0478 }
0479
0480 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0481 static typename std::enable_if<std::is_same<T, TensorArray*>::value, std::tuple<T, Ts...>>::type
0482 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0483 args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
0484 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())};
0485 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
0486 return std::tuple_cat(current, next);
0487 }
0488
0489 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
0490 static typename std::enable_if<std::is_same<T, TensorArray&>::value, std::tuple<T, Ts...>>::type
0491 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) {
0492 args.push_back(std::make_unique<TensorArray>(context, ith_output, false));
0493 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())};
0494 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep);
0495 return std::tuple_cat(current, next);
0496 }
0497
0498 #define CREATE_TUPLE_INPUT(data_type) \
0499 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0500 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
0501 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0502 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0503 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
0504 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0505 return std::tuple_cat(current, next); \
0506 } \
0507 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0508 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
0509 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0510 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0511 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
0512 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0513 return std::tuple_cat(current, next); \
0514 } \
0515 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0516 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
0517 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0518 if (ith_input < num_input) { \
0519 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0520 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
0521 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0522 return std::tuple_cat(current, next); \
0523 } else { \
0524 std::tuple<T> current = std::tuple<T>{}; \
0525 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0526 return std::tuple_cat(current, next); \
0527 } \
0528 } \
0529 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0530 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
0531 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0532 if ("CPUExecutionProvider" != ep) { \
0533 ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
0534 } \
0535 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0536 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
0537 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0538 return std::tuple_cat(current, next); \
0539 } \
0540 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0541 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
0542 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0543 if ("CPUExecutionProvider" != ep) { \
0544 ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
0545 } \
0546 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0547 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
0548 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0549 return std::tuple_cat(current, next); \
0550 } \
0551 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0552 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
0553 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0554 if (ith_input < num_input) { \
0555 if ("CPUExecutionProvider" != ep) { \
0556 ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
0557 } \
0558 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0559 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsSpan()}; \
0560 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0561 return std::tuple_cat(current, next); \
0562 } else { \
0563 std::tuple<T> current = std::tuple<T>{}; \
0564 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0565 return std::tuple_cat(current, next); \
0566 } \
0567 } \
0568 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0569 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
0570 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0571 if ("CPUExecutionProvider" != ep) { \
0572 ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
0573 } \
0574 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0575 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
0576 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0577 return std::tuple_cat(current, next); \
0578 } \
0579 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0580 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
0581 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0582 if (ith_input < num_input) { \
0583 if ("CPUExecutionProvider" != ep) { \
0584 ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \
0585 } \
0586 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_input, true)); \
0587 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())->AsScalar()}; \
0588 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0589 return std::tuple_cat(current, next); \
0590 } else { \
0591 std::tuple<T> current = std::tuple<T>{}; \
0592 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, args, num_input, num_output, ep); \
0593 return std::tuple_cat(current, next); \
0594 } \
0595 }
0596 #define CREATE_TUPLE_OUTPUT(data_type) \
0597 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0598 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
0599 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0600 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
0601 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(args.back().get())}; \
0602 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
0603 return std::tuple_cat(current, next); \
0604 } \
0605 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0606 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
0607 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0608 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
0609 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*args.back().get())}; \
0610 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
0611 return std::tuple_cat(current, next); \
0612 } \
0613 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
0614 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
0615 CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \
0616 if (ith_output < num_output) { \
0617 args.push_back(std::make_unique<Custom::Tensor<data_type>>(context, ith_output, false)); \
0618 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(args.back().get())}; \
0619 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
0620 return std::tuple_cat(current, next); \
0621 } else { \
0622 std::tuple<T> current = std::tuple<T>{}; \
0623 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, args, num_input, num_output, ep); \
0624 return std::tuple_cat(current, next); \
0625 } \
0626 }
0627 #define CREATE_TUPLE(data_type) \
0628 CREATE_TUPLE_INPUT(data_type) \
0629 CREATE_TUPLE_OUTPUT(data_type)
0630
0631 CREATE_TUPLE(bool)
0632 CREATE_TUPLE(float)
0633 CREATE_TUPLE(Ort::Float16_t)
0634 CREATE_TUPLE(Ort::BFloat16_t)
0635 CREATE_TUPLE(double)
0636 CREATE_TUPLE(int8_t)
0637 CREATE_TUPLE(int16_t)
0638 CREATE_TUPLE(int32_t)
0639 CREATE_TUPLE(int64_t)
0640 CREATE_TUPLE(uint8_t)
0641 CREATE_TUPLE(uint16_t)
0642 CREATE_TUPLE(uint32_t)
0643 CREATE_TUPLE(uint64_t)
0644 CREATE_TUPLE(std::string)
0645 CREATE_TUPLE_INPUT(std::string_view)
0646 CREATE_TUPLE(Ort::Float8E4M3FN_t)
0647 CREATE_TUPLE(Ort::Float8E4M3FNUZ_t)
0648 CREATE_TUPLE(Ort::Float8E5M2_t)
0649 CREATE_TUPLE(Ort::Float8E5M2FNUZ_t)
0650
0651
0652 template <typename... Ts>
0653 static typename std::enable_if<0 == sizeof...(Ts)>::type
0654 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
0655 }
0656
0657 template <typename T, typename... Ts>
0658 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
0659 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0660 ParseArgs<Ts...>(input_types, output_types);
0661 }
0662
0663 template <typename T, typename... Ts>
0664 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext&>::value>::type
0665 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0666 ParseArgs<Ts...>(input_types, output_types);
0667 }
0668
0669 #ifdef ORT_CUDA_CTX
0670 template <typename T, typename... Ts>
0671 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
0672 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0673 ParseArgs<Ts...>(input_types, output_types);
0674 }
0675 #endif
0676
0677 #ifdef ORT_ROCM_CTX
0678 template <typename T, typename... Ts>
0679 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const RocmContext&>::value>::type
0680 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0681 ParseArgs<Ts...>(input_types, output_types);
0682 }
0683 #endif
0684
0685 template <typename T, typename... Ts>
0686 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray&>::value>::type
0687 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0688 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0689 ParseArgs<Ts...>(input_types, output_types);
0690 }
0691
0692 template <typename T, typename... Ts>
0693 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const TensorArray*>::value>::type
0694 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0695 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0696 ParseArgs<Ts...>(input_types, output_types);
0697 }
0698
0699 template <typename T, typename... Ts>
0700 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray&>::value>::type
0701 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0702 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0703 ParseArgs<Ts...>(input_types, output_types);
0704 }
0705
0706 template <typename T, typename... Ts>
0707 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, TensorArray*>::value>::type
0708 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
0709 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
0710 ParseArgs<Ts...>(input_types, output_types);
0711 }
0712
0713 #define PARSE_INPUT_BASE(pack_type, onnx_type) \
0714 template <typename T, typename... Ts> \
0715 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
0716 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0717 input_types.push_back(onnx_type); \
0718 ParseArgs<Ts...>(input_types, output_types); \
0719 } \
0720 template <typename T, typename... Ts> \
0721 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const std::optional<pack_type>>::value>::type \
0722 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0723 input_types.push_back(onnx_type); \
0724 ParseArgs<Ts...>(input_types, output_types); \
0725 } \
0726 template <typename T, typename... Ts> \
0727 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
0728 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0729 input_types.push_back(onnx_type); \
0730 ParseArgs<Ts...>(input_types, output_types); \
0731 }
0732
0733 #define PARSE_INPUT(data_type, onnx_type) \
0734 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
0735 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
0736 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
0737 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
0738 PARSE_INPUT_BASE(data_type, onnx_type)
0739
0740 #define PARSE_OUTPUT(data_type, onnx_type) \
0741 template <typename T, typename... Ts> \
0742 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
0743 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0744 output_types.push_back(onnx_type); \
0745 ParseArgs<Ts...>(input_types, output_types); \
0746 } \
0747 template <typename T, typename... Ts> \
0748 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
0749 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0750 output_types.push_back(onnx_type); \
0751 ParseArgs<Ts...>(input_types, output_types); \
0752 } \
0753 template <typename T, typename... Ts> \
0754 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
0755 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
0756 output_types.push_back(onnx_type); \
0757 ParseArgs<Ts...>(input_types, output_types); \
0758 }
0759
0760 #define PARSE_ARGS(data_type, onnx_type) \
0761 PARSE_INPUT(data_type, onnx_type) \
0762 PARSE_OUTPUT(data_type, onnx_type)
0763
0764 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
0765 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
0766 PARSE_ARGS(Ort::Float16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
0767 PARSE_ARGS(Ort::BFloat16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
0768 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
0769 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
0770 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
0771 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
0772 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
0773 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
0774 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
0775 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
0776 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
0777 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
0778 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
0779 PARSE_ARGS(Ort::Float8E4M3FN_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN)
0780 PARSE_ARGS(Ort::Float8E4M3FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ)
0781 PARSE_ARGS(Ort::Float8E5M2_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2)
0782 PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)
0783
0784 OrtLiteCustomOp(const char* op_name,
0785 const char* execution_provider,
0786 ShapeInferFn shape_infer_fn,
0787 int start_ver = 1,
0788 int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
0789 execution_provider_(execution_provider),
0790 shape_infer_fn_(shape_infer_fn),
0791 start_ver_(start_ver),
0792 end_ver_(end_ver) {
0793 OrtCustomOp::version = ORT_API_VERSION;
0794
0795 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
0796 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
0797 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { return OrtMemTypeDefault; };
0798
0799 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
0800 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0801 return self->input_types_.size();
0802 };
0803
0804 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
0805 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0806 return self->input_types_[indice];
0807 };
0808
0809 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
0810 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0811 return self->output_types_.size();
0812 };
0813
0814 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
0815 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0816 return self->output_types_[indice];
0817 };
0818
0819 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
0820 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0821 return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
0822 };
0823
0824 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) {
0825 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0826 return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL;
0827 };
0828
0829 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
0830 return 1;
0831 };
0832
0833 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
0834 return 0;
0835 };
0836
0837 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
0838 return 1;
0839 };
0840
0841 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
0842 return 0;
0843 };
0844
0845 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
0846 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; };
0847 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; };
0848 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; };
0849
0850 OrtCustomOp::CreateKernelV2 = {};
0851 OrtCustomOp::KernelComputeV2 = {};
0852 OrtCustomOp::KernelCompute = {};
0853
0854 OrtCustomOp::InferOutputShapeFn = {};
0855
0856 OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
0857 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0858 return self->start_ver_;
0859 };
0860
0861 OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
0862 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
0863 return self->end_ver_;
0864 };
0865
0866 OrtCustomOp::GetMayInplace = {};
0867 OrtCustomOp::ReleaseMayInplace = {};
0868 OrtCustomOp::GetAliasMap = {};
0869 OrtCustomOp::ReleaseAliasMap = {};
0870 }
0871
0872 const std::string op_name_;
0873 const std::string execution_provider_;
0874
0875 std::vector<ONNXTensorElementDataType> input_types_;
0876 std::vector<ONNXTensorElementDataType> output_types_;
0877
0878 ShapeInferFn shape_infer_fn_ = {};
0879
0880 int start_ver_ = 1;
0881 int end_ver_ = MAX_CUSTOM_OP_END_VER;
0882
0883 void* compute_fn_ = {};
0884 void* compute_fn_return_status_ = {};
0885 };
0886
0887
0888
0889
0890
0891
0892
0893
0894
0895
0896
0897 template <typename... Args>
0898 struct OrtLiteCustomFunc : public OrtLiteCustomOp {
0899 using ComputeFn = void (*)(Args...);
0900 using ComputeFnReturnStatus = Status (*)(Args...);
0901 using MyType = OrtLiteCustomFunc<Args...>;
0902
0903 struct Kernel {
0904 size_t num_input_{};
0905 size_t num_output_{};
0906 ComputeFn compute_fn_{};
0907 ComputeFnReturnStatus compute_fn_return_status_{};
0908 std::string ep_{};
0909 };
0910
0911 OrtLiteCustomFunc(const char* op_name,
0912 const char* execution_provider,
0913 ComputeFn compute_fn,
0914 ShapeInferFn shape_infer_fn = {},
0915 int start_ver = 1,
0916 int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
0917 compute_fn_ = reinterpret_cast<void*>(compute_fn);
0918 ParseArgs<Args...>(input_types_, output_types_);
0919
0920 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
0921 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
0922 std::vector<ArgPtr> args;
0923 auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
0924 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
0925 };
0926
0927 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
0928 auto kernel = std::make_unique<Kernel>();
0929 auto me = static_cast<const MyType*>(this_);
0930 kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
0931 Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
0932 Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
0933 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
0934 kernel->ep_ = self->execution_provider_;
0935 return reinterpret_cast<void*>(kernel.release());
0936 };
0937
0938 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
0939 delete reinterpret_cast<Kernel*>(op_kernel);
0940 };
0941
0942 if (shape_infer_fn_) {
0943 OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
0944 auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
0945 ShapeInferContext ctx(&GetApi(), ort_ctx);
0946 return shape_info_fn(ctx);
0947 };
0948 }
0949 }
0950
0951 OrtLiteCustomFunc(const char* op_name,
0952 const char* execution_provider,
0953 ComputeFnReturnStatus compute_fn_return_status,
0954 ShapeInferFn shape_infer_fn = {},
0955 int start_ver = 1,
0956 int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
0957 compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
0958 ParseArgs<Args...>(input_types_, output_types_);
0959
0960 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
0961 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
0962 std::vector<ArgPtr> args;
0963 auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
0964 return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t);
0965 };
0966
0967 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
0968 auto kernel = std::make_unique<Kernel>();
0969 auto me = static_cast<const MyType*>(this_);
0970 kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
0971 Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
0972 Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
0973 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
0974 kernel->ep_ = self->execution_provider_;
0975 return reinterpret_cast<void*>(kernel.release());
0976 };
0977
0978 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
0979 delete reinterpret_cast<Kernel*>(op_kernel);
0980 };
0981
0982 if (shape_infer_fn_) {
0983 OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp* op, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
0984 auto shape_info_fn = static_cast<const MyType*>(op)->shape_infer_fn_;
0985 ShapeInferContext ctx(&GetApi(), ort_ctx);
0986 return shape_info_fn(ctx);
0987 };
0988 }
0989 }
0990 };
0991
0992
0993
0994
0995
0996
0997
0998
0999
1000
1001
1002
1003
1004
1005
1006
1007
1008 template <typename CustomOp>
1009 struct OrtLiteCustomStruct : public OrtLiteCustomOp {
1010 template <typename... Args>
1011 using CustomComputeFn = void (CustomOp::*)(Args...);
1012
1013 template <typename... Args>
1014 using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...);
1015
1016 using MyType = OrtLiteCustomStruct<CustomOp>;
1017
1018 struct Kernel {
1019 size_t num_input_{};
1020 size_t num_output_{};
1021 std::unique_ptr<CustomOp> custom_op_;
1022 std::string ep_{};
1023 };
1024
1025 OrtLiteCustomStruct(const char* op_name,
1026 const char* execution_provider,
1027 int start_ver = 1,
1028 int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
1029 SetCompute(&CustomOp::Compute);
1030
1031 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1032 auto kernel = std::make_unique<Kernel>();
1033 Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
1034 Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
1035 kernel->custom_op_ = std::make_unique<CustomOp>(ort_api, info);
1036 auto self = static_cast<const OrtLiteCustomStruct*>(this_);
1037 kernel->ep_ = self->execution_provider_;
1038 return reinterpret_cast<void*>(kernel.release());
1039 };
1040
1041 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1042 delete reinterpret_cast<Kernel*>(op_kernel);
1043 };
1044
1045 SetShapeInfer<CustomOp>(0);
1046 }
1047
1048 template <typename... Args>
1049 void SetCompute(CustomComputeFn<Args...>) {
1050 ParseArgs<Args...>(input_types_, output_types_);
1051 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1052 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1053 ArgPtrs args;
1054 auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1055 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
1056 };
1057 }
1058
1059 template <typename... Args>
1060 void SetCompute(CustomComputeFnReturnStatus<Args...>) {
1061 ParseArgs<Args...>(input_types_, output_types_);
1062 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
1063 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1064 ArgPtrs args;
1065 auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_);
1066 return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t);
1067 };
1068 }
1069
1070 template <typename C>
1071 decltype(&C::InferOutputShape) SetShapeInfer(decltype(&C::InferOutputShape)) {
1072 OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
1073 ShapeInferContext ctx(&GetApi(), ort_ctx);
1074 return C::InferOutputShape(ctx);
1075 };
1076 return {};
1077 }
1078
1079 template <typename C>
1080 void SetShapeInfer(...) {
1081 OrtCustomOp::InferOutputShapeFn = {};
1082 }
1083 };
1084
1085
1086
1087 template <typename... Args>
1088 OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1089 const char* execution_provider,
1090 void (*custom_compute_fn)(Args...),
1091 Status (*shape_infer_fn)(ShapeInferContext&) = {},
1092 int start_ver = 1,
1093 int end_ver = MAX_CUSTOM_OP_END_VER) {
1094 using LiteOp = OrtLiteCustomFunc<Args...>;
1095 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
1096 }
1097
1098 template <typename... Args>
1099 OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1100 const char* execution_provider,
1101 Status (*custom_compute_fn_v2)(Args...),
1102 Status (*shape_infer_fn)(ShapeInferContext&) = {},
1103 int start_ver = 1,
1104 int end_ver = MAX_CUSTOM_OP_END_VER) {
1105 using LiteOp = OrtLiteCustomFunc<Args...>;
1106 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
1107 }
1108
1109 template <typename CustomOp>
1110 OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1111 const char* execution_provider,
1112 int start_ver = 1,
1113 int end_ver = MAX_CUSTOM_OP_END_VER) {
1114 using LiteOp = OrtLiteCustomStruct<CustomOp>;
1115 return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
1116 }
1117
1118 }
1119 }