|
|
|||
File indexing completed on 2025-12-16 10:20:34
0001 // Copyright (c) Microsoft Corporation. All rights reserved. 0002 // Licensed under the MIT License. 0003 0004 // Summary: The Ort C++ API is a header only wrapper around the Ort C API. 0005 // 0006 // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors 0007 // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so 0008 // all the resources follow RAII and do not leak memory. 0009 // 0010 // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. 0011 // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them 0012 // until you assign an instance that actually holds an underlying object. 0013 // 0014 // For Ort objects only move assignment between objects is allowed, there are no copy constructors. 0015 // Some objects have explicit 'Clone' methods for this purpose. 0016 // 0017 // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments 0018 // by value or by reference. ConstXXXX types are restricted to const only interfaces. 0019 // 0020 // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces. 0021 // 0022 // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not 0023 // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code. 0024 0025 #pragma once 0026 #include "onnxruntime_c_api.h" 0027 #include "onnxruntime_float16.h" 0028 0029 #include <cstddef> 0030 #include <cstdio> 0031 #include <array> 0032 #include <memory> 0033 #include <stdexcept> 0034 #include <string> 0035 #include <vector> 0036 #include <unordered_map> 0037 #include <utility> 0038 #include <type_traits> 0039 0040 #ifdef ORT_NO_EXCEPTIONS 0041 #include <iostream> 0042 #endif 0043 0044 /** \brief All C++ Onnxruntime APIs are defined inside this namespace 0045 * 0046 */ 0047 namespace Ort { 0048 0049 /** \brief All C++ methods that can fail will throw an exception of this type 0050 * 0051 * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort() 0052 */ 0053 struct Exception : std::exception { 0054 Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} 0055 0056 OrtErrorCode GetOrtErrorCode() const { return code_; } 0057 const char* what() const noexcept override { return message_.c_str(); } 0058 0059 private: 0060 std::string message_; 0061 OrtErrorCode code_; 0062 }; 0063 0064 #ifdef ORT_NO_EXCEPTIONS 0065 // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors. 0066 // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW 0067 #ifndef ORT_CXX_API_THROW 0068 #define ORT_CXX_API_THROW(string, code) \ 0069 do { \ 0070 std::cerr << Ort::Exception(string, code) \ 0071 .what() \ 0072 << std::endl; \ 0073 abort(); \ 0074 } while (false) 0075 #endif 0076 #else 0077 #define ORT_CXX_API_THROW(string, code) \ 0078 throw Ort::Exception(string, code) 0079 #endif 0080 0081 // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, 0082 // it's in a template so that we can define a global variable in a header and make 0083 // it transparent to the users of the API. 0084 template <typename T> 0085 struct Global { 0086 static const OrtApi* api_; 0087 }; 0088 0089 // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. 0090 template <typename T> 0091 #ifdef ORT_API_MANUAL_INIT 0092 const OrtApi* Global<T>::api_{}; 0093 inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } 0094 0095 // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is 0096 // required by C++ APIs. 0097 // 0098 // Example mycustomop.cc: 0099 // 0100 // #define ORT_API_MANUAL_INIT 0101 // #include <onnxruntime_cxx_api.h> 0102 // #undef ORT_API_MANUAL_INIT 0103 // 0104 // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { 0105 // Ort::InitApi(api_base->GetApi(ORT_API_VERSION)); 0106 // // ... 0107 // } 0108 // 0109 inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; } 0110 #else 0111 #if defined(_MSC_VER) && !defined(__clang__) 0112 #pragma warning(push) 0113 // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. 0114 // Please define ORT_API_MANUAL_INIT if it conerns you. 0115 #pragma warning(disable : 26426) 0116 #endif 0117 const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); 0118 #if defined(_MSC_VER) && !defined(__clang__) 0119 #pragma warning(pop) 0120 #endif 0121 #endif 0122 0123 /// This returns a reference to the OrtApi interface in use 0124 inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; } 0125 0126 /// <summary> 0127 /// This function returns the onnxruntime version string 0128 /// </summary> 0129 /// <returns>version string major.minor.rev</returns> 0130 std::string GetVersionString(); 0131 0132 /// <summary> 0133 /// This function returns the onnxruntime build information: including git branch, 0134 /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags. 0135 /// </summary> 0136 /// <returns>string</returns> 0137 std::string GetBuildInfoString(); 0138 0139 /// <summary> 0140 /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and 0141 /// returns a vector of strings representing the available execution providers. 0142 /// </summary> 0143 /// <returns>vector of strings</returns> 0144 std::vector<std::string> GetAvailableProviders(); 0145 0146 /** \brief IEEE 754 half-precision floating point data type 0147 * 0148 * \details This struct is used for converting float to float16 and back 0149 * so the user could feed inputs and fetch outputs using these type. 0150 * 0151 * The size of the structure should align with uint16_t and one can freely cast 0152 * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. 0153 * 0154 * \code{.unparsed} 0155 * // This example demonstrates converion from float to float16 0156 * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; 0157 * std::vector<Ort::Float16_t> fp16_values; 0158 * fp16_values.reserve(std::size(values)); 0159 * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values), 0160 * [](float value) { return Ort::Float16_t(value); }); 0161 * 0162 * \endcode 0163 */ 0164 struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> { 0165 private: 0166 /// <summary> 0167 /// Constructor from a 16-bit representation of a float16 value 0168 /// No conversion is done here. 0169 /// </summary> 0170 /// <param name="v">16-bit representation</param> 0171 constexpr explicit Float16_t(uint16_t v) noexcept { val = v; } 0172 0173 public: 0174 using Base = onnxruntime_float16::Float16Impl<Float16_t>; 0175 0176 /// <summary> 0177 /// Default constructor 0178 /// </summary> 0179 Float16_t() = default; 0180 0181 /// <summary> 0182 /// Explicit conversion to uint16_t representation of float16. 0183 /// </summary> 0184 /// <param name="v">uint16_t bit representation of float16</param> 0185 /// <returns>new instance of Float16_t</returns> 0186 constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); } 0187 0188 /// <summary> 0189 /// __ctor from float. Float is converted into float16 16-bit representation. 0190 /// </summary> 0191 /// <param name="v">float value</param> 0192 explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); } 0193 0194 /// <summary> 0195 /// Converts float16 to float 0196 /// </summary> 0197 /// <returns>float representation of float16 value</returns> 0198 float ToFloat() const noexcept { return Base::ToFloatImpl(); } 0199 0200 /// <summary> 0201 /// Checks if the value is negative 0202 /// </summary> 0203 /// <returns>true if negative</returns> 0204 using Base::IsNegative; 0205 0206 /// <summary> 0207 /// Tests if the value is NaN 0208 /// </summary> 0209 /// <returns>true if NaN</returns> 0210 using Base::IsNaN; 0211 0212 /// <summary> 0213 /// Tests if the value is finite 0214 /// </summary> 0215 /// <returns>true if finite</returns> 0216 using Base::IsFinite; 0217 0218 /// <summary> 0219 /// Tests if the value represents positive infinity. 0220 /// </summary> 0221 /// <returns>true if positive infinity</returns> 0222 using Base::IsPositiveInfinity; 0223 0224 /// <summary> 0225 /// Tests if the value represents negative infinity 0226 /// </summary> 0227 /// <returns>true if negative infinity</returns> 0228 using Base::IsNegativeInfinity; 0229 0230 /// <summary> 0231 /// Tests if the value is either positive or negative infinity. 0232 /// </summary> 0233 /// <returns>True if absolute value is infinity</returns> 0234 using Base::IsInfinity; 0235 0236 /// <summary> 0237 /// Tests if the value is NaN or zero. Useful for comparisons. 0238 /// </summary> 0239 /// <returns>True if NaN or zero.</returns> 0240 using Base::IsNaNOrZero; 0241 0242 /// <summary> 0243 /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). 0244 /// </summary> 0245 /// <returns>True if so</returns> 0246 using Base::IsNormal; 0247 0248 /// <summary> 0249 /// Tests if the value is subnormal (denormal). 0250 /// </summary> 0251 /// <returns>True if so</returns> 0252 using Base::IsSubnormal; 0253 0254 /// <summary> 0255 /// Creates an instance that represents absolute value. 0256 /// </summary> 0257 /// <returns>Absolute value</returns> 0258 using Base::Abs; 0259 0260 /// <summary> 0261 /// Creates a new instance with the sign flipped. 0262 /// </summary> 0263 /// <returns>Flipped sign instance</returns> 0264 using Base::Negate; 0265 0266 /// <summary> 0267 /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check 0268 /// for two values by or'ing the private bits together and stripping the sign. They are both zero, 0269 /// and therefore equivalent, if the resulting value is still zero. 0270 /// </summary> 0271 /// <param name="lhs">first value</param> 0272 /// <param name="rhs">second value</param> 0273 /// <returns>True if both arguments represent zero</returns> 0274 using Base::AreZero; 0275 0276 /// <summary> 0277 /// User defined conversion operator. Converts Float16_t to float. 0278 /// </summary> 0279 explicit operator float() const noexcept { return ToFloat(); } 0280 0281 using Base::operator==; 0282 using Base::operator!=; 0283 using Base::operator<; 0284 }; 0285 0286 static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); 0287 0288 /** \brief bfloat16 (Brain Floating Point) data type 0289 * 0290 * \details This struct is used for converting float to bfloat16 and back 0291 * so the user could feed inputs and fetch outputs using these type. 0292 * 0293 * The size of the structure should align with uint16_t and one can freely cast 0294 * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. 0295 * 0296 * \code{.unparsed} 0297 * // This example demonstrates converion from float to float16 0298 * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f}; 0299 * std::vector<Ort::BFloat16_t> bfp16_values; 0300 * bfp16_values.reserve(std::size(values)); 0301 * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values), 0302 * [](float value) { return Ort::BFloat16_t(value); }); 0303 * 0304 * \endcode 0305 */ 0306 struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> { 0307 private: 0308 /// <summary> 0309 /// Constructor from a uint16_t representation of bfloat16 0310 /// used in FromBits() to escape overload resolution issue with 0311 /// constructor from float. 0312 /// No conversion is done. 0313 /// </summary> 0314 /// <param name="v">16-bit bfloat16 value</param> 0315 constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; } 0316 0317 public: 0318 using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>; 0319 0320 BFloat16_t() = default; 0321 0322 /// <summary> 0323 /// Explicit conversion to uint16_t representation of bfloat16. 0324 /// </summary> 0325 /// <param name="v">uint16_t bit representation of bfloat16</param> 0326 /// <returns>new instance of BFloat16_t</returns> 0327 static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); } 0328 0329 /// <summary> 0330 /// __ctor from float. Float is converted into bfloat16 16-bit representation. 0331 /// </summary> 0332 /// <param name="v">float value</param> 0333 explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); } 0334 0335 /// <summary> 0336 /// Converts bfloat16 to float 0337 /// </summary> 0338 /// <returns>float representation of bfloat16 value</returns> 0339 float ToFloat() const noexcept { return Base::ToFloatImpl(); } 0340 0341 /// <summary> 0342 /// Checks if the value is negative 0343 /// </summary> 0344 /// <returns>true if negative</returns> 0345 using Base::IsNegative; 0346 0347 /// <summary> 0348 /// Tests if the value is NaN 0349 /// </summary> 0350 /// <returns>true if NaN</returns> 0351 using Base::IsNaN; 0352 0353 /// <summary> 0354 /// Tests if the value is finite 0355 /// </summary> 0356 /// <returns>true if finite</returns> 0357 using Base::IsFinite; 0358 0359 /// <summary> 0360 /// Tests if the value represents positive infinity. 0361 /// </summary> 0362 /// <returns>true if positive infinity</returns> 0363 using Base::IsPositiveInfinity; 0364 0365 /// <summary> 0366 /// Tests if the value represents negative infinity 0367 /// </summary> 0368 /// <returns>true if negative infinity</returns> 0369 using Base::IsNegativeInfinity; 0370 0371 /// <summary> 0372 /// Tests if the value is either positive or negative infinity. 0373 /// </summary> 0374 /// <returns>True if absolute value is infinity</returns> 0375 using Base::IsInfinity; 0376 0377 /// <summary> 0378 /// Tests if the value is NaN or zero. Useful for comparisons. 0379 /// </summary> 0380 /// <returns>True if NaN or zero.</returns> 0381 using Base::IsNaNOrZero; 0382 0383 /// <summary> 0384 /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). 0385 /// </summary> 0386 /// <returns>True if so</returns> 0387 using Base::IsNormal; 0388 0389 /// <summary> 0390 /// Tests if the value is subnormal (denormal). 0391 /// </summary> 0392 /// <returns>True if so</returns> 0393 using Base::IsSubnormal; 0394 0395 /// <summary> 0396 /// Creates an instance that represents absolute value. 0397 /// </summary> 0398 /// <returns>Absolute value</returns> 0399 using Base::Abs; 0400 0401 /// <summary> 0402 /// Creates a new instance with the sign flipped. 0403 /// </summary> 0404 /// <returns>Flipped sign instance</returns> 0405 using Base::Negate; 0406 0407 /// <summary> 0408 /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check 0409 /// for two values by or'ing the private bits together and stripping the sign. They are both zero, 0410 /// and therefore equivalent, if the resulting value is still zero. 0411 /// </summary> 0412 /// <param name="lhs">first value</param> 0413 /// <param name="rhs">second value</param> 0414 /// <returns>True if both arguments represent zero</returns> 0415 using Base::AreZero; 0416 0417 /// <summary> 0418 /// User defined conversion operator. Converts BFloat16_t to float. 0419 /// </summary> 0420 explicit operator float() const noexcept { return ToFloat(); } 0421 0422 // We do not have an inherited impl for the below operators 0423 // as the internal class implements them a little differently 0424 bool operator==(const BFloat16_t& rhs) const noexcept; 0425 bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); } 0426 bool operator<(const BFloat16_t& rhs) const noexcept; 0427 }; 0428 0429 static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); 0430 0431 /** \brief float8e4m3fn (Float8 Floating Point) data type 0432 * \details It is necessary for type dispatching to make use of C++ API 0433 * The type is implicitly convertible to/from uint8_t. 0434 * See https://onnx.ai/onnx/technical/float8.html for further details. 0435 */ 0436 struct Float8E4M3FN_t { 0437 uint8_t value; 0438 constexpr Float8E4M3FN_t() noexcept : value(0) {} 0439 constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {} 0440 constexpr operator uint8_t() const noexcept { return value; } 0441 // nan values are treated like any other value for operator ==, != 0442 constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; }; 0443 constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; }; 0444 }; 0445 0446 static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match"); 0447 0448 /** \brief float8e4m3fnuz (Float8 Floating Point) data type 0449 * \details It is necessary for type dispatching to make use of C++ API 0450 * The type is implicitly convertible to/from uint8_t. 0451 * See https://onnx.ai/onnx/technical/float8.html for further details. 0452 */ 0453 struct Float8E4M3FNUZ_t { 0454 uint8_t value; 0455 constexpr Float8E4M3FNUZ_t() noexcept : value(0) {} 0456 constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {} 0457 constexpr operator uint8_t() const noexcept { return value; } 0458 // nan values are treated like any other value for operator ==, != 0459 constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; }; 0460 constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; }; 0461 }; 0462 0463 static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match"); 0464 0465 /** \brief float8e5m2 (Float8 Floating Point) data type 0466 * \details It is necessary for type dispatching to make use of C++ API 0467 * The type is implicitly convertible to/from uint8_t. 0468 * See https://onnx.ai/onnx/technical/float8.html for further details. 0469 */ 0470 struct Float8E5M2_t { 0471 uint8_t value; 0472 constexpr Float8E5M2_t() noexcept : value(0) {} 0473 constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {} 0474 constexpr operator uint8_t() const noexcept { return value; } 0475 // nan values are treated like any other value for operator ==, != 0476 constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; }; 0477 constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; }; 0478 }; 0479 0480 static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match"); 0481 0482 /** \brief float8e5m2fnuz (Float8 Floating Point) data type 0483 * \details It is necessary for type dispatching to make use of C++ API 0484 * The type is implicitly convertible to/from uint8_t. 0485 * See https://onnx.ai/onnx/technical/float8.html for further details. 0486 */ 0487 struct Float8E5M2FNUZ_t { 0488 uint8_t value; 0489 constexpr Float8E5M2FNUZ_t() noexcept : value(0) {} 0490 constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {} 0491 constexpr operator uint8_t() const noexcept { return value; } 0492 // nan values are treated like any other value for operator ==, != 0493 constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; }; 0494 constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; }; 0495 }; 0496 0497 static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match"); 0498 0499 namespace detail { 0500 // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type 0501 // This can't be done in the C API since C doesn't have function overloading. 0502 #define ORT_DEFINE_RELEASE(NAME) \ 0503 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } 0504 0505 ORT_DEFINE_RELEASE(Allocator); 0506 ORT_DEFINE_RELEASE(MemoryInfo); 0507 ORT_DEFINE_RELEASE(CustomOpDomain); 0508 ORT_DEFINE_RELEASE(ThreadingOptions); 0509 ORT_DEFINE_RELEASE(Env); 0510 ORT_DEFINE_RELEASE(RunOptions); 0511 ORT_DEFINE_RELEASE(LoraAdapter); 0512 ORT_DEFINE_RELEASE(Session); 0513 ORT_DEFINE_RELEASE(SessionOptions); 0514 ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); 0515 ORT_DEFINE_RELEASE(SequenceTypeInfo); 0516 ORT_DEFINE_RELEASE(MapTypeInfo); 0517 ORT_DEFINE_RELEASE(TypeInfo); 0518 ORT_DEFINE_RELEASE(Value); 0519 ORT_DEFINE_RELEASE(ModelMetadata); 0520 ORT_DEFINE_RELEASE(IoBinding); 0521 ORT_DEFINE_RELEASE(ArenaCfg); 0522 ORT_DEFINE_RELEASE(Status); 0523 ORT_DEFINE_RELEASE(OpAttr); 0524 ORT_DEFINE_RELEASE(Op); 0525 ORT_DEFINE_RELEASE(KernelInfo); 0526 0527 #undef ORT_DEFINE_RELEASE 0528 0529 /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object 0530 * has no ownership of the underlying C object. 0531 */ 0532 template <typename T> 0533 struct Unowned { 0534 using Type = T; 0535 }; 0536 0537 /** \brief Used internally by the C++ API. C++ wrapper types inherit from this. 0538 * This is a zero cost abstraction to wrap the C API objects and delete them on destruction. 0539 * 0540 * All of the C++ classes 0541 * a) serve as containers for pointers to objects that are created by the underlying C API. 0542 * Their size is just a pointer size, no need to dynamically allocate them. Use them by value. 0543 * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects. 0544 * they would release objects owned automatically when going out of scope, they are move-only. 0545 * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers. 0546 * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else 0547 * such as Onnxruntime or instances of XXXX classes. 0548 * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used 0549 * in C++ code. 0550 * 0551 */ 0552 0553 /// <summary> 0554 /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction. 0555 /// </summary> 0556 template <typename T> 0557 struct Base { 0558 using contained_type = T; 0559 0560 constexpr Base() = default; 0561 constexpr explicit Base(contained_type* p) noexcept : p_{p} {} 0562 ~Base() { OrtRelease(p_); } 0563 0564 Base(const Base&) = delete; 0565 Base& operator=(const Base&) = delete; 0566 0567 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } 0568 Base& operator=(Base&& v) noexcept { 0569 OrtRelease(p_); 0570 p_ = v.release(); 0571 return *this; 0572 } 0573 0574 constexpr operator contained_type*() const noexcept { return p_; } 0575 0576 /// \brief Relinquishes ownership of the contained C object pointer 0577 /// The underlying object is not destroyed 0578 contained_type* release() { 0579 T* p = p_; 0580 p_ = nullptr; 0581 return p; 0582 } 0583 0584 protected: 0585 contained_type* p_{}; 0586 }; 0587 0588 // Undefined. For const types use Base<Unowned<const T>> 0589 template <typename T> 0590 struct Base<const T>; 0591 0592 /// <summary> 0593 /// Covers unowned pointers owned by either the ORT 0594 /// or some other instance of CPP wrappers. 0595 /// Used for ConstXXX and UnownedXXXX types that are copyable. 0596 /// Also convenient to wrap raw OrtXX pointers . 0597 /// </summary> 0598 /// <typeparam name="T"></typeparam> 0599 template <typename T> 0600 struct Base<Unowned<T>> { 0601 using contained_type = typename Unowned<T>::Type; 0602 0603 constexpr Base() = default; 0604 constexpr explicit Base(contained_type* p) noexcept : p_{p} {} 0605 0606 ~Base() = default; 0607 0608 Base(const Base&) = default; 0609 Base& operator=(const Base&) = default; 0610 0611 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } 0612 Base& operator=(Base&& v) noexcept { 0613 p_ = nullptr; 0614 std::swap(p_, v.p_); 0615 return *this; 0616 } 0617 0618 constexpr operator contained_type*() const noexcept { return p_; } 0619 0620 protected: 0621 contained_type* p_{}; 0622 }; 0623 0624 // Light functor to release memory with OrtAllocator 0625 struct AllocatedFree { 0626 OrtAllocator* allocator_; 0627 explicit AllocatedFree(OrtAllocator* allocator) 0628 : allocator_(allocator) {} 0629 void operator()(void* ptr) const { 0630 if (ptr) allocator_->Free(allocator_, ptr); 0631 } 0632 }; 0633 0634 } // namespace detail 0635 0636 struct AllocatorWithDefaultOptions; 0637 struct Env; 0638 struct TypeInfo; 0639 struct Value; 0640 struct ModelMetadata; 0641 0642 /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators 0643 * and release them at the end of the scope. The lifespan of the given allocator 0644 * must eclipse the lifespan of AllocatedStringPtr instance 0645 */ 0646 using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>; 0647 0648 /** \brief The Status that holds ownership of OrtStatus received from C API 0649 * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate 0650 * constructors to construct an instance of a Status object from exceptions. 0651 */ 0652 struct Status : detail::Base<OrtStatus> { 0653 explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used 0654 explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API. 0655 explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception 0656 explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception 0657 Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message. 0658 std::string GetErrorMessage() const; 0659 OrtErrorCode GetErrorCode() const; 0660 bool IsOK() const noexcept; ///< Returns true if instance represents an OK (non-error) status. 0661 }; 0662 0663 /** \brief The ThreadingOptions 0664 * 0665 * The ThreadingOptions used for set global threadpools' options of The Env. 0666 */ 0667 struct ThreadingOptions : detail::Base<OrtThreadingOptions> { 0668 /// \brief Wraps OrtApi::CreateThreadingOptions 0669 ThreadingOptions(); 0670 0671 /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads 0672 ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads); 0673 0674 /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads 0675 ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads); 0676 0677 /// \brief Wraps OrtApi::SetGlobalSpinControl 0678 ThreadingOptions& SetGlobalSpinControl(int allow_spinning); 0679 0680 /// \brief Wraps OrtApi::SetGlobalDenormalAsZero 0681 ThreadingOptions& SetGlobalDenormalAsZero(); 0682 0683 /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn 0684 ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); 0685 0686 /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions 0687 ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options); 0688 0689 /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn 0690 ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); 0691 }; 0692 0693 /** \brief The Env (Environment) 0694 * 0695 * The Env holds the logging state used by all other objects. 0696 * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality 0697 */ 0698 struct Env : detail::Base<OrtEnv> { 0699 explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used 0700 0701 /// \brief Wraps OrtApi::CreateEnv 0702 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); 0703 0704 /// \brief Wraps OrtApi::CreateEnvWithCustomLogger 0705 Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); 0706 0707 /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools 0708 Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); 0709 0710 /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools 0711 Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, 0712 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); 0713 0714 /// \brief C Interop Helper 0715 explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {} 0716 0717 Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents 0718 Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents 0719 0720 Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel 0721 0722 Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator 0723 0724 Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocatorV2 0725 }; 0726 0727 /** \brief Custom Op Domain 0728 * 0729 */ 0730 struct CustomOpDomain : detail::Base<OrtCustomOpDomain> { 0731 explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used 0732 0733 /// \brief Wraps OrtApi::CreateCustomOpDomain 0734 explicit CustomOpDomain(const char* domain); 0735 0736 // This does not take ownership of the op, simply registers it. 0737 void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add 0738 }; 0739 0740 /// \brief LoraAdapter holds a set of Lora Parameters loaded from a single file 0741 struct LoraAdapter : detail::Base<OrtLoraAdapter> { 0742 using Base = detail::Base<OrtLoraAdapter>; 0743 using Base::Base; 0744 0745 explicit LoraAdapter(std::nullptr_t) {} ///< Create an empty LoraAdapter object, must be assigned a valid one to be used 0746 /// \brief Wraps OrtApi::CreateLoraAdapter 0747 /// 0748 /// The function attempts to load the adapter from the specified file 0749 /// \param adapter_path The path to the Lora adapter 0750 /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still 0751 /// be copied to device if required by the model at inference time. 0752 static LoraAdapter CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path, 0753 OrtAllocator* allocator); 0754 0755 /// \brief Wraps OrtApi::CreateLoraAdapterFromArray 0756 /// 0757 /// The function attempts to load the adapter from the specified byte array. 0758 /// \param bytes The byte array containing file LoraAdapter format 0759 /// \param num_bytes The number of bytes in the byte array 0760 /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still 0761 /// be copied to device if required by the model at inference time. 0762 static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes, 0763 OrtAllocator* allocator); 0764 }; 0765 0766 /** \brief RunOptions 0767 * 0768 */ 0769 struct RunOptions : detail::Base<OrtRunOptions> { 0770 explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used 0771 RunOptions(); ///< Wraps OrtApi::CreateRunOptions 0772 0773 RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel 0774 int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel 0775 0776 RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel 0777 int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel 0778 0779 RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag 0780 const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag 0781 0782 RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry 0783 0784 /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance 0785 * 0786 * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error 0787 * Wraps OrtApi::RunOptionsSetTerminate 0788 */ 0789 RunOptions& SetTerminate(); 0790 0791 /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating 0792 * 0793 * Wraps OrtApi::RunOptionsUnsetTerminate 0794 */ 0795 RunOptions& UnsetTerminate(); 0796 0797 /** \brief Add the LoraAdapter to the list of active adapters. 0798 * The setting does not affect RunWithBinding() calls. 0799 * 0800 * Wraps OrtApi::RunOptionsAddActiveLoraAdapter 0801 * \param adapter The LoraAdapter to be used as the active adapter 0802 */ 0803 RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); 0804 }; 0805 0806 namespace detail { 0807 // Utility function that returns a SessionOption config entry key for a specific custom operator. 0808 // Ex: custom_op.[custom_op_name].[config] 0809 std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config); 0810 } // namespace detail 0811 0812 /// <summary> 0813 /// Class that represents session configuration entries for one or more custom operators. 0814 /// 0815 /// Example: 0816 /// Ort::CustomOpConfigs op_configs; 0817 /// op_configs.AddConfig("my_custom_op", "device_type", "CPU"); 0818 /// 0819 /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary. 0820 /// </summary> 0821 struct CustomOpConfigs { 0822 CustomOpConfigs() = default; 0823 ~CustomOpConfigs() = default; 0824 CustomOpConfigs(const CustomOpConfigs&) = default; 0825 CustomOpConfigs& operator=(const CustomOpConfigs&) = default; 0826 CustomOpConfigs(CustomOpConfigs&& o) = default; 0827 CustomOpConfigs& operator=(CustomOpConfigs&& o) = default; 0828 0829 /** \brief Adds a session configuration entry/value for a specific custom operator. 0830 * 0831 * \param custom_op_name The name of the custom operator for which to add a configuration entry. 0832 * Must match the name returned by the CustomOp's GetName() method. 0833 * \param config_key The name of the configuration entry. 0834 * \param config_value The value of the configuration entry. 0835 * \return A reference to this object to enable call chaining. 0836 */ 0837 CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value); 0838 0839 /** \brief Returns a flattened map of custom operator configuration entries and their values. 0840 * 0841 * The keys has been flattened to include both the custom operator name and the configuration entry key name. 0842 * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair 0843 * {"my_op.key", "value"}. 0844 * 0845 * \return An unordered map of flattened configurations. 0846 */ 0847 const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const; 0848 0849 private: 0850 std::unordered_map<std::string, std::string> flat_configs_; 0851 }; 0852 0853 /** \brief Options object used when creating a new Session object 0854 * 0855 * Wraps ::OrtSessionOptions object and methods 0856 */ 0857 0858 struct SessionOptions; 0859 0860 namespace detail { 0861 // we separate const-only methods because passing const ptr to non-const methods 0862 // is only discovered when inline methods are compiled which is counter-intuitive 0863 template <typename T> 0864 struct ConstSessionOptionsImpl : Base<T> { 0865 using B = Base<T>; 0866 using B::B; 0867 0868 SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions 0869 0870 std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry 0871 bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry 0872 std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def); 0873 }; 0874 0875 template <typename T> 0876 struct SessionOptionsImpl : ConstSessionOptionsImpl<T> { 0877 using B = ConstSessionOptionsImpl<T>; 0878 using B::B; 0879 0880 SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads 0881 SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads 0882 SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel 0883 SessionOptionsImpl& SetDeterministicCompute(bool value); ///< Wraps OrtApi::SetDeterministicCompute 0884 0885 SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena 0886 SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena 0887 0888 SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath 0889 0890 SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling 0891 SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling 0892 0893 SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps 0894 0895 SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern 0896 SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern 0897 0898 SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode 0899 0900 SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId 0901 SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel 0902 0903 SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain 0904 0905 SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads 0906 0907 SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry 0908 0909 SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer 0910 SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers 0911 SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names, 0912 const std::vector<char*>& external_initializer_file_buffer_array, 0913 const std::vector<size_t>& external_initializer_file_lengths); ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory 0914 0915 SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA 0916 SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2 0917 SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM 0918 SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO 0919 ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2 0920 SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {}); 0921 SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT 0922 SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT 0923 SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX 0924 ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN 0925 SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); 0926 ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl 0927 SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); 0928 /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. 0929 SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, 0930 const std::unordered_map<std::string, std::string>& provider_options = {}); 0931 0932 SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn 0933 SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions 0934 SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn 0935 0936 ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2. 0937 ///< The custom operator configurations are optional. If provided, custom operator configs are set via 0938 ///< OrtApi::AddSessionConfigEntry. 0939 SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); 0940 0941 SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction 0942 0943 ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI 0944 SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {}); 0945 }; 0946 } // namespace detail 0947 0948 using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>; 0949 using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>; 0950 0951 /** \brief Wrapper around ::OrtSessionOptions 0952 * 0953 */ 0954 struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> { 0955 explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used 0956 SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions 0957 explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API 0958 UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; } 0959 ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; } 0960 }; 0961 0962 /** \brief Wrapper around ::OrtModelMetadata 0963 * 0964 */ 0965 struct ModelMetadata : detail::Base<OrtModelMetadata> { 0966 explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used 0967 explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API 0968 0969 /** \brief Returns a copy of the producer name. 0970 * 0971 * \param allocator to allocate memory for the copy of the name returned 0972 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 0973 * The OrtAllocator instances must be valid at the point of memory release. 0974 */ 0975 AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName 0976 0977 /** \brief Returns a copy of the graph name. 0978 * 0979 * \param allocator to allocate memory for the copy of the name returned 0980 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 0981 * The OrtAllocator instances must be valid at the point of memory release. 0982 */ 0983 AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName 0984 0985 /** \brief Returns a copy of the domain name. 0986 * 0987 * \param allocator to allocate memory for the copy of the name returned 0988 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 0989 * The OrtAllocator instances must be valid at the point of memory release. 0990 */ 0991 AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain 0992 0993 /** \brief Returns a copy of the description. 0994 * 0995 * \param allocator to allocate memory for the copy of the string returned 0996 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 0997 * The OrtAllocator instances must be valid at the point of memory release. 0998 */ 0999 AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription 1000 1001 /** \brief Returns a copy of the graph description. 1002 * 1003 * \param allocator to allocate memory for the copy of the string returned 1004 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 1005 * The OrtAllocator instances must be valid at the point of memory release. 1006 */ 1007 AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription 1008 1009 /** \brief Returns a vector of copies of the custom metadata keys. 1010 * 1011 * \param allocator to allocate memory for the copy of the string returned 1012 * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope. 1013 * The OrtAllocator instance must be valid at the point of memory release. 1014 */ 1015 std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys 1016 1017 /** \brief Looks up a value by a key in the Custom Metadata map 1018 * 1019 * \param key zero terminated string key to lookup 1020 * \param allocator to allocate memory for the copy of the string returned 1021 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 1022 * maybe nullptr if key is not found. 1023 * 1024 * The OrtAllocator instances must be valid at the point of memory release. 1025 */ 1026 AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap 1027 1028 int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion 1029 }; 1030 1031 struct IoBinding; 1032 1033 namespace detail { 1034 1035 // we separate const-only methods because passing const ptr to non-const methods 1036 // is only discovered when inline methods are compiled which is counter-intuitive 1037 template <typename T> 1038 struct ConstSessionImpl : Base<T> { 1039 using B = Base<T>; 1040 using B::B; 1041 1042 size_t GetInputCount() const; ///< Returns the number of model inputs 1043 size_t GetOutputCount() const; ///< Returns the number of model outputs 1044 size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden 1045 1046 /** \brief Returns a copy of input name at the specified index. 1047 * 1048 * \param index must less than the value returned by GetInputCount() 1049 * \param allocator to allocate memory for the copy of the name returned 1050 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 1051 * The OrtAllocator instances must be valid at the point of memory release. 1052 */ 1053 AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const; 1054 1055 /** \brief Returns a copy of output name at then specified index. 1056 * 1057 * \param index must less than the value returned by GetOutputCount() 1058 * \param allocator to allocate memory for the copy of the name returned 1059 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 1060 * The OrtAllocator instances must be valid at the point of memory release. 1061 */ 1062 AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const; 1063 1064 /** \brief Returns a copy of the overridable initializer name at then specified index. 1065 * 1066 * \param index must less than the value returned by GetOverridableInitializerCount() 1067 * \param allocator to allocate memory for the copy of the name returned 1068 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 1069 * The OrtAllocator instances must be valid at the point of memory release. 1070 */ 1071 AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName 1072 1073 uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs 1074 ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata 1075 1076 TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo 1077 TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo 1078 TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo 1079 }; 1080 1081 template <typename T> 1082 struct SessionImpl : ConstSessionImpl<T> { 1083 using B = ConstSessionImpl<T>; 1084 using B::B; 1085 1086 /** \brief Run the model returning results in an Ort allocated vector. 1087 * 1088 * Wraps OrtApi::Run 1089 * 1090 * The caller provides a list of inputs and a list of the desired outputs to return. 1091 * 1092 * See the output logs for more information on warnings/errors that occur while processing the model. 1093 * Common errors are.. (TODO) 1094 * 1095 * \param[in] run_options 1096 * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names 1097 * \param[in] input_values Array of Value objects of length input_count that is the list of input values 1098 * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays) 1099 * \param[in] output_names Array of C style strings of length output_count that is the list of output names 1100 * \param[in] output_count Number of outputs (the size of the output_names array) 1101 * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector) 1102 */ 1103 std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, 1104 const char* const* output_names, size_t output_count); 1105 1106 /** \brief Run the model returning results in user provided outputs 1107 * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t) 1108 */ 1109 void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, 1110 const char* const* output_names, Value* output_values, size_t output_count); 1111 1112 void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding 1113 1114 /** \brief Run the model asynchronously in a thread owned by intra op thread pool 1115 * 1116 * Wraps OrtApi::RunAsync 1117 * 1118 * \param[in] run_options 1119 * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names 1120 * \param[in] input_values Array of Value objects of length input_count 1121 * \param[in] input_count Number of elements in the input_names and inputs arrays 1122 * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names 1123 * \param[out] output_values Array of provided Values to be filled with outputs. 1124 * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. 1125 * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. 1126 * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. 1127 * NOTE: it is customer's duty to finally release output_values and each of its member, 1128 * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. 1129 * \param[in] output_count Number of elements in the output_names and outputs array 1130 * \param[in] callback Callback function on model run completion 1131 * \param[in] user_data User data that pass back to the callback 1132 */ 1133 void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, 1134 const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data); 1135 1136 /** \brief End profiling and return a copy of the profiling file name. 1137 * 1138 * \param allocator to allocate memory for the copy of the string returned 1139 * \return a instance of smart pointer that would deallocate the buffer when out of scope. 1140 * The OrtAllocator instances must be valid at the point of memory release. 1141 */ 1142 AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling 1143 1144 /** \brief Set DynamicOptions for EPs (Execution Providers) 1145 * 1146 * Wraps OrtApi::SetEpDynamicOptions 1147 * 1148 * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` 1149 * Look for `kOrtEpDynamicOptions` 1150 * 1151 * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys 1152 * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values 1153 * \param[in] kv_len Number of elements in the keys and values arrays 1154 */ 1155 void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len); 1156 }; 1157 1158 } // namespace detail 1159 1160 using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>; 1161 using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>; 1162 1163 /** \brief Wrapper around ::OrtSession 1164 * 1165 */ 1166 struct Session : detail::SessionImpl<OrtSession> { 1167 explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used 1168 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession 1169 Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options, 1170 OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer 1171 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray 1172 Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options, 1173 OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer 1174 1175 ConstSession GetConst() const { return ConstSession{this->p_}; } 1176 UnownedSession GetUnowned() const { return UnownedSession{this->p_}; } 1177 }; 1178 1179 namespace detail { 1180 template <typename T> 1181 struct MemoryInfoImpl : Base<T> { 1182 using B = Base<T>; 1183 using B::B; 1184 1185 std::string GetAllocatorName() const; 1186 OrtAllocatorType GetAllocatorType() const; 1187 int GetDeviceId() const; 1188 OrtMemoryInfoDeviceType GetDeviceType() const; 1189 OrtMemType GetMemoryType() const; 1190 1191 template <typename U> 1192 bool operator==(const MemoryInfoImpl<U>& o) const; 1193 }; 1194 } // namespace detail 1195 1196 // Const object holder that does not own the underlying object 1197 using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>; 1198 1199 /** \brief Wrapper around ::OrtMemoryInfo 1200 * 1201 */ 1202 struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> { 1203 static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); 1204 explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created 1205 explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api 1206 MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); 1207 ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; } 1208 }; 1209 1210 namespace detail { 1211 template <typename T> 1212 struct TensorTypeAndShapeInfoImpl : Base<T> { 1213 using B = Base<T>; 1214 using B::B; 1215 1216 ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType 1217 size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount 1218 1219 size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount 1220 1221 /** \deprecated use GetShape() returning std::vector 1222 * [[deprecated]] 1223 * This interface is unsafe to use 1224 */ 1225 [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions 1226 1227 void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions 1228 1229 std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape 1230 }; 1231 1232 } // namespace detail 1233 1234 using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>; 1235 1236 /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo 1237 * 1238 */ 1239 struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> { 1240 explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used 1241 explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API 1242 ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; } 1243 }; 1244 1245 namespace detail { 1246 template <typename T> 1247 struct SequenceTypeInfoImpl : Base<T> { 1248 using B = Base<T>; 1249 using B::B; 1250 TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType 1251 }; 1252 1253 } // namespace detail 1254 1255 using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>; 1256 1257 /** \brief Wrapper around ::OrtSequenceTypeInfo 1258 * 1259 */ 1260 struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> { 1261 explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used 1262 explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API 1263 ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; } 1264 }; 1265 1266 namespace detail { 1267 template <typename T> 1268 struct OptionalTypeInfoImpl : Base<T> { 1269 using B = Base<T>; 1270 using B::B; 1271 TypeInfo GetOptionalElementType() const; ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo 1272 }; 1273 1274 } // namespace detail 1275 1276 // This is always owned by the TypeInfo and can only be obtained from it. 1277 using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl<detail::Unowned<const OrtOptionalTypeInfo>>; 1278 1279 namespace detail { 1280 template <typename T> 1281 struct MapTypeInfoImpl : detail::Base<T> { 1282 using B = Base<T>; 1283 using B::B; 1284 ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType 1285 TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType 1286 }; 1287 1288 } // namespace detail 1289 1290 using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>; 1291 1292 /** \brief Wrapper around ::OrtMapTypeInfo 1293 * 1294 */ 1295 struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> { 1296 explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used 1297 explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API 1298 ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; } 1299 }; 1300 1301 namespace detail { 1302 template <typename T> 1303 struct TypeInfoImpl : detail::Base<T> { 1304 using B = Base<T>; 1305 using B::B; 1306 1307 ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo 1308 ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo 1309 ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo 1310 ConstOptionalTypeInfo GetOptionalTypeInfo() const; ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo 1311 1312 ONNXType GetONNXType() const; 1313 }; 1314 } // namespace detail 1315 1316 /// <summary> 1317 /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value. 1318 /// Provides access to const OrtTypeInfo APIs. 1319 /// </summary> 1320 using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>; 1321 1322 /// <summary> 1323 /// Type information that may contain either TensorTypeAndShapeInfo or 1324 /// the information about contained sequence or map depending on the ONNXType. 1325 /// </summary> 1326 struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> { 1327 explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used 1328 explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop 1329 1330 ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; } 1331 }; 1332 1333 namespace detail { 1334 // This structure is used to feed sparse tensor values 1335 // information for use with FillSparseTensor<Format>() API 1336 // if the data type for the sparse tensor values is numeric 1337 // use data.p_data, otherwise, use data.str pointer to feed 1338 // values. data.str is an array of const char* that are zero terminated. 1339 // number of strings in the array must match shape size. 1340 // For fully sparse tensors use shape {0} and set p_data/str 1341 // to nullptr. 1342 struct OrtSparseValuesParam { 1343 const int64_t* values_shape; 1344 size_t values_shape_len; 1345 union { 1346 const void* p_data; 1347 const char** str; 1348 } data; 1349 }; 1350 1351 // Provides a way to pass shape in a single 1352 // argument 1353 struct Shape { 1354 const int64_t* shape; 1355 size_t shape_len; 1356 }; 1357 1358 template <typename T> 1359 struct ConstValueImpl : Base<T> { 1360 using B = Base<T>; 1361 using B::B; 1362 1363 /// <summary> 1364 /// Obtains a pointer to a user defined data for experimental purposes 1365 /// </summary> 1366 template <typename R> 1367 void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue 1368 1369 bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc 1370 bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None 1371 1372 size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements 1373 Value GetValue(int index, OrtAllocator* allocator) const; 1374 1375 /// <summary> 1376 /// This API returns a full length of string data contained within either a tensor or a sparse Tensor. 1377 /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful 1378 /// for allocating necessary memory and calling GetStringTensorContent(). 1379 /// </summary> 1380 /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns> 1381 size_t GetStringTensorDataLength() const; 1382 1383 /// <summary> 1384 /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor 1385 /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate. 1386 /// The user must also allocate offsets buffer with the number of entries equal to that of the contained 1387 /// strings. 1388 /// 1389 /// Strings are always assumed to be on CPU, no X-device copy. 1390 /// </summary> 1391 /// <param name="buffer">user allocated buffer</param> 1392 /// <param name="buffer_length">length in bytes of the allocated buffer</param> 1393 /// <param name="offsets">a pointer to the offsets user allocated buffer</param> 1394 /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained. 1395 /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo() 1396 /// for sparse tensors</param> 1397 void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; 1398 1399 /// <summary> 1400 /// Returns a const typed pointer to the tensor contained data. 1401 /// No type checking is performed, the caller must ensure the type matches the tensor type. 1402 /// </summary> 1403 /// <typeparam name="T"></typeparam> 1404 /// <returns>const pointer to data, no copies made</returns> 1405 template <typename R> 1406 const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary> 1407 1408 /// <summary> 1409 /// Returns a non-typed pointer to a tensor contained data. 1410 /// </summary> 1411 /// <returns>const pointer to data, no copies made</returns> 1412 const void* GetTensorRawData() const; 1413 1414 /// <summary> 1415 /// The API returns type information for data contained in a tensor. For sparse 1416 /// tensors it returns type information for contained non-zero values. 1417 /// It returns dense shape for sparse tensors. 1418 /// </summary> 1419 /// <returns>TypeInfo</returns> 1420 TypeInfo GetTypeInfo() const; 1421 1422 /// <summary> 1423 /// The API returns type information for data contained in a tensor. For sparse 1424 /// tensors it returns type information for contained non-zero values. 1425 /// It returns dense shape for sparse tensors. 1426 /// </summary> 1427 /// <returns>TensorTypeAndShapeInfo</returns> 1428 TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; 1429 1430 /// <summary> 1431 /// This API returns information about the memory allocation used to hold data. 1432 /// </summary> 1433 /// <returns>Non owning instance of MemoryInfo</returns> 1434 ConstMemoryInfo GetTensorMemoryInfo() const; 1435 1436 /// <summary> 1437 /// The API copies UTF-8 encoded bytes for the requested string element 1438 /// contained within a tensor or a sparse tensor into a provided buffer. 1439 /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate. 1440 /// </summary> 1441 /// <param name="buffer_length"></param> 1442 /// <param name="element_index"></param> 1443 /// <param name="buffer"></param> 1444 void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; 1445 1446 /// <summary> 1447 /// Returns string tensor UTF-8 encoded string element. 1448 /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer. 1449 /// </summary> 1450 /// <param name="element_index"></param> 1451 /// <returns>std::string</returns> 1452 std::string GetStringTensorElement(size_t element_index) const; 1453 1454 /// <summary> 1455 /// The API returns a byte length of UTF-8 encoded string element 1456 /// contained in either a tensor or a spare tensor values. 1457 /// </summary> 1458 /// <param name="element_index"></param> 1459 /// <returns>byte length for the specified string element</returns> 1460 size_t GetStringTensorElementLength(size_t element_index) const; 1461 1462 #if !defined(DISABLE_SPARSE_TENSORS) 1463 /// <summary> 1464 /// The API returns the sparse data format this OrtValue holds in a sparse tensor. 1465 /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used 1466 /// the value returned is ORT_SPARSE_UNDEFINED. 1467 /// </summary> 1468 /// <returns>Format enum</returns> 1469 OrtSparseFormat GetSparseFormat() const; 1470 1471 /// <summary> 1472 /// The API returns type and shape information for stored non-zero values of the 1473 /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer. 1474 /// </summary> 1475 /// <returns>TensorTypeAndShapeInfo values information</returns> 1476 TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const; 1477 1478 /// <summary> 1479 /// The API returns type and shape information for the specified indices. Each supported 1480 /// indices have their own enum values even if a give format has more than one kind of indices. 1481 /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer. 1482 /// </summary> 1483 /// <param name="format">enum requested</param> 1484 /// <returns>type and shape information</returns> 1485 TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const; 1486 1487 /// <summary> 1488 /// The API retrieves a pointer to the internal indices buffer. The API merely performs 1489 /// a convenience data type casting on the return type pointer. Make sure you are requesting 1490 /// the right type, use GetSparseTensorIndicesTypeShapeInfo(); 1491 /// </summary> 1492 /// <typeparam name="T">type to cast to</typeparam> 1493 /// <param name="indices_format">requested indices kind</param> 1494 /// <param name="num_indices">number of indices entries</param> 1495 /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns> 1496 template <typename R> 1497 const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const; 1498 1499 /// <summary> 1500 /// Returns true if the OrtValue contains a sparse tensor 1501 /// </summary> 1502 /// <returns></returns> 1503 bool IsSparseTensor() const; 1504 1505 /// <summary> 1506 /// The API returns a pointer to an internal buffer of the sparse tensor 1507 /// containing non-zero values. The API merely does casting. Make sure you 1508 /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo() 1509 /// first. 1510 /// </summary> 1511 /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam> 1512 /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns> 1513 template <typename R> 1514 const R* GetSparseTensorValues() const; 1515 1516 #endif 1517 }; 1518 1519 template <typename T> 1520 struct ValueImpl : ConstValueImpl<T> { 1521 using B = ConstValueImpl<T>; 1522 using B::B; 1523 1524 /// <summary> 1525 /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer 1526 /// No type checking is performed, the caller must ensure the type matches the tensor type. 1527 /// </summary> 1528 /// <returns>non-const pointer to data, no copies made</returns> 1529 template <typename R> 1530 R* GetTensorMutableData(); 1531 1532 /// <summary> 1533 /// Returns a non-typed non-const pointer to a tensor contained data. 1534 /// </summary> 1535 /// <returns>pointer to data, no copies made</returns> 1536 void* GetTensorMutableRawData(); 1537 1538 /// <summary> 1539 // Obtain a reference to an element of data at the location specified 1540 /// by the vector of dims. 1541 /// </summary> 1542 /// <typeparam name="R"></typeparam> 1543 /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param> 1544 /// <returns></returns> 1545 template <typename R> 1546 R& At(const std::vector<int64_t>& location); 1547 1548 /// <summary> 1549 /// Set all strings at once in a string tensor 1550 /// </summary> 1551 /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param> 1552 /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param> 1553 void FillStringTensor(const char* const* s, size_t s_len); 1554 1555 /// <summary> 1556 /// Set a single string in a string tensor 1557 /// </summary> 1558 /// <param name="s">[in] A null terminated UTF-8 encoded string</param> 1559 /// <param name="index">[in] Index of the string in the tensor to set</param> 1560 void FillStringTensorElement(const char* s, size_t index); 1561 1562 /// <summary> 1563 /// Allocate if necessary and obtain a pointer to a UTF-8 1564 /// encoded string element buffer indexed by the flat element index, 1565 /// of the specified length. 1566 /// 1567 /// This API is for advanced usage. It avoids a need to construct 1568 /// an auxiliary array of string pointers, and allows to write data directly 1569 /// (do not zero terminate). 1570 /// </summary> 1571 /// <param name="index"></param> 1572 /// <param name="buffer_length"></param> 1573 /// <returns>a pointer to a writable buffer</returns> 1574 char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length); 1575 1576 #if !defined(DISABLE_SPARSE_TENSORS) 1577 /// <summary> 1578 /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor. 1579 /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user 1580 /// allocated buffers lifespan must eclipse that of the OrtValue. 1581 /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. 1582 /// </summary> 1583 /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param> 1584 /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param> 1585 void UseCooIndices(int64_t* indices_data, size_t indices_num); 1586 1587 /// <summary> 1588 /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor. 1589 /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user 1590 /// allocated buffers lifespan must eclipse that of the OrtValue. 1591 /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. 1592 /// </summary> 1593 /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param> 1594 /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param> 1595 /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param> 1596 /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param> 1597 void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num); 1598 1599 /// <summary> 1600 /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor. 1601 /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user 1602 /// allocated buffers lifespan must eclipse that of the OrtValue. 1603 /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time. 1604 /// </summary> 1605 /// <param name="indices_shape">indices shape or a {0} for fully sparse</param> 1606 /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param> 1607 void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data); 1608 1609 /// <summary> 1610 /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API 1611 /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located 1612 /// at difference device than the allocator, a X-device copy will be performed if possible. 1613 /// </summary> 1614 /// <param name="data_mem_info">specified buffer memory description</param> 1615 /// <param name="values_param">values buffer information.</param> 1616 /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param> 1617 /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param> 1618 void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param, 1619 const int64_t* indices_data, size_t indices_num); 1620 1621 /// <summary> 1622 /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API 1623 /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located 1624 /// at difference device than the allocator, a X-device copy will be performed if possible. 1625 /// </summary> 1626 /// <param name="data_mem_info">specified buffer memory description</param> 1627 /// <param name="values">values buffer information</param> 1628 /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param> 1629 /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param> 1630 /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param> 1631 /// <param name="outer_indices_num">number of csr outer indices or 0</param> 1632 void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info, 1633 const OrtSparseValuesParam& values, 1634 const int64_t* inner_indices_data, size_t inner_indices_num, 1635 const int64_t* outer_indices_data, size_t outer_indices_num); 1636 1637 /// <summary> 1638 /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API 1639 /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located 1640 /// at difference device than the allocator, a X-device copy will be performed if possible. 1641 /// </summary> 1642 /// <param name="data_mem_info">specified buffer memory description</param> 1643 /// <param name="values">values buffer information</param> 1644 /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param> 1645 /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param> 1646 void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info, 1647 const OrtSparseValuesParam& values, 1648 const Shape& indices_shape, 1649 const int32_t* indices_data); 1650 1651 #endif 1652 }; 1653 1654 } // namespace detail 1655 1656 using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>; 1657 using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>; 1658 1659 /** \brief Wrapper around ::OrtValue 1660 * 1661 */ 1662 struct Value : detail::ValueImpl<OrtValue> { 1663 using Base = detail::ValueImpl<OrtValue>; 1664 using OrtSparseValuesParam = detail::OrtSparseValuesParam; 1665 using Shape = detail::Shape; 1666 1667 explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used 1668 explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API 1669 Value(Value&&) = default; 1670 Value& operator=(Value&&) = default; 1671 1672 ConstValue GetConst() const { return ConstValue{this->p_}; } 1673 UnownedValue GetUnowned() const { return UnownedValue{this->p_}; } 1674 1675 /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. 1676 * \tparam T The numeric datatype. This API is not suitable for strings. 1677 * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). 1678 * \param p_data Pointer to the data buffer. 1679 * \param p_data_element_count The number of elements in the data buffer. 1680 * \param shape Pointer to the tensor shape dimensions. 1681 * \param shape_len The number of tensor shape dimensions. 1682 */ 1683 template <typename T> 1684 static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); 1685 1686 /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue. 1687 * 1688 * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc). 1689 * \param p_data Pointer to the data buffer. 1690 * \param p_data_byte_count The number of bytes in the data buffer. 1691 * \param shape Pointer to the tensor shape dimensions. 1692 * \param shape_len The number of tensor shape dimensions. 1693 * \param type The data type. 1694 */ 1695 static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, 1696 ONNXTensorElementDataType type); 1697 1698 /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. 1699 * This overload will allocate the buffer for the tensor according to the supplied shape and data type. 1700 * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. 1701 * The input data would need to be copied into the allocated buffer. 1702 * This API is not suitable for strings. 1703 * 1704 * \tparam T The numeric datatype. This API is not suitable for strings. 1705 * \param allocator The allocator to use. 1706 * \param shape Pointer to the tensor shape dimensions. 1707 * \param shape_len The number of tensor shape dimensions. 1708 */ 1709 template <typename T> 1710 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); 1711 1712 /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator. 1713 * Wraps OrtApi::CreateTensorAsOrtValue. 1714 * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. 1715 * The input data would need to be copied into the allocated buffer. 1716 * This API is not suitable for strings. 1717 * 1718 * \param allocator The allocator to use. 1719 * \param shape Pointer to the tensor shape dimensions. 1720 * \param shape_len The number of tensor shape dimensions. 1721 * \param type The data type. 1722 */ 1723 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); 1724 1725 /** \brief Creates an OrtValue with a Map Onnx type representation. 1726 * The API would ref-count the supplied OrtValues and they will be released 1727 * when the returned OrtValue is released. The caller may release keys and values after the call 1728 * returns. 1729 * 1730 * \param keys an OrtValue containing a tensor with primitive data type keys. 1731 * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values. 1732 */ 1733 static Value CreateMap(const Value& keys, const Value& values); ///< Wraps OrtApi::CreateValue 1734 1735 /** \brief Creates an OrtValue with a Sequence Onnx type representation. 1736 * The API would ref-count the supplied OrtValues and they will be released 1737 * when the returned OrtValue is released. The caller may release the values after the call 1738 * returns. 1739 * 1740 * \param values a vector of OrtValues that must have the same Onnx value type. 1741 */ 1742 static Value CreateSequence(const std::vector<Value>& values); ///< Wraps OrtApi::CreateValue 1743 1744 /** \brief Creates an OrtValue wrapping an Opaque type. 1745 * This is used for experimental support of non-tensor types. 1746 * 1747 * \tparam T - the type of the value. 1748 * \param domain - zero terminated utf-8 string. Domain of the type. 1749 * \param type_name - zero terminated utf-8 string. Name of the type. 1750 * \param value - the value to be wrapped. 1751 */ 1752 template <typename T> 1753 static Value CreateOpaque(const char* domain, const char* type_name, const T& value); ///< Wraps OrtApi::CreateOpaqueValue 1754 1755 #if !defined(DISABLE_SPARSE_TENSORS) 1756 /// <summary> 1757 /// This is a simple forwarding method to the other overload that helps deducing 1758 /// data type enum value from the type of the buffer. 1759 /// </summary> 1760 /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam> 1761 /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param> 1762 /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param> 1763 /// <param name="dense_shape">a would be dense shape of the tensor</param> 1764 /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param> 1765 /// <returns></returns> 1766 template <typename T> 1767 static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape, 1768 const Shape& values_shape); 1769 1770 /// <summary> 1771 /// Creates an OrtValue instance containing SparseTensor. This constructs 1772 /// a sparse tensor that makes use of user allocated buffers. It does not make copies 1773 /// of the user provided data and does not modify it. The lifespan of user provided buffers should 1774 /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain 1775 /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below 1776 /// to supply a sparse format specific indices. 1777 /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings 1778 /// can be properly copied into the allocated buffer. 1779 /// </summary> 1780 /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param> 1781 /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param> 1782 /// <param name="dense_shape">a would be dense shape of the tensor</param> 1783 /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param> 1784 /// <param name="type">data type</param> 1785 /// <returns>Ort::Value instance containing SparseTensor</returns> 1786 static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape, 1787 const Shape& values_shape, ONNXTensorElementDataType type); 1788 1789 /// <summary> 1790 /// This is a simple forwarding method to the below CreateSparseTensor. 1791 /// This helps to specify data type enum in terms of C++ data type. 1792 /// Use CreateSparseTensor<T> 1793 /// </summary> 1794 /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam> 1795 /// <param name="allocator">allocator to use</param> 1796 /// <param name="dense_shape">a would be dense shape of the tensor</param> 1797 /// <returns>Ort::Value</returns> 1798 template <typename T> 1799 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape); 1800 1801 /// <summary> 1802 /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data. 1803 /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values 1804 /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator. 1805 /// Use this API to create OrtValues that contain sparse tensors with all supported data types including 1806 /// strings. 1807 /// </summary> 1808 /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param> 1809 /// <param name="dense_shape">a would be dense shape of the tensor</param> 1810 /// <param name="type">data type</param> 1811 /// <returns>an instance of Ort::Value</returns> 1812 static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type); 1813 1814 #endif // !defined(DISABLE_SPARSE_TENSORS) 1815 }; 1816 1817 /// <summary> 1818 /// Represents native memory allocation coming from one of the 1819 /// OrtAllocators registered with OnnxRuntime. 1820 /// Use it to wrap an allocation made by an allocator 1821 /// so it can be automatically released when no longer needed. 1822 /// </summary> 1823 struct MemoryAllocation { 1824 MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); 1825 ~MemoryAllocation(); 1826 MemoryAllocation(const MemoryAllocation&) = delete; 1827 MemoryAllocation& operator=(const MemoryAllocation&) = delete; 1828 MemoryAllocation(MemoryAllocation&&) noexcept; 1829 MemoryAllocation& operator=(MemoryAllocation&&) noexcept; 1830 1831 void* get() { return p_; } 1832 size_t size() const { return size_; } 1833 1834 private: 1835 OrtAllocator* allocator_; 1836 void* p_; 1837 size_t size_; 1838 }; 1839 1840 namespace detail { 1841 template <typename T> 1842 struct AllocatorImpl : Base<T> { 1843 using B = Base<T>; 1844 using B::B; 1845 1846 void* Alloc(size_t size); 1847 MemoryAllocation GetAllocation(size_t size); 1848 void Free(void* p); 1849 ConstMemoryInfo GetInfo() const; 1850 }; 1851 1852 } // namespace detail 1853 1854 /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime 1855 * 1856 */ 1857 struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> { 1858 explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance 1859 AllocatorWithDefaultOptions(); 1860 }; 1861 1862 /** \brief Wrapper around ::OrtAllocator 1863 * 1864 */ 1865 struct Allocator : detail::AllocatorImpl<OrtAllocator> { 1866 explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance 1867 Allocator(const Session& session, const OrtMemoryInfo*); 1868 }; 1869 1870 using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>; 1871 1872 namespace detail { 1873 namespace binding_utils { 1874 // Bring these out of template 1875 std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*); 1876 std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*); 1877 } // namespace binding_utils 1878 1879 template <typename T> 1880 struct ConstIoBindingImpl : Base<T> { 1881 using B = Base<T>; 1882 using B::B; 1883 1884 std::vector<std::string> GetOutputNames() const; 1885 std::vector<std::string> GetOutputNames(OrtAllocator*) const; 1886 std::vector<Value> GetOutputValues() const; 1887 std::vector<Value> GetOutputValues(OrtAllocator*) const; 1888 }; 1889 1890 template <typename T> 1891 struct IoBindingImpl : ConstIoBindingImpl<T> { 1892 using B = ConstIoBindingImpl<T>; 1893 using B::B; 1894 1895 void BindInput(const char* name, const Value&); 1896 void BindOutput(const char* name, const Value&); 1897 void BindOutput(const char* name, const OrtMemoryInfo*); 1898 void ClearBoundInputs(); 1899 void ClearBoundOutputs(); 1900 void SynchronizeInputs(); 1901 void SynchronizeOutputs(); 1902 }; 1903 1904 } // namespace detail 1905 1906 using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>; 1907 using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>; 1908 1909 /** \brief Wrapper around ::OrtIoBinding 1910 * 1911 */ 1912 struct IoBinding : detail::IoBindingImpl<OrtIoBinding> { 1913 explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later. 1914 explicit IoBinding(Session& session); 1915 ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; } 1916 UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; } 1917 }; 1918 1919 /*! \struct Ort::ArenaCfg 1920 * \brief it is a structure that represents the configuration of an arena based allocator 1921 * \details Please see docs/C_API.md for details 1922 */ 1923 struct ArenaCfg : detail::Base<OrtArenaCfg> { 1924 explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used 1925 /** 1926 * Wraps OrtApi::CreateArenaCfg 1927 * \param max_mem - use 0 to allow ORT to choose the default 1928 * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested 1929 * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default 1930 * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default 1931 * See docs/C_API.md for details on what the following parameters mean and how to choose these values 1932 */ 1933 ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); 1934 }; 1935 1936 // 1937 // Custom OPs (only needed to implement custom OPs) 1938 // 1939 1940 /// <summary> 1941 /// This struct provides life time management for custom op attribute 1942 /// </summary> 1943 struct OpAttr : detail::Base<OrtOpAttr> { 1944 OpAttr(const char* name, const void* data, int len, OrtOpAttrType type); 1945 }; 1946 1947 /** 1948 * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails. 1949 * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); 1950 * 1951 * \param logger The Ort::Logger instance to use. Must be a value or reference. 1952 * \param message_severity The logging severity level of the message. 1953 * \param message A null-terminated UTF-8 message to log. 1954 */ 1955 #define ORT_CXX_LOG(logger, message_severity, message) \ 1956 do { \ 1957 if (message_severity >= logger.GetLoggingSeverityLevel()) { \ 1958 Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ 1959 static_cast<const char*>(__FUNCTION__), message)); \ 1960 } \ 1961 } while (false) 1962 1963 /** 1964 * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored. 1965 * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message"); 1966 * 1967 * \param logger The Ort::Logger instance to use. Must be a value or reference. 1968 * \param message_severity The logging severity level of the message. 1969 * \param message A null-terminated UTF-8 message to log. 1970 */ 1971 #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message) \ 1972 do { \ 1973 if (message_severity >= logger.GetLoggingSeverityLevel()) { \ 1974 static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__, \ 1975 static_cast<const char*>(__FUNCTION__), message)); \ 1976 } \ 1977 } while (false) 1978 1979 /** 1980 * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if 1981 * OrtApi::Logger_LogMessage fails or if a formatting error occurs. 1982 * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); 1983 * 1984 * \param logger The Ort::Logger instance to use. Must be a value or reference. 1985 * \param message_severity The logging severity level of the message. 1986 * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. 1987 * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. 1988 * \param ... Zero or more variadic arguments referenced by the format string. 1989 */ 1990 #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...) \ 1991 do { \ 1992 if (message_severity >= logger.GetLoggingSeverityLevel()) { \ 1993 Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ 1994 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \ 1995 } \ 1996 } while (false) 1997 1998 /** 1999 * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors 2000 * are silently ignored. 2001 * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12); 2002 * 2003 * \param logger The Ort::Logger instance to use. Must be a value or reference. 2004 * \param message_severity The logging severity level of the message. 2005 * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. 2006 * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. 2007 * \param ... Zero or more variadic arguments referenced by the format string. 2008 */ 2009 #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...) \ 2010 do { \ 2011 if (message_severity >= logger.GetLoggingSeverityLevel()) { \ 2012 static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__, \ 2013 static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \ 2014 } \ 2015 } while (false) 2016 2017 /// <summary> 2018 /// This class represents an ONNX Runtime logger that can be used to log information with an 2019 /// associated severity level and source code location (file path, line number, function name). 2020 /// 2021 /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger(). 2022 /// Instances of Ort::Logger are the size of two pointers and can be passed by value. 2023 /// 2024 /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite 2025 /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API. 2026 /// </summary> 2027 struct Logger { 2028 /** 2029 * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. 2030 */ 2031 Logger() = default; 2032 2033 /** 2034 * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use. 2035 */ 2036 explicit Logger(std::nullptr_t) {} 2037 2038 /** 2039 * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling 2040 * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails. 2041 * 2042 * \param logger The ::OrtLogger to wrap. 2043 */ 2044 explicit Logger(const OrtLogger* logger); 2045 2046 ~Logger() = default; 2047 2048 Logger(const Logger&) = default; 2049 Logger& operator=(const Logger&) = default; 2050 2051 Logger(Logger&& v) noexcept = default; 2052 Logger& operator=(Logger&& v) noexcept = default; 2053 2054 /** 2055 * Returns the logger's current severity level from the cached member. 2056 * 2057 * \return The current ::OrtLoggingLevel. 2058 */ 2059 OrtLoggingLevel GetLoggingSeverityLevel() const noexcept; 2060 2061 /** 2062 * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT 2063 * macros to properly set the source code location and to use the cached severity level to potentially bypass 2064 * calls to the underlying C API. 2065 * 2066 * \param log_severity_level The message's logging severity level. 2067 * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. 2068 * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. 2069 * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. 2070 * \param message The message to log. 2071 * \return A Ort::Status value to indicate error or success. 2072 */ 2073 Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, 2074 const char* func_name, const char* message) const noexcept; 2075 2076 /** 2077 * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT 2078 * macros to properly set the source code location and to use the cached severity level to potentially bypass 2079 * calls to the underlying C API. Returns an error status if a formatting error occurs. 2080 * 2081 * \param log_severity_level The message's logging severity level. 2082 * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. 2083 * \param line_number The file line number in which the message is logged. Usually the value of __LINE__. 2084 * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. 2085 * \param format A null-terminated UTF-8 format string forwarded to a printf-like function. 2086 * Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats. 2087 * \param args Zero or more variadic arguments referenced by the format string. 2088 * \return A Ort::Status value to indicate error or success. 2089 */ 2090 template <typename... Args> 2091 Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number, 2092 const char* func_name, const char* format, Args&&... args) const noexcept; 2093 2094 private: 2095 const OrtLogger* logger_{}; 2096 OrtLoggingLevel cached_severity_level_{}; 2097 }; 2098 2099 /// <summary> 2100 /// This class wraps a raw pointer OrtKernelContext* that is being passed 2101 /// to the custom kernel Compute() method. Use it to safely access context 2102 /// attributes, input and output parameters with exception safety guarantees. 2103 /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc 2104 /// </summary> 2105 struct KernelContext { 2106 explicit KernelContext(OrtKernelContext* context); 2107 size_t GetInputCount() const; 2108 size_t GetOutputCount() const; 2109 // If input is optional and is not present, the method returns en empty ConstValue 2110 // which can be compared to nullptr. 2111 ConstValue GetInput(size_t index) const; 2112 // If outout is optional and is not present, the method returns en empty UnownedValue 2113 // which can be compared to nullptr. 2114 UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; 2115 UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const; 2116 void* GetGPUComputeStream() const; 2117 Logger GetLogger() const; 2118 OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; 2119 OrtKernelContext* GetOrtKernelContext() const { return ctx_; } 2120 void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const; 2121 2122 private: 2123 OrtKernelContext* ctx_; 2124 }; 2125 2126 struct KernelInfo; 2127 2128 namespace detail { 2129 namespace attr_utils { 2130 void GetAttr(const OrtKernelInfo* p, const char* name, float&); 2131 void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&); 2132 void GetAttr(const OrtKernelInfo* p, const char* name, std::string&); 2133 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&); 2134 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&); 2135 } // namespace attr_utils 2136 2137 template <typename T> 2138 struct KernelInfoImpl : Base<T> { 2139 using B = Base<T>; 2140 using B::B; 2141 2142 KernelInfo Copy() const; 2143 2144 template <typename R> // R is only implemented for float, int64_t, and string 2145 R GetAttribute(const char* name) const { 2146 R val; 2147 attr_utils::GetAttr(this->p_, name, val); 2148 return val; 2149 } 2150 2151 template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t> 2152 std::vector<R> GetAttributes(const char* name) const { 2153 std::vector<R> result; 2154 attr_utils::GetAttrs(this->p_, name, result); 2155 return result; 2156 } 2157 2158 Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const; 2159 2160 size_t GetInputCount() const; 2161 size_t GetOutputCount() const; 2162 2163 std::string GetInputName(size_t index) const; 2164 std::string GetOutputName(size_t index) const; 2165 2166 TypeInfo GetInputTypeInfo(size_t index) const; 2167 TypeInfo GetOutputTypeInfo(size_t index) const; 2168 2169 ConstValue GetTensorConstantInput(size_t index, int* is_constant) const; 2170 2171 std::string GetNodeName() const; 2172 Logger GetLogger() const; 2173 }; 2174 2175 } // namespace detail 2176 2177 using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>; 2178 2179 /// <summary> 2180 /// This struct owns the OrtKernInfo* pointer when a copy is made. 2181 /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor 2182 /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance 2183 /// so it does not destroy the pointer the kernel does not own. 2184 /// </summary> 2185 struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> { 2186 explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later 2187 explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance 2188 ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; } 2189 }; 2190 2191 /// <summary> 2192 /// Create and own custom defined operation. 2193 /// </summary> 2194 struct Op : detail::Base<OrtOp> { 2195 explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used 2196 2197 explicit Op(OrtOp*); ///< Take ownership of the OrtOp 2198 2199 static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain, 2200 int version, const char** type_constraint_names, 2201 const ONNXTensorElementDataType* type_constraint_values, 2202 size_t type_constraint_count, 2203 const OpAttr* attr_values, 2204 size_t attr_count, 2205 size_t input_count, size_t output_count); 2206 2207 void Invoke(const OrtKernelContext* context, 2208 const Value* input_values, 2209 size_t input_count, 2210 Value* output_values, 2211 size_t output_count); 2212 2213 // For easier refactoring 2214 void Invoke(const OrtKernelContext* context, 2215 const OrtValue* const* input_values, 2216 size_t input_count, 2217 OrtValue* const* output_values, 2218 size_t output_count); 2219 }; 2220 2221 /// <summary> 2222 /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes. 2223 /// </summary> 2224 struct ShapeInferContext { 2225 struct SymbolicInteger { 2226 SymbolicInteger(int64_t i) : i_(i), is_int_(true) {}; 2227 SymbolicInteger(const char* s) : s_(s), is_int_(false) {}; 2228 SymbolicInteger(const SymbolicInteger&) = default; 2229 SymbolicInteger(SymbolicInteger&&) = default; 2230 2231 SymbolicInteger& operator=(const SymbolicInteger&) = default; 2232 SymbolicInteger& operator=(SymbolicInteger&&) = default; 2233 2234 bool operator==(const SymbolicInteger& dim) const { 2235 if (is_int_ == dim.is_int_) { 2236 if (is_int_) { 2237 return i_ == dim.i_; 2238 } else { 2239 return std::string{s_} == std::string{dim.s_}; 2240 } 2241 } 2242 return false; 2243 } 2244 2245 bool IsInt() const { return is_int_; } 2246 int64_t AsInt() const { return i_; } 2247 const char* AsSym() const { return s_; } 2248 2249 static constexpr int INVALID_INT_DIM = -2; 2250 2251 private: 2252 union { 2253 int64_t i_; 2254 const char* s_; 2255 }; 2256 bool is_int_; 2257 }; 2258 2259 using Shape = std::vector<SymbolicInteger>; 2260 2261 ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx); 2262 2263 const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); } 2264 2265 size_t GetInputCount() const { return input_shapes_.size(); } 2266 2267 Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); 2268 2269 int64_t GetAttrInt(const char* attr_name); 2270 2271 using Ints = std::vector<int64_t>; 2272 Ints GetAttrInts(const char* attr_name); 2273 2274 float GetAttrFloat(const char* attr_name); 2275 2276 using Floats = std::vector<float>; 2277 Floats GetAttrFloats(const char* attr_name); 2278 2279 std::string GetAttrString(const char* attr_name); 2280 2281 using Strings = std::vector<std::string>; 2282 Strings GetAttrStrings(const char* attr_name); 2283 2284 private: 2285 const OrtOpAttr* GetAttrHdl(const char* attr_name) const; 2286 const OrtApi* ort_api_; 2287 OrtShapeInferContext* ctx_; 2288 std::vector<Shape> input_shapes_; 2289 }; 2290 2291 using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&); 2292 2293 #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1 2294 2295 template <typename TOp, typename TKernel, bool WithStatus = false> 2296 struct CustomOpBase : OrtCustomOp { 2297 CustomOpBase() { 2298 OrtCustomOp::version = ORT_API_VERSION; 2299 OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); }; 2300 2301 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); }; 2302 2303 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); }; 2304 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); }; 2305 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); }; 2306 2307 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); }; 2308 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); }; 2309 2310 #if defined(_MSC_VER) && !defined(__clang__) 2311 #pragma warning(push) 2312 #pragma warning(disable : 26409) 2313 #endif 2314 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); }; 2315 #if defined(_MSC_VER) && !defined(__clang__) 2316 #pragma warning(pop) 2317 #endif 2318 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); }; 2319 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); }; 2320 2321 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); }; 2322 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); }; 2323 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); }; 2324 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); }; 2325 #ifdef __cpp_if_constexpr 2326 if constexpr (WithStatus) { 2327 #else 2328 if (WithStatus) { 2329 #endif 2330 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { 2331 return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel); 2332 }; 2333 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { 2334 return static_cast<TKernel*>(op_kernel)->ComputeV2(context); 2335 }; 2336 } else { 2337 OrtCustomOp::CreateKernelV2 = nullptr; 2338 OrtCustomOp::KernelComputeV2 = nullptr; 2339 2340 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); }; 2341 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { 2342 static_cast<TKernel*>(op_kernel)->Compute(context); 2343 }; 2344 } 2345 2346 SetShapeInferFn<TOp>(0); 2347 2348 OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) { 2349 return static_cast<const TOp*>(this_)->start_ver_; 2350 }; 2351 2352 OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) { 2353 return static_cast<const TOp*>(this_)->end_ver_; 2354 }; 2355 2356 OrtCustomOp::GetMayInplace = nullptr; 2357 OrtCustomOp::ReleaseMayInplace = nullptr; 2358 OrtCustomOp::GetAliasMap = nullptr; 2359 OrtCustomOp::ReleaseAliasMap = nullptr; 2360 } 2361 2362 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider 2363 const char* GetExecutionProviderType() const { return nullptr; } 2364 2365 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below 2366 // (inputs and outputs are required by default) 2367 OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const { 2368 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; 2369 } 2370 2371 OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { 2372 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; 2373 } 2374 2375 // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault 2376 OrtMemType GetInputMemoryType(size_t /*index*/) const { 2377 return OrtMemTypeDefault; 2378 } 2379 2380 // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input 2381 // should expect at least 1 argument. 2382 int GetVariadicInputMinArity() const { 2383 return 1; 2384 } 2385 2386 // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments 2387 // to a variadic input should be of the same type. 2388 bool GetVariadicInputHomogeneity() const { 2389 return true; 2390 } 2391 2392 // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output 2393 // should produce at least 1 output value. 2394 int GetVariadicOutputMinArity() const { 2395 return 1; 2396 } 2397 2398 // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values 2399 // produced by a variadic output should be of the same type. 2400 bool GetVariadicOutputHomogeneity() const { 2401 return true; 2402 } 2403 2404 // Declare list of session config entries used by this Custom Op. 2405 // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs(). 2406 // This default implementation returns an empty vector of config entries. 2407 std::vector<std::string> GetSessionConfigKeys() const { 2408 return std::vector<std::string>{}; 2409 } 2410 2411 template <typename C> 2412 decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) { 2413 OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr { 2414 ShapeInferContext ctx(&GetApi(), ort_ctx); 2415 return C::InferOutputShape(ctx); 2416 }; 2417 return {}; 2418 } 2419 2420 template <typename C> 2421 void SetShapeInferFn(...) { 2422 OrtCustomOp::InferOutputShapeFn = {}; 2423 } 2424 2425 protected: 2426 // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys. 2427 void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const; 2428 2429 int start_ver_ = 1; 2430 int end_ver_ = MAX_CUSTOM_OP_END_VER; 2431 }; 2432 2433 } // namespace Ort 2434 2435 #include "onnxruntime_cxx_inline.h"
| [ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
|
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
|