Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:20:34

0001 // Copyright (c) Microsoft Corporation. All rights reserved.
0002 // Licensed under the MIT License.
0003 
0004 // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
0005 //
0006 // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
0007 // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
0008 // all the resources follow RAII and do not leak memory.
0009 //
0010 // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
0011 // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
0012 // until you assign an instance that actually holds an underlying object.
0013 //
0014 // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
0015 // Some objects have explicit 'Clone' methods for this purpose.
0016 //
0017 // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
0018 // by value or by reference. ConstXXXX types are restricted to const only interfaces.
0019 //
0020 // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
0021 //
0022 // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
0023 // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
0024 
0025 #pragma once
0026 #include "onnxruntime_c_api.h"
0027 #include "onnxruntime_float16.h"
0028 
0029 #include <cstddef>
0030 #include <cstdio>
0031 #include <array>
0032 #include <memory>
0033 #include <stdexcept>
0034 #include <string>
0035 #include <vector>
0036 #include <unordered_map>
0037 #include <utility>
0038 #include <type_traits>
0039 
0040 #ifdef ORT_NO_EXCEPTIONS
0041 #include <iostream>
0042 #endif
0043 
0044 /** \brief All C++ Onnxruntime APIs are defined inside this namespace
0045  *
0046  */
0047 namespace Ort {
0048 
0049 /** \brief All C++ methods that can fail will throw an exception of this type
0050  *
0051  * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
0052  */
0053 struct Exception : std::exception {
0054   Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
0055 
0056   OrtErrorCode GetOrtErrorCode() const { return code_; }
0057   const char* what() const noexcept override { return message_.c_str(); }
0058 
0059  private:
0060   std::string message_;
0061   OrtErrorCode code_;
0062 };
0063 
0064 #ifdef ORT_NO_EXCEPTIONS
0065 // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
0066 // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
0067 #ifndef ORT_CXX_API_THROW
0068 #define ORT_CXX_API_THROW(string, code)       \
0069   do {                                        \
0070     std::cerr << Ort::Exception(string, code) \
0071                      .what()                  \
0072               << std::endl;                   \
0073     abort();                                  \
0074   } while (false)
0075 #endif
0076 #else
0077 #define ORT_CXX_API_THROW(string, code) \
0078   throw Ort::Exception(string, code)
0079 #endif
0080 
0081 // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
0082 //  it's in a template so that we can define a global variable in a header and make
0083 // it transparent to the users of the API.
0084 template <typename T>
0085 struct Global {
0086   static const OrtApi* api_;
0087 };
0088 
0089 // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
0090 template <typename T>
0091 #ifdef ORT_API_MANUAL_INIT
0092 const OrtApi* Global<T>::api_{};
0093 inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
0094 
0095 // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
0096 // required by C++ APIs.
0097 //
0098 // Example mycustomop.cc:
0099 //
0100 // #define ORT_API_MANUAL_INIT
0101 // #include <onnxruntime_cxx_api.h>
0102 // #undef ORT_API_MANUAL_INIT
0103 //
0104 // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
0105 //   Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
0106 //   // ...
0107 // }
0108 //
0109 inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
0110 #else
0111 #if defined(_MSC_VER) && !defined(__clang__)
0112 #pragma warning(push)
0113 // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
0114 // Please define ORT_API_MANUAL_INIT if it conerns you.
0115 #pragma warning(disable : 26426)
0116 #endif
0117 const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
0118 #if defined(_MSC_VER) && !defined(__clang__)
0119 #pragma warning(pop)
0120 #endif
0121 #endif
0122 
0123 /// This returns a reference to the OrtApi interface in use
0124 inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
0125 
0126 /// <summary>
0127 /// This function returns the onnxruntime version string
0128 /// </summary>
0129 /// <returns>version string major.minor.rev</returns>
0130 std::string GetVersionString();
0131 
0132 /// <summary>
0133 /// This function returns the onnxruntime build information: including git branch,
0134 /// git commit id, build type(Debug/Release/RelWithDebInfo) and cmake cpp flags.
0135 /// </summary>
0136 /// <returns>string</returns>
0137 std::string GetBuildInfoString();
0138 
0139 /// <summary>
0140 /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
0141 /// returns a vector of strings representing the available execution providers.
0142 /// </summary>
0143 /// <returns>vector of strings</returns>
0144 std::vector<std::string> GetAvailableProviders();
0145 
0146 /** \brief IEEE 754 half-precision floating point data type
0147  *
0148  * \details This struct is used for converting float to float16 and back
0149  * so the user could feed inputs and fetch outputs using these type.
0150  *
0151  * The size of the structure should align with uint16_t and one can freely cast
0152  * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
0153  *
0154  * \code{.unparsed}
0155  * // This example demonstrates converion from float to float16
0156  * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
0157  * std::vector<Ort::Float16_t> fp16_values;
0158  * fp16_values.reserve(std::size(values));
0159  * std::transform(std::begin(values), std::end(values), std::back_inserter(fp16_values),
0160  *     [](float value) { return Ort::Float16_t(value); });
0161  *
0162  * \endcode
0163  */
0164 struct Float16_t : onnxruntime_float16::Float16Impl<Float16_t> {
0165  private:
0166   /// <summary>
0167   /// Constructor from a 16-bit representation of a float16 value
0168   /// No conversion is done here.
0169   /// </summary>
0170   /// <param name="v">16-bit representation</param>
0171   constexpr explicit Float16_t(uint16_t v) noexcept { val = v; }
0172 
0173  public:
0174   using Base = onnxruntime_float16::Float16Impl<Float16_t>;
0175 
0176   /// <summary>
0177   /// Default constructor
0178   /// </summary>
0179   Float16_t() = default;
0180 
0181   /// <summary>
0182   /// Explicit conversion to uint16_t representation of float16.
0183   /// </summary>
0184   /// <param name="v">uint16_t bit representation of float16</param>
0185   /// <returns>new instance of Float16_t</returns>
0186   constexpr static Float16_t FromBits(uint16_t v) noexcept { return Float16_t(v); }
0187 
0188   /// <summary>
0189   /// __ctor from float. Float is converted into float16 16-bit representation.
0190   /// </summary>
0191   /// <param name="v">float value</param>
0192   explicit Float16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
0193 
0194   /// <summary>
0195   /// Converts float16 to float
0196   /// </summary>
0197   /// <returns>float representation of float16 value</returns>
0198   float ToFloat() const noexcept { return Base::ToFloatImpl(); }
0199 
0200   /// <summary>
0201   /// Checks if the value is negative
0202   /// </summary>
0203   /// <returns>true if negative</returns>
0204   using Base::IsNegative;
0205 
0206   /// <summary>
0207   /// Tests if the value is NaN
0208   /// </summary>
0209   /// <returns>true if NaN</returns>
0210   using Base::IsNaN;
0211 
0212   /// <summary>
0213   /// Tests if the value is finite
0214   /// </summary>
0215   /// <returns>true if finite</returns>
0216   using Base::IsFinite;
0217 
0218   /// <summary>
0219   /// Tests if the value represents positive infinity.
0220   /// </summary>
0221   /// <returns>true if positive infinity</returns>
0222   using Base::IsPositiveInfinity;
0223 
0224   /// <summary>
0225   /// Tests if the value represents negative infinity
0226   /// </summary>
0227   /// <returns>true if negative infinity</returns>
0228   using Base::IsNegativeInfinity;
0229 
0230   /// <summary>
0231   /// Tests if the value is either positive or negative infinity.
0232   /// </summary>
0233   /// <returns>True if absolute value is infinity</returns>
0234   using Base::IsInfinity;
0235 
0236   /// <summary>
0237   /// Tests if the value is NaN or zero. Useful for comparisons.
0238   /// </summary>
0239   /// <returns>True if NaN or zero.</returns>
0240   using Base::IsNaNOrZero;
0241 
0242   /// <summary>
0243   /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
0244   /// </summary>
0245   /// <returns>True if so</returns>
0246   using Base::IsNormal;
0247 
0248   /// <summary>
0249   /// Tests if the value is subnormal (denormal).
0250   /// </summary>
0251   /// <returns>True if so</returns>
0252   using Base::IsSubnormal;
0253 
0254   /// <summary>
0255   /// Creates an instance that represents absolute value.
0256   /// </summary>
0257   /// <returns>Absolute value</returns>
0258   using Base::Abs;
0259 
0260   /// <summary>
0261   /// Creates a new instance with the sign flipped.
0262   /// </summary>
0263   /// <returns>Flipped sign instance</returns>
0264   using Base::Negate;
0265 
0266   /// <summary>
0267   /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0268   /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
0269   /// and therefore equivalent, if the resulting value is still zero.
0270   /// </summary>
0271   /// <param name="lhs">first value</param>
0272   /// <param name="rhs">second value</param>
0273   /// <returns>True if both arguments represent zero</returns>
0274   using Base::AreZero;
0275 
0276   /// <summary>
0277   /// User defined conversion operator. Converts Float16_t to float.
0278   /// </summary>
0279   explicit operator float() const noexcept { return ToFloat(); }
0280 
0281   using Base::operator==;
0282   using Base::operator!=;
0283   using Base::operator<;
0284 };
0285 
0286 static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
0287 
0288 /** \brief bfloat16 (Brain Floating Point) data type
0289  *
0290  * \details This struct is used for converting float to bfloat16 and back
0291  * so the user could feed inputs and fetch outputs using these type.
0292  *
0293  * The size of the structure should align with uint16_t and one can freely cast
0294  * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
0295  *
0296  * \code{.unparsed}
0297  * // This example demonstrates converion from float to float16
0298  * constexpr float values[] = {1.f, 2.f, 3.f, 4.f, 5.f};
0299  * std::vector<Ort::BFloat16_t> bfp16_values;
0300  * bfp16_values.reserve(std::size(values));
0301  * std::transform(std::begin(values), std::end(values), std::back_inserter(bfp16_values),
0302  *     [](float value) { return Ort::BFloat16_t(value); });
0303  *
0304  * \endcode
0305  */
0306 struct BFloat16_t : onnxruntime_float16::BFloat16Impl<BFloat16_t> {
0307  private:
0308   /// <summary>
0309   /// Constructor from a uint16_t representation of bfloat16
0310   /// used in FromBits() to escape overload resolution issue with
0311   /// constructor from float.
0312   /// No conversion is done.
0313   /// </summary>
0314   /// <param name="v">16-bit bfloat16 value</param>
0315   constexpr explicit BFloat16_t(uint16_t v) noexcept { val = v; }
0316 
0317  public:
0318   using Base = onnxruntime_float16::BFloat16Impl<BFloat16_t>;
0319 
0320   BFloat16_t() = default;
0321 
0322   /// <summary>
0323   /// Explicit conversion to uint16_t representation of bfloat16.
0324   /// </summary>
0325   /// <param name="v">uint16_t bit representation of bfloat16</param>
0326   /// <returns>new instance of BFloat16_t</returns>
0327   static constexpr BFloat16_t FromBits(uint16_t v) noexcept { return BFloat16_t(v); }
0328 
0329   /// <summary>
0330   /// __ctor from float. Float is converted into bfloat16 16-bit representation.
0331   /// </summary>
0332   /// <param name="v">float value</param>
0333   explicit BFloat16_t(float v) noexcept { val = Base::ToUint16Impl(v); }
0334 
0335   /// <summary>
0336   /// Converts bfloat16 to float
0337   /// </summary>
0338   /// <returns>float representation of bfloat16 value</returns>
0339   float ToFloat() const noexcept { return Base::ToFloatImpl(); }
0340 
0341   /// <summary>
0342   /// Checks if the value is negative
0343   /// </summary>
0344   /// <returns>true if negative</returns>
0345   using Base::IsNegative;
0346 
0347   /// <summary>
0348   /// Tests if the value is NaN
0349   /// </summary>
0350   /// <returns>true if NaN</returns>
0351   using Base::IsNaN;
0352 
0353   /// <summary>
0354   /// Tests if the value is finite
0355   /// </summary>
0356   /// <returns>true if finite</returns>
0357   using Base::IsFinite;
0358 
0359   /// <summary>
0360   /// Tests if the value represents positive infinity.
0361   /// </summary>
0362   /// <returns>true if positive infinity</returns>
0363   using Base::IsPositiveInfinity;
0364 
0365   /// <summary>
0366   /// Tests if the value represents negative infinity
0367   /// </summary>
0368   /// <returns>true if negative infinity</returns>
0369   using Base::IsNegativeInfinity;
0370 
0371   /// <summary>
0372   /// Tests if the value is either positive or negative infinity.
0373   /// </summary>
0374   /// <returns>True if absolute value is infinity</returns>
0375   using Base::IsInfinity;
0376 
0377   /// <summary>
0378   /// Tests if the value is NaN or zero. Useful for comparisons.
0379   /// </summary>
0380   /// <returns>True if NaN or zero.</returns>
0381   using Base::IsNaNOrZero;
0382 
0383   /// <summary>
0384   /// Tests if the value is normal (not zero, subnormal, infinite, or NaN).
0385   /// </summary>
0386   /// <returns>True if so</returns>
0387   using Base::IsNormal;
0388 
0389   /// <summary>
0390   /// Tests if the value is subnormal (denormal).
0391   /// </summary>
0392   /// <returns>True if so</returns>
0393   using Base::IsSubnormal;
0394 
0395   /// <summary>
0396   /// Creates an instance that represents absolute value.
0397   /// </summary>
0398   /// <returns>Absolute value</returns>
0399   using Base::Abs;
0400 
0401   /// <summary>
0402   /// Creates a new instance with the sign flipped.
0403   /// </summary>
0404   /// <returns>Flipped sign instance</returns>
0405   using Base::Negate;
0406 
0407   /// <summary>
0408   /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check
0409   /// for two values by or'ing the private bits together and stripping the sign. They are both zero,
0410   /// and therefore equivalent, if the resulting value is still zero.
0411   /// </summary>
0412   /// <param name="lhs">first value</param>
0413   /// <param name="rhs">second value</param>
0414   /// <returns>True if both arguments represent zero</returns>
0415   using Base::AreZero;
0416 
0417   /// <summary>
0418   /// User defined conversion operator. Converts BFloat16_t to float.
0419   /// </summary>
0420   explicit operator float() const noexcept { return ToFloat(); }
0421 
0422   // We do not have an inherited impl for the below operators
0423   // as the internal class implements them a little differently
0424   bool operator==(const BFloat16_t& rhs) const noexcept;
0425   bool operator!=(const BFloat16_t& rhs) const noexcept { return !(*this == rhs); }
0426   bool operator<(const BFloat16_t& rhs) const noexcept;
0427 };
0428 
0429 static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
0430 
0431 /** \brief float8e4m3fn (Float8 Floating Point) data type
0432  * \details It is necessary for type dispatching to make use of C++ API
0433  * The type is implicitly convertible to/from uint8_t.
0434  * See https://onnx.ai/onnx/technical/float8.html for further details.
0435  */
0436 struct Float8E4M3FN_t {
0437   uint8_t value;
0438   constexpr Float8E4M3FN_t() noexcept : value(0) {}
0439   constexpr Float8E4M3FN_t(uint8_t v) noexcept : value(v) {}
0440   constexpr operator uint8_t() const noexcept { return value; }
0441   // nan values are treated like any other value for operator ==, !=
0442   constexpr bool operator==(const Float8E4M3FN_t& rhs) const noexcept { return value == rhs.value; };
0443   constexpr bool operator!=(const Float8E4M3FN_t& rhs) const noexcept { return value != rhs.value; };
0444 };
0445 
0446 static_assert(sizeof(Float8E4M3FN_t) == sizeof(uint8_t), "Sizes must match");
0447 
0448 /** \brief float8e4m3fnuz (Float8 Floating Point) data type
0449  * \details It is necessary for type dispatching to make use of C++ API
0450  * The type is implicitly convertible to/from uint8_t.
0451  * See https://onnx.ai/onnx/technical/float8.html for further details.
0452  */
0453 struct Float8E4M3FNUZ_t {
0454   uint8_t value;
0455   constexpr Float8E4M3FNUZ_t() noexcept : value(0) {}
0456   constexpr Float8E4M3FNUZ_t(uint8_t v) noexcept : value(v) {}
0457   constexpr operator uint8_t() const noexcept { return value; }
0458   // nan values are treated like any other value for operator ==, !=
0459   constexpr bool operator==(const Float8E4M3FNUZ_t& rhs) const noexcept { return value == rhs.value; };
0460   constexpr bool operator!=(const Float8E4M3FNUZ_t& rhs) const noexcept { return value != rhs.value; };
0461 };
0462 
0463 static_assert(sizeof(Float8E4M3FNUZ_t) == sizeof(uint8_t), "Sizes must match");
0464 
0465 /** \brief float8e5m2 (Float8 Floating Point) data type
0466  * \details It is necessary for type dispatching to make use of C++ API
0467  * The type is implicitly convertible to/from uint8_t.
0468  * See https://onnx.ai/onnx/technical/float8.html for further details.
0469  */
0470 struct Float8E5M2_t {
0471   uint8_t value;
0472   constexpr Float8E5M2_t() noexcept : value(0) {}
0473   constexpr Float8E5M2_t(uint8_t v) noexcept : value(v) {}
0474   constexpr operator uint8_t() const noexcept { return value; }
0475   // nan values are treated like any other value for operator ==, !=
0476   constexpr bool operator==(const Float8E5M2_t& rhs) const noexcept { return value == rhs.value; };
0477   constexpr bool operator!=(const Float8E5M2_t& rhs) const noexcept { return value != rhs.value; };
0478 };
0479 
0480 static_assert(sizeof(Float8E5M2_t) == sizeof(uint8_t), "Sizes must match");
0481 
0482 /** \brief float8e5m2fnuz (Float8 Floating Point) data type
0483  * \details It is necessary for type dispatching to make use of C++ API
0484  * The type is implicitly convertible to/from uint8_t.
0485  * See https://onnx.ai/onnx/technical/float8.html for further details.
0486  */
0487 struct Float8E5M2FNUZ_t {
0488   uint8_t value;
0489   constexpr Float8E5M2FNUZ_t() noexcept : value(0) {}
0490   constexpr Float8E5M2FNUZ_t(uint8_t v) noexcept : value(v) {}
0491   constexpr operator uint8_t() const noexcept { return value; }
0492   // nan values are treated like any other value for operator ==, !=
0493   constexpr bool operator==(const Float8E5M2FNUZ_t& rhs) const noexcept { return value == rhs.value; };
0494   constexpr bool operator!=(const Float8E5M2FNUZ_t& rhs) const noexcept { return value != rhs.value; };
0495 };
0496 
0497 static_assert(sizeof(Float8E5M2FNUZ_t) == sizeof(uint8_t), "Sizes must match");
0498 
0499 namespace detail {
0500 // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
0501 // This can't be done in the C API since C doesn't have function overloading.
0502 #define ORT_DEFINE_RELEASE(NAME) \
0503   inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
0504 
0505 ORT_DEFINE_RELEASE(Allocator);
0506 ORT_DEFINE_RELEASE(MemoryInfo);
0507 ORT_DEFINE_RELEASE(CustomOpDomain);
0508 ORT_DEFINE_RELEASE(ThreadingOptions);
0509 ORT_DEFINE_RELEASE(Env);
0510 ORT_DEFINE_RELEASE(RunOptions);
0511 ORT_DEFINE_RELEASE(LoraAdapter);
0512 ORT_DEFINE_RELEASE(Session);
0513 ORT_DEFINE_RELEASE(SessionOptions);
0514 ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
0515 ORT_DEFINE_RELEASE(SequenceTypeInfo);
0516 ORT_DEFINE_RELEASE(MapTypeInfo);
0517 ORT_DEFINE_RELEASE(TypeInfo);
0518 ORT_DEFINE_RELEASE(Value);
0519 ORT_DEFINE_RELEASE(ModelMetadata);
0520 ORT_DEFINE_RELEASE(IoBinding);
0521 ORT_DEFINE_RELEASE(ArenaCfg);
0522 ORT_DEFINE_RELEASE(Status);
0523 ORT_DEFINE_RELEASE(OpAttr);
0524 ORT_DEFINE_RELEASE(Op);
0525 ORT_DEFINE_RELEASE(KernelInfo);
0526 
0527 #undef ORT_DEFINE_RELEASE
0528 
0529 /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
0530  *   has no ownership of the underlying C object.
0531  */
0532 template <typename T>
0533 struct Unowned {
0534   using Type = T;
0535 };
0536 
0537 /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
0538  *   This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
0539  *
0540  * All of the C++ classes
0541  *  a) serve as containers for pointers to objects that are created by the underlying C API.
0542  *     Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
0543  *  b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
0544  *     they would release objects owned automatically when going out of scope, they are move-only.
0545  *  c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
0546  *     ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
0547  *     such as Onnxruntime or instances of XXXX classes.
0548  *  d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
0549  *     in C++ code.
0550  *
0551  */
0552 
0553 /// <summary>
0554 /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
0555 /// </summary>
0556 template <typename T>
0557 struct Base {
0558   using contained_type = T;
0559 
0560   constexpr Base() = default;
0561   constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
0562   ~Base() { OrtRelease(p_); }
0563 
0564   Base(const Base&) = delete;
0565   Base& operator=(const Base&) = delete;
0566 
0567   Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
0568   Base& operator=(Base&& v) noexcept {
0569     OrtRelease(p_);
0570     p_ = v.release();
0571     return *this;
0572   }
0573 
0574   constexpr operator contained_type*() const noexcept { return p_; }
0575 
0576   /// \brief Relinquishes ownership of the contained C object pointer
0577   /// The underlying object is not destroyed
0578   contained_type* release() {
0579     T* p = p_;
0580     p_ = nullptr;
0581     return p;
0582   }
0583 
0584  protected:
0585   contained_type* p_{};
0586 };
0587 
0588 // Undefined. For const types use Base<Unowned<const T>>
0589 template <typename T>
0590 struct Base<const T>;
0591 
0592 /// <summary>
0593 /// Covers unowned pointers owned by either the ORT
0594 /// or some other instance of CPP wrappers.
0595 /// Used for ConstXXX and UnownedXXXX types that are copyable.
0596 /// Also convenient to wrap raw OrtXX pointers .
0597 /// </summary>
0598 /// <typeparam name="T"></typeparam>
0599 template <typename T>
0600 struct Base<Unowned<T>> {
0601   using contained_type = typename Unowned<T>::Type;
0602 
0603   constexpr Base() = default;
0604   constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
0605 
0606   ~Base() = default;
0607 
0608   Base(const Base&) = default;
0609   Base& operator=(const Base&) = default;
0610 
0611   Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
0612   Base& operator=(Base&& v) noexcept {
0613     p_ = nullptr;
0614     std::swap(p_, v.p_);
0615     return *this;
0616   }
0617 
0618   constexpr operator contained_type*() const noexcept { return p_; }
0619 
0620  protected:
0621   contained_type* p_{};
0622 };
0623 
0624 // Light functor to release memory with OrtAllocator
0625 struct AllocatedFree {
0626   OrtAllocator* allocator_;
0627   explicit AllocatedFree(OrtAllocator* allocator)
0628       : allocator_(allocator) {}
0629   void operator()(void* ptr) const {
0630     if (ptr) allocator_->Free(allocator_, ptr);
0631   }
0632 };
0633 
0634 }  // namespace detail
0635 
0636 struct AllocatorWithDefaultOptions;
0637 struct Env;
0638 struct TypeInfo;
0639 struct Value;
0640 struct ModelMetadata;
0641 
0642 /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
0643  *  and release them at the end of the scope. The lifespan of the given allocator
0644  *  must eclipse the lifespan of AllocatedStringPtr instance
0645  */
0646 using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
0647 
0648 /** \brief The Status that holds ownership of OrtStatus received from C API
0649  *  Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
0650  *  constructors to construct an instance of a Status object from exceptions.
0651  */
0652 struct Status : detail::Base<OrtStatus> {
0653   explicit Status(std::nullptr_t) noexcept {}               ///< Create an empty object, must be assigned a valid one to be used
0654   explicit Status(OrtStatus* status) noexcept;              ///< Takes ownership of OrtStatus instance returned from the C API.
0655   explicit Status(const Exception&) noexcept;               ///< Creates status instance out of exception
0656   explicit Status(const std::exception&) noexcept;          ///< Creates status instance out of exception
0657   Status(const char* message, OrtErrorCode code) noexcept;  ///< Creates status instance out of null-terminated string message.
0658   std::string GetErrorMessage() const;
0659   OrtErrorCode GetErrorCode() const;
0660   bool IsOK() const noexcept;  ///< Returns true if instance represents an OK (non-error) status.
0661 };
0662 
0663 /** \brief The ThreadingOptions
0664  *
0665  * The ThreadingOptions used for set global threadpools' options of The Env.
0666  */
0667 struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
0668   /// \brief Wraps OrtApi::CreateThreadingOptions
0669   ThreadingOptions();
0670 
0671   /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
0672   ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
0673 
0674   /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
0675   ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
0676 
0677   /// \brief Wraps OrtApi::SetGlobalSpinControl
0678   ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
0679 
0680   /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
0681   ThreadingOptions& SetGlobalDenormalAsZero();
0682 
0683   /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
0684   ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
0685 
0686   /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
0687   ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
0688 
0689   /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
0690   ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
0691 };
0692 
0693 /** \brief The Env (Environment)
0694  *
0695  * The Env holds the logging state used by all other objects.
0696  * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
0697  */
0698 struct Env : detail::Base<OrtEnv> {
0699   explicit Env(std::nullptr_t) {}  ///< Create an empty Env object, must be assigned a valid one to be used
0700 
0701   /// \brief Wraps OrtApi::CreateEnv
0702   Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
0703 
0704   /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
0705   Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
0706 
0707   /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
0708   Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
0709 
0710   /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
0711   Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
0712       OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
0713 
0714   /// \brief C Interop Helper
0715   explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
0716 
0717   Env& EnableTelemetryEvents();   ///< Wraps OrtApi::EnableTelemetryEvents
0718   Env& DisableTelemetryEvents();  ///< Wraps OrtApi::DisableTelemetryEvents
0719 
0720   Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level);  ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
0721 
0722   Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);  ///< Wraps OrtApi::CreateAndRegisterAllocator
0723 
0724   Env& CreateAndRegisterAllocatorV2(const std::string& provider_type, const OrtMemoryInfo* mem_info, const std::unordered_map<std::string, std::string>& options, const OrtArenaCfg* arena_cfg);  ///< Wraps OrtApi::CreateAndRegisterAllocatorV2
0725 };
0726 
0727 /** \brief Custom Op Domain
0728  *
0729  */
0730 struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
0731   explicit CustomOpDomain(std::nullptr_t) {}  ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
0732 
0733   /// \brief Wraps OrtApi::CreateCustomOpDomain
0734   explicit CustomOpDomain(const char* domain);
0735 
0736   // This does not take ownership of the op, simply registers it.
0737   void Add(const OrtCustomOp* op);  ///< Wraps CustomOpDomain_Add
0738 };
0739 
0740 /// \brief LoraAdapter holds a set of Lora Parameters loaded from a single file
0741 struct LoraAdapter : detail::Base<OrtLoraAdapter> {
0742   using Base = detail::Base<OrtLoraAdapter>;
0743   using Base::Base;
0744 
0745   explicit LoraAdapter(std::nullptr_t) {}  ///< Create an empty LoraAdapter object, must be assigned a valid one to be used
0746   /// \brief Wraps OrtApi::CreateLoraAdapter
0747   ///
0748   /// The function attempts to load the adapter from the specified file
0749   /// \param adapter_path The path to the Lora adapter
0750   /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still
0751   ///        be copied to device if required by the model at inference time.
0752   static LoraAdapter CreateLoraAdapter(const std::basic_string<ORTCHAR_T>& adapter_path,
0753                                        OrtAllocator* allocator);
0754 
0755   /// \brief Wraps OrtApi::CreateLoraAdapterFromArray
0756   ///
0757   /// The function attempts to load the adapter from the specified byte array.
0758   /// \param bytes The byte array containing file LoraAdapter format
0759   /// \param num_bytes The number of bytes in the byte array
0760   /// \param allocator optional pointer to a device allocator. If nullptr, the data stays on CPU. It would still
0761   ///        be copied to device if required by the model at inference time.
0762   static LoraAdapter CreateLoraAdapterFromArray(const void* bytes, size_t num_bytes,
0763                                                 OrtAllocator* allocator);
0764 };
0765 
0766 /** \brief RunOptions
0767  *
0768  */
0769 struct RunOptions : detail::Base<OrtRunOptions> {
0770   explicit RunOptions(std::nullptr_t) {}  ///< Create an empty RunOptions object, must be assigned a valid one to be used
0771   RunOptions();                           ///< Wraps OrtApi::CreateRunOptions
0772 
0773   RunOptions& SetRunLogVerbosityLevel(int);  ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
0774   int GetRunLogVerbosityLevel() const;       ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
0775 
0776   RunOptions& SetRunLogSeverityLevel(int);  ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
0777   int GetRunLogSeverityLevel() const;       ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
0778 
0779   RunOptions& SetRunTag(const char* run_tag);  ///< wraps OrtApi::RunOptionsSetRunTag
0780   const char* GetRunTag() const;               ///< Wraps OrtApi::RunOptionsGetRunTag
0781 
0782   RunOptions& AddConfigEntry(const char* config_key, const char* config_value);  ///< Wraps OrtApi::AddRunConfigEntry
0783 
0784   /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
0785    *
0786    * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
0787    * Wraps OrtApi::RunOptionsSetTerminate
0788    */
0789   RunOptions& SetTerminate();
0790 
0791   /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
0792    *
0793    * Wraps OrtApi::RunOptionsUnsetTerminate
0794    */
0795   RunOptions& UnsetTerminate();
0796 
0797   /** \brief Add the LoraAdapter to the list of active adapters.
0798    *  The setting does not affect RunWithBinding() calls.
0799    *
0800    * Wraps OrtApi::RunOptionsAddActiveLoraAdapter
0801    * \param adapter The LoraAdapter to be used as the active adapter
0802    */
0803   RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter);
0804 };
0805 
0806 namespace detail {
0807 // Utility function that returns a SessionOption config entry key for a specific custom operator.
0808 // Ex: custom_op.[custom_op_name].[config]
0809 std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
0810 }  // namespace detail
0811 
0812 /// <summary>
0813 /// Class that represents session configuration entries for one or more custom operators.
0814 ///
0815 /// Example:
0816 ///   Ort::CustomOpConfigs op_configs;
0817 ///   op_configs.AddConfig("my_custom_op", "device_type", "CPU");
0818 ///
0819 /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
0820 /// </summary>
0821 struct CustomOpConfigs {
0822   CustomOpConfigs() = default;
0823   ~CustomOpConfigs() = default;
0824   CustomOpConfigs(const CustomOpConfigs&) = default;
0825   CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
0826   CustomOpConfigs(CustomOpConfigs&& o) = default;
0827   CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
0828 
0829   /** \brief Adds a session configuration entry/value for a specific custom operator.
0830    *
0831    * \param custom_op_name The name of the custom operator for which to add a configuration entry.
0832    *                       Must match the name returned by the CustomOp's GetName() method.
0833    * \param config_key The name of the configuration entry.
0834    * \param config_value The value of the configuration entry.
0835    * \return A reference to this object to enable call chaining.
0836    */
0837   CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
0838 
0839   /** \brief Returns a flattened map of custom operator configuration entries and their values.
0840    *
0841    * The keys has been flattened to include both the custom operator name and the configuration entry key name.
0842    * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
0843    * {"my_op.key", "value"}.
0844    *
0845    * \return An unordered map of flattened configurations.
0846    */
0847   const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
0848 
0849  private:
0850   std::unordered_map<std::string, std::string> flat_configs_;
0851 };
0852 
0853 /** \brief Options object used when creating a new Session object
0854  *
0855  * Wraps ::OrtSessionOptions object and methods
0856  */
0857 
0858 struct SessionOptions;
0859 
0860 namespace detail {
0861 // we separate const-only methods because passing const ptr to non-const methods
0862 // is only discovered when inline methods are compiled which is counter-intuitive
0863 template <typename T>
0864 struct ConstSessionOptionsImpl : Base<T> {
0865   using B = Base<T>;
0866   using B::B;
0867 
0868   SessionOptions Clone() const;  ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
0869 
0870   std::string GetConfigEntry(const char* config_key) const;  ///< Wraps OrtApi::GetSessionConfigEntry
0871   bool HasConfigEntry(const char* config_key) const;         ///< Wraps OrtApi::HasSessionConfigEntry
0872   std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
0873 };
0874 
0875 template <typename T>
0876 struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
0877   using B = ConstSessionOptionsImpl<T>;
0878   using B::B;
0879 
0880   SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads);                              ///< Wraps OrtApi::SetIntraOpNumThreads
0881   SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads);                              ///< Wraps OrtApi::SetInterOpNumThreads
0882   SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);  ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
0883   SessionOptionsImpl& SetDeterministicCompute(bool value);                                         ///< Wraps OrtApi::SetDeterministicCompute
0884 
0885   SessionOptionsImpl& EnableCpuMemArena();   ///< Wraps OrtApi::EnableCpuMemArena
0886   SessionOptionsImpl& DisableCpuMemArena();  ///< Wraps OrtApi::DisableCpuMemArena
0887 
0888   SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);  ///< Wraps OrtApi::SetOptimizedModelFilePath
0889 
0890   SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix);  ///< Wraps OrtApi::EnableProfiling
0891   SessionOptionsImpl& DisableProfiling();                                     ///< Wraps OrtApi::DisableProfiling
0892 
0893   SessionOptionsImpl& EnableOrtCustomOps();  ///< Wraps OrtApi::EnableOrtCustomOps
0894 
0895   SessionOptionsImpl& EnableMemPattern();   ///< Wraps OrtApi::EnableMemPattern
0896   SessionOptionsImpl& DisableMemPattern();  ///< Wraps OrtApi::DisableMemPattern
0897 
0898   SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode);  ///< Wraps OrtApi::SetSessionExecutionMode
0899 
0900   SessionOptionsImpl& SetLogId(const char* logid);     ///< Wraps OrtApi::SetSessionLogId
0901   SessionOptionsImpl& SetLogSeverityLevel(int level);  ///< Wraps OrtApi::SetSessionLogSeverityLevel
0902 
0903   SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain);  ///< Wraps OrtApi::AddCustomOpDomain
0904 
0905   SessionOptionsImpl& DisablePerSessionThreads();  ///< Wraps OrtApi::DisablePerSessionThreads
0906 
0907   SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value);  ///< Wraps OrtApi::AddSessionConfigEntry
0908 
0909   SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val);                                             ///< Wraps OrtApi::AddInitializer
0910   SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values);  ///< Wraps OrtApi::AddExternalInitializers
0911   SessionOptionsImpl& AddExternalInitializersFromFilesInMemory(const std::vector<std::basic_string<ORTCHAR_T>>& external_initializer_file_names,
0912                                                                const std::vector<char*>& external_initializer_file_buffer_array,
0913                                                                const std::vector<size_t>& external_initializer_file_lengths);  ///< Wraps OrtApi::AddExternalInitializersFromFilesInMemory
0914 
0915   SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);          ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
0916   SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options);     ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
0917   SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);          ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
0918   SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
0919   ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO_V2
0920   SessionOptionsImpl& AppendExecutionProvider_OpenVINO_V2(const std::unordered_map<std::string, std::string>& provider_options = {});
0921   SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);       ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
0922   SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options);  ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
0923   SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options);       ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
0924   ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
0925   SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
0926   ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
0927   SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
0928   /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
0929   SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
0930                                               const std::unordered_map<std::string, std::string>& provider_options = {});
0931 
0932   SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);  ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
0933   SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options);      ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
0934   SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);        ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
0935 
0936   ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
0937   ///< The custom operator configurations are optional. If provided, custom operator configs are set via
0938   ///< OrtApi::AddSessionConfigEntry.
0939   SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
0940 
0941   SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name);  ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
0942 
0943   ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI
0944   SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map<std::string, std::string>& provider_options = {});
0945 };
0946 }  // namespace detail
0947 
0948 using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
0949 using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
0950 
0951 /** \brief Wrapper around ::OrtSessionOptions
0952  *
0953  */
0954 struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
0955   explicit SessionOptions(std::nullptr_t) {}                                                   ///< Create an empty SessionOptions object, must be assigned a valid one to be used
0956   SessionOptions();                                                                            ///< Wraps OrtApi::CreateSessionOptions
0957   explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {}  ///< Used for interop with the C API
0958   UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
0959   ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
0960 };
0961 
0962 /** \brief Wrapper around ::OrtModelMetadata
0963  *
0964  */
0965 struct ModelMetadata : detail::Base<OrtModelMetadata> {
0966   explicit ModelMetadata(std::nullptr_t) {}                                   ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
0967   explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}  ///< Used for interop with the C API
0968 
0969   /** \brief Returns a copy of the producer name.
0970    *
0971    * \param allocator to allocate memory for the copy of the name returned
0972    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
0973    *  The OrtAllocator instances must be valid at the point of memory release.
0974    */
0975   AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const;  ///< Wraps OrtApi::ModelMetadataGetProducerName
0976 
0977   /** \brief Returns a copy of the graph name.
0978    *
0979    * \param allocator to allocate memory for the copy of the name returned
0980    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
0981    *  The OrtAllocator instances must be valid at the point of memory release.
0982    */
0983   AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const;  ///< Wraps OrtApi::ModelMetadataGetGraphName
0984 
0985   /** \brief Returns a copy of the domain name.
0986    *
0987    * \param allocator to allocate memory for the copy of the name returned
0988    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
0989    *  The OrtAllocator instances must be valid at the point of memory release.
0990    */
0991   AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const;  ///< Wraps OrtApi::ModelMetadataGetDomain
0992 
0993   /** \brief Returns a copy of the description.
0994    *
0995    * \param allocator to allocate memory for the copy of the string returned
0996    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
0997    *  The OrtAllocator instances must be valid at the point of memory release.
0998    */
0999   AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const;  ///< Wraps OrtApi::ModelMetadataGetDescription
1000 
1001   /** \brief Returns a copy of the graph description.
1002    *
1003    * \param allocator to allocate memory for the copy of the string returned
1004    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1005    *  The OrtAllocator instances must be valid at the point of memory release.
1006    */
1007   AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const;  ///< Wraps OrtApi::ModelMetadataGetGraphDescription
1008 
1009   /** \brief Returns a vector of copies of the custom metadata keys.
1010    *
1011    * \param allocator to allocate memory for the copy of the string returned
1012    * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
1013    *  The OrtAllocator instance must be valid at the point of memory release.
1014    */
1015   std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const;  ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
1016 
1017   /** \brief Looks up a value by a key in the Custom Metadata map
1018    *
1019    * \param key zero terminated string key to lookup
1020    * \param allocator to allocate memory for the copy of the string returned
1021    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1022    *  maybe nullptr if key is not found.
1023    *
1024    *  The OrtAllocator instances must be valid at the point of memory release.
1025    */
1026   AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const;  ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
1027 
1028   int64_t GetVersion() const;  ///< Wraps OrtApi::ModelMetadataGetVersion
1029 };
1030 
1031 struct IoBinding;
1032 
1033 namespace detail {
1034 
1035 // we separate const-only methods because passing const ptr to non-const methods
1036 // is only discovered when inline methods are compiled which is counter-intuitive
1037 template <typename T>
1038 struct ConstSessionImpl : Base<T> {
1039   using B = Base<T>;
1040   using B::B;
1041 
1042   size_t GetInputCount() const;                   ///< Returns the number of model inputs
1043   size_t GetOutputCount() const;                  ///< Returns the number of model outputs
1044   size_t GetOverridableInitializerCount() const;  ///< Returns the number of inputs that have defaults that can be overridden
1045 
1046   /** \brief Returns a copy of input name at the specified index.
1047    *
1048    * \param index must less than the value returned by GetInputCount()
1049    * \param allocator to allocate memory for the copy of the name returned
1050    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1051    *  The OrtAllocator instances must be valid at the point of memory release.
1052    */
1053   AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
1054 
1055   /** \brief Returns a copy of output name at then specified index.
1056    *
1057    * \param index must less than the value returned by GetOutputCount()
1058    * \param allocator to allocate memory for the copy of the name returned
1059    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1060    *  The OrtAllocator instances must be valid at the point of memory release.
1061    */
1062   AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
1063 
1064   /** \brief Returns a copy of the overridable initializer name at then specified index.
1065    *
1066    * \param index must less than the value returned by GetOverridableInitializerCount()
1067    * \param allocator to allocate memory for the copy of the name returned
1068    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1069    *  The OrtAllocator instances must be valid at the point of memory release.
1070    */
1071   AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const;  ///< Wraps OrtApi::SessionGetOverridableInitializerName
1072 
1073   uint64_t GetProfilingStartTimeNs() const;  ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
1074   ModelMetadata GetModelMetadata() const;    ///< Wraps OrtApi::SessionGetModelMetadata
1075 
1076   TypeInfo GetInputTypeInfo(size_t index) const;                   ///< Wraps OrtApi::SessionGetInputTypeInfo
1077   TypeInfo GetOutputTypeInfo(size_t index) const;                  ///< Wraps OrtApi::SessionGetOutputTypeInfo
1078   TypeInfo GetOverridableInitializerTypeInfo(size_t index) const;  ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
1079 };
1080 
1081 template <typename T>
1082 struct SessionImpl : ConstSessionImpl<T> {
1083   using B = ConstSessionImpl<T>;
1084   using B::B;
1085 
1086   /** \brief Run the model returning results in an Ort allocated vector.
1087    *
1088    * Wraps OrtApi::Run
1089    *
1090    * The caller provides a list of inputs and a list of the desired outputs to return.
1091    *
1092    * See the output logs for more information on warnings/errors that occur while processing the model.
1093    * Common errors are.. (TODO)
1094    *
1095    * \param[in] run_options
1096    * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
1097    * \param[in] input_values Array of Value objects of length input_count that is the list of input values
1098    * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
1099    * \param[in] output_names Array of C style strings of length output_count that is the list of output names
1100    * \param[in] output_count Number of outputs (the size of the output_names array)
1101    * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
1102    */
1103   std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1104                          const char* const* output_names, size_t output_count);
1105 
1106   /** \brief Run the model returning results in user provided outputs
1107    * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
1108    */
1109   void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1110            const char* const* output_names, Value* output_values, size_t output_count);
1111 
1112   void Run(const RunOptions& run_options, const IoBinding&);  ///< Wraps OrtApi::RunWithBinding
1113 
1114   /** \brief Run the model asynchronously in a thread owned by intra op thread pool
1115    *
1116    * Wraps OrtApi::RunAsync
1117    *
1118    * \param[in] run_options
1119    * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
1120    * \param[in] input_values Array of Value objects of length input_count
1121    * \param[in] input_count Number of elements in the input_names and inputs arrays
1122    * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
1123    * \param[out] output_values Array of provided Values to be filled with outputs.
1124    *             On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
1125    *             Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
1126    *             Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
1127    *             NOTE: it is customer's duty to finally release output_values and each of its member,
1128    *             regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
1129    * \param[in] output_count Number of elements in the output_names and outputs array
1130    * \param[in] callback Callback function on model run completion
1131    * \param[in] user_data User data that pass back to the callback
1132    */
1133   void RunAsync(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
1134                 const char* const* output_names, Value* output_values, size_t output_count, RunAsyncCallbackFn callback, void* user_data);
1135 
1136   /** \brief End profiling and return a copy of the profiling file name.
1137    *
1138    * \param allocator to allocate memory for the copy of the string returned
1139    * \return a instance of smart pointer that would deallocate the buffer when out of scope.
1140    *  The OrtAllocator instances must be valid at the point of memory release.
1141    */
1142   AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator);  ///< Wraps OrtApi::SessionEndProfiling
1143 
1144   /** \brief Set DynamicOptions for EPs (Execution Providers)
1145    *
1146    * Wraps OrtApi::SetEpDynamicOptions
1147    *
1148    * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h`
1149    * Look for `kOrtEpDynamicOptions`
1150    *
1151    * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys
1152    * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values
1153    * \param[in] kv_len Number of elements in the keys and values arrays
1154    */
1155   void SetEpDynamicOptions(const char* const* keys, const char* const* values, size_t kv_len);
1156 };
1157 
1158 }  // namespace detail
1159 
1160 using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
1161 using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
1162 
1163 /** \brief Wrapper around ::OrtSession
1164  *
1165  */
1166 struct Session : detail::SessionImpl<OrtSession> {
1167   explicit Session(std::nullptr_t) {}                                                   ///< Create an empty Session object, must be assigned a valid one to be used
1168   Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);  ///< Wraps OrtApi::CreateSession
1169   Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
1170           OrtPrepackedWeightsContainer* prepacked_weights_container);                                        ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
1171   Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);  ///< Wraps OrtApi::CreateSessionFromArray
1172   Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
1173           OrtPrepackedWeightsContainer* prepacked_weights_container);  ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
1174 
1175   ConstSession GetConst() const { return ConstSession{this->p_}; }
1176   UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
1177 };
1178 
1179 namespace detail {
1180 template <typename T>
1181 struct MemoryInfoImpl : Base<T> {
1182   using B = Base<T>;
1183   using B::B;
1184 
1185   std::string GetAllocatorName() const;
1186   OrtAllocatorType GetAllocatorType() const;
1187   int GetDeviceId() const;
1188   OrtMemoryInfoDeviceType GetDeviceType() const;
1189   OrtMemType GetMemoryType() const;
1190 
1191   template <typename U>
1192   bool operator==(const MemoryInfoImpl<U>& o) const;
1193 };
1194 }  // namespace detail
1195 
1196 // Const object holder that does not own the underlying object
1197 using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
1198 
1199 /** \brief Wrapper around ::OrtMemoryInfo
1200  *
1201  */
1202 struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
1203   static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
1204   explicit MemoryInfo(std::nullptr_t) {}                                       ///< No instance is created
1205   explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {}  ///< Take ownership of a pointer created by C Api
1206   MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
1207   ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
1208 };
1209 
1210 namespace detail {
1211 template <typename T>
1212 struct TensorTypeAndShapeInfoImpl : Base<T> {
1213   using B = Base<T>;
1214   using B::B;
1215 
1216   ONNXTensorElementDataType GetElementType() const;  ///< Wraps OrtApi::GetTensorElementType
1217   size_t GetElementCount() const;                    ///< Wraps OrtApi::GetTensorShapeElementCount
1218 
1219   size_t GetDimensionsCount() const;  ///< Wraps OrtApi::GetDimensionsCount
1220 
1221   /** \deprecated use GetShape() returning std::vector
1222    * [[deprecated]]
1223    * This interface is unsafe to use
1224    */
1225   [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const;  ///< Wraps OrtApi::GetDimensions
1226 
1227   void GetSymbolicDimensions(const char** values, size_t values_count) const;  ///< Wraps OrtApi::GetSymbolicDimensions
1228 
1229   std::vector<int64_t> GetShape() const;  ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
1230 };
1231 
1232 }  // namespace detail
1233 
1234 using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
1235 
1236 /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
1237  *
1238  */
1239 struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
1240   explicit TensorTypeAndShapeInfo(std::nullptr_t) {}                                                ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
1241   explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {}  ///< Used for interop with the C API
1242   ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
1243 };
1244 
1245 namespace detail {
1246 template <typename T>
1247 struct SequenceTypeInfoImpl : Base<T> {
1248   using B = Base<T>;
1249   using B::B;
1250   TypeInfo GetSequenceElementType() const;  ///< Wraps OrtApi::GetSequenceElementType
1251 };
1252 
1253 }  // namespace detail
1254 
1255 using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
1256 
1257 /** \brief Wrapper around ::OrtSequenceTypeInfo
1258  *
1259  */
1260 struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
1261   explicit SequenceTypeInfo(std::nullptr_t) {}                                                         ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
1262   explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {}  ///< Used for interop with the C API
1263   ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
1264 };
1265 
1266 namespace detail {
1267 template <typename T>
1268 struct OptionalTypeInfoImpl : Base<T> {
1269   using B = Base<T>;
1270   using B::B;
1271   TypeInfo GetOptionalElementType() const;  ///< Wraps OrtApi::CastOptionalTypeToContainedTypeInfo
1272 };
1273 
1274 }  // namespace detail
1275 
1276 // This is always owned by the TypeInfo and can only be obtained from it.
1277 using ConstOptionalTypeInfo = detail::OptionalTypeInfoImpl<detail::Unowned<const OrtOptionalTypeInfo>>;
1278 
1279 namespace detail {
1280 template <typename T>
1281 struct MapTypeInfoImpl : detail::Base<T> {
1282   using B = Base<T>;
1283   using B::B;
1284   ONNXTensorElementDataType GetMapKeyType() const;  ///< Wraps OrtApi::GetMapKeyType
1285   TypeInfo GetMapValueType() const;                 ///< Wraps OrtApi::GetMapValueType
1286 };
1287 
1288 }  // namespace detail
1289 
1290 using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
1291 
1292 /** \brief Wrapper around ::OrtMapTypeInfo
1293  *
1294  */
1295 struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
1296   explicit MapTypeInfo(std::nullptr_t) {}                                          ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
1297   explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {}  ///< Used for interop with the C API
1298   ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
1299 };
1300 
1301 namespace detail {
1302 template <typename T>
1303 struct TypeInfoImpl : detail::Base<T> {
1304   using B = Base<T>;
1305   using B::B;
1306 
1307   ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;  ///< Wraps OrtApi::CastTypeInfoToTensorInfo
1308   ConstSequenceTypeInfo GetSequenceTypeInfo() const;              ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
1309   ConstMapTypeInfo GetMapTypeInfo() const;                        ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
1310   ConstOptionalTypeInfo GetOptionalTypeInfo() const;              ///< wraps OrtApi::CastTypeInfoToOptionalTypeInfo
1311 
1312   ONNXType GetONNXType() const;
1313 };
1314 }  // namespace detail
1315 
1316 /// <summary>
1317 /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
1318 /// Provides access to const OrtTypeInfo APIs.
1319 /// </summary>
1320 using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
1321 
1322 /// <summary>
1323 /// Type information that may contain either TensorTypeAndShapeInfo or
1324 /// the information about contained sequence or map depending on the ONNXType.
1325 /// </summary>
1326 struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
1327   explicit TypeInfo(std::nullptr_t) {}                                 ///< Create an empty TypeInfo object, must be assigned a valid one to be used
1328   explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {}  ///< C API Interop
1329 
1330   ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
1331 };
1332 
1333 namespace detail {
1334 // This structure is used to feed  sparse tensor values
1335 // information for use with FillSparseTensor<Format>() API
1336 // if the data type for the sparse tensor values is numeric
1337 // use data.p_data, otherwise, use data.str pointer to feed
1338 // values. data.str is an array of const char* that are zero terminated.
1339 // number of strings in the array must match shape size.
1340 // For fully sparse tensors use shape {0} and set p_data/str
1341 // to nullptr.
1342 struct OrtSparseValuesParam {
1343   const int64_t* values_shape;
1344   size_t values_shape_len;
1345   union {
1346     const void* p_data;
1347     const char** str;
1348   } data;
1349 };
1350 
1351 // Provides a way to pass shape in a single
1352 // argument
1353 struct Shape {
1354   const int64_t* shape;
1355   size_t shape_len;
1356 };
1357 
1358 template <typename T>
1359 struct ConstValueImpl : Base<T> {
1360   using B = Base<T>;
1361   using B::B;
1362 
1363   /// <summary>
1364   /// Obtains a pointer to a user defined data for experimental purposes
1365   /// </summary>
1366   template <typename R>
1367   void GetOpaqueData(const char* domain, const char* type_name, R&) const;  ///< Wraps OrtApi::GetOpaqueValue
1368 
1369   bool IsTensor() const;  ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
1370   bool HasValue() const;  /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
1371 
1372   size_t GetCount() const;  // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
1373   Value GetValue(int index, OrtAllocator* allocator) const;
1374 
1375   /// <summary>
1376   /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
1377   /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
1378   /// for allocating necessary memory and calling GetStringTensorContent().
1379   /// </summary>
1380   /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
1381   size_t GetStringTensorDataLength() const;
1382 
1383   /// <summary>
1384   /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
1385   /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
1386   /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
1387   /// strings.
1388   ///
1389   /// Strings are always assumed to be on CPU, no X-device copy.
1390   /// </summary>
1391   /// <param name="buffer">user allocated buffer</param>
1392   /// <param name="buffer_length">length in bytes of the allocated buffer</param>
1393   /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
1394   /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
1395   ///   that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
1396   ///   for sparse tensors</param>
1397   void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1398 
1399   /// <summary>
1400   /// Returns a const typed pointer to the tensor contained data.
1401   /// No type checking is performed, the caller must ensure the type matches the tensor type.
1402   /// </summary>
1403   /// <typeparam name="T"></typeparam>
1404   /// <returns>const pointer to data, no copies made</returns>
1405   template <typename R>
1406   const R* GetTensorData() const;  ///< Wraps OrtApi::GetTensorMutableData   /// <summary>
1407 
1408   /// <summary>
1409   /// Returns a non-typed pointer to a tensor contained data.
1410   /// </summary>
1411   /// <returns>const pointer to data, no copies made</returns>
1412   const void* GetTensorRawData() const;
1413 
1414   /// <summary>
1415   /// The API returns type information for data contained in a tensor. For sparse
1416   /// tensors it returns type information for contained non-zero values.
1417   /// It returns dense shape for sparse tensors.
1418   /// </summary>
1419   /// <returns>TypeInfo</returns>
1420   TypeInfo GetTypeInfo() const;
1421 
1422   /// <summary>
1423   /// The API returns type information for data contained in a tensor. For sparse
1424   /// tensors it returns type information for contained non-zero values.
1425   /// It returns dense shape for sparse tensors.
1426   /// </summary>
1427   /// <returns>TensorTypeAndShapeInfo</returns>
1428   TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
1429 
1430   /// <summary>
1431   /// This API returns information about the memory allocation used to hold data.
1432   /// </summary>
1433   /// <returns>Non owning instance of MemoryInfo</returns>
1434   ConstMemoryInfo GetTensorMemoryInfo() const;
1435 
1436   /// <summary>
1437   /// The API copies UTF-8 encoded bytes for the requested string element
1438   /// contained within a tensor or a sparse tensor into a provided buffer.
1439   /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1440   /// </summary>
1441   /// <param name="buffer_length"></param>
1442   /// <param name="element_index"></param>
1443   /// <param name="buffer"></param>
1444   void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1445 
1446   /// <summary>
1447   /// Returns string tensor UTF-8 encoded string element.
1448   /// Use of this API is recommended over GetStringTensorElement() that takes void* buffer pointer.
1449   /// </summary>
1450   /// <param name="element_index"></param>
1451   /// <returns>std::string</returns>
1452   std::string GetStringTensorElement(size_t element_index) const;
1453 
1454   /// <summary>
1455   /// The API returns a byte length of UTF-8 encoded string element
1456   /// contained in either a tensor or a spare tensor values.
1457   /// </summary>
1458   /// <param name="element_index"></param>
1459   /// <returns>byte length for the specified string element</returns>
1460   size_t GetStringTensorElementLength(size_t element_index) const;
1461 
1462 #if !defined(DISABLE_SPARSE_TENSORS)
1463   /// <summary>
1464   /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1465   /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1466   /// the value returned is ORT_SPARSE_UNDEFINED.
1467   /// </summary>
1468   /// <returns>Format enum</returns>
1469   OrtSparseFormat GetSparseFormat() const;
1470 
1471   /// <summary>
1472   /// The API returns type and shape information for stored non-zero values of the
1473   /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1474   /// </summary>
1475   /// <returns>TensorTypeAndShapeInfo values information</returns>
1476   TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
1477 
1478   /// <summary>
1479   /// The API returns type and shape information for the specified indices. Each supported
1480   /// indices have their own enum values even if a give format has more than one kind of indices.
1481   /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1482   /// </summary>
1483   /// <param name="format">enum requested</param>
1484   /// <returns>type and shape information</returns>
1485   TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
1486 
1487   /// <summary>
1488   /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1489   /// a convenience data type casting on the return type pointer. Make sure you are requesting
1490   /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1491   /// </summary>
1492   /// <typeparam name="T">type to cast to</typeparam>
1493   /// <param name="indices_format">requested indices kind</param>
1494   /// <param name="num_indices">number of indices entries</param>
1495   /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1496   template <typename R>
1497   const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1498 
1499   /// <summary>
1500   /// Returns true if the OrtValue contains a sparse tensor
1501   /// </summary>
1502   /// <returns></returns>
1503   bool IsSparseTensor() const;
1504 
1505   /// <summary>
1506   /// The API returns a pointer to an internal buffer of the sparse tensor
1507   /// containing non-zero values. The API merely does casting. Make sure you
1508   /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1509   /// first.
1510   /// </summary>
1511   /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1512   /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1513   template <typename R>
1514   const R* GetSparseTensorValues() const;
1515 
1516 #endif
1517 };
1518 
1519 template <typename T>
1520 struct ValueImpl : ConstValueImpl<T> {
1521   using B = ConstValueImpl<T>;
1522   using B::B;
1523 
1524   /// <summary>
1525   /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1526   /// No type checking is performed, the caller must ensure the type matches the tensor type.
1527   /// </summary>
1528   /// <returns>non-const pointer to data, no copies made</returns>
1529   template <typename R>
1530   R* GetTensorMutableData();
1531 
1532   /// <summary>
1533   /// Returns a non-typed non-const pointer to a tensor contained data.
1534   /// </summary>
1535   /// <returns>pointer to data, no copies made</returns>
1536   void* GetTensorMutableRawData();
1537 
1538   /// <summary>
1539   //  Obtain a reference to an element of data at the location specified
1540   /// by the vector of dims.
1541   /// </summary>
1542   /// <typeparam name="R"></typeparam>
1543   /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1544   /// <returns></returns>
1545   template <typename R>
1546   R& At(const std::vector<int64_t>& location);
1547 
1548   /// <summary>
1549   /// Set all strings at once in a string tensor
1550   /// </summary>
1551   /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1552   /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1553   void FillStringTensor(const char* const* s, size_t s_len);
1554 
1555   /// <summary>
1556   /// Set a single string in a string tensor
1557   /// </summary>
1558   /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1559   /// <param name="index">[in] Index of the string in the tensor to set</param>
1560   void FillStringTensorElement(const char* s, size_t index);
1561 
1562   /// <summary>
1563   /// Allocate if necessary and obtain a pointer to a UTF-8
1564   /// encoded string element buffer indexed by the flat element index,
1565   /// of the specified length.
1566   ///
1567   /// This API is for advanced usage. It avoids a need to construct
1568   /// an auxiliary array of string pointers, and allows to write data directly
1569   /// (do not zero terminate).
1570   /// </summary>
1571   /// <param name="index"></param>
1572   /// <param name="buffer_length"></param>
1573   /// <returns>a pointer to a writable buffer</returns>
1574   char* GetResizedStringTensorElementBuffer(size_t index, size_t buffer_length);
1575 
1576 #if !defined(DISABLE_SPARSE_TENSORS)
1577   /// <summary>
1578   /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1579   /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1580   /// allocated buffers lifespan must eclipse that of the OrtValue.
1581   /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1582   /// </summary>
1583   /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1584   /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1585   void UseCooIndices(int64_t* indices_data, size_t indices_num);
1586 
1587   /// <summary>
1588   /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1589   /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1590   /// allocated buffers lifespan must eclipse that of the OrtValue.
1591   /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1592   /// </summary>
1593   /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1594   /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1595   /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1596   /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1597   void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1598 
1599   /// <summary>
1600   /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1601   /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1602   /// allocated buffers lifespan must eclipse that of the OrtValue.
1603   /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1604   /// </summary>
1605   /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1606   /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1607   void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1608 
1609   /// <summary>
1610   /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1611   /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1612   /// at difference device than the allocator, a X-device copy will be performed if possible.
1613   /// </summary>
1614   /// <param name="data_mem_info">specified buffer memory description</param>
1615   /// <param name="values_param">values buffer information.</param>
1616   /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1617   /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1618   void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1619                            const int64_t* indices_data, size_t indices_num);
1620 
1621   /// <summary>
1622   /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1623   /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1624   /// at difference device than the allocator, a X-device copy will be performed if possible.
1625   /// </summary>
1626   /// <param name="data_mem_info">specified buffer memory description</param>
1627   /// <param name="values">values buffer information</param>
1628   /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1629   /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1630   /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1631   /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1632   void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1633                            const OrtSparseValuesParam& values,
1634                            const int64_t* inner_indices_data, size_t inner_indices_num,
1635                            const int64_t* outer_indices_data, size_t outer_indices_num);
1636 
1637   /// <summary>
1638   /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1639   /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1640   /// at difference device than the allocator, a X-device copy will be performed if possible.
1641   /// </summary>
1642   /// <param name="data_mem_info">specified buffer memory description</param>
1643   /// <param name="values">values buffer information</param>
1644   /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1645   /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1646   void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1647                                    const OrtSparseValuesParam& values,
1648                                    const Shape& indices_shape,
1649                                    const int32_t* indices_data);
1650 
1651 #endif
1652 };
1653 
1654 }  // namespace detail
1655 
1656 using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
1657 using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
1658 
1659 /** \brief Wrapper around ::OrtValue
1660  *
1661  */
1662 struct Value : detail::ValueImpl<OrtValue> {
1663   using Base = detail::ValueImpl<OrtValue>;
1664   using OrtSparseValuesParam = detail::OrtSparseValuesParam;
1665   using Shape = detail::Shape;
1666 
1667   explicit Value(std::nullptr_t) {}         ///< Create an empty Value object, must be assigned a valid one to be used
1668   explicit Value(OrtValue* p) : Base{p} {}  ///< Used for interop with the C API
1669   Value(Value&&) = default;
1670   Value& operator=(Value&&) = default;
1671 
1672   ConstValue GetConst() const { return ConstValue{this->p_}; }
1673   UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1674 
1675   /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1676    * \tparam T The numeric datatype. This API is not suitable for strings.
1677    * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1678    * \param p_data Pointer to the data buffer.
1679    * \param p_data_element_count The number of elements in the data buffer.
1680    * \param shape Pointer to the tensor shape dimensions.
1681    * \param shape_len The number of tensor shape dimensions.
1682    */
1683   template <typename T>
1684   static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1685 
1686   /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1687    *
1688    * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1689    * \param p_data Pointer to the data buffer.
1690    * \param p_data_byte_count The number of bytes in the data buffer.
1691    * \param shape Pointer to the tensor shape dimensions.
1692    * \param shape_len The number of tensor shape dimensions.
1693    * \param type The data type.
1694    */
1695   static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1696                             ONNXTensorElementDataType type);
1697 
1698   /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1699    *         This overload will allocate the buffer for the tensor  according to the supplied shape and data type.
1700    *         The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1701    *         The input data would need to be copied into the allocated buffer.
1702    *         This API is not suitable for strings.
1703    *
1704    * \tparam T The numeric datatype. This API is not suitable for strings.
1705    * \param allocator The allocator to use.
1706    * \param shape Pointer to the tensor shape dimensions.
1707    * \param shape_len The number of tensor shape dimensions.
1708    */
1709   template <typename T>
1710   static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1711 
1712   /** \brief Creates an OrtValue with a tensor using the supplied OrtAllocator.
1713    *   Wraps OrtApi::CreateTensorAsOrtValue.
1714    *   The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released.
1715    *   The input data would need to be copied into the allocated buffer.
1716    *   This API is not suitable for strings.
1717    *
1718    * \param allocator The allocator to use.
1719    * \param shape Pointer to the tensor shape dimensions.
1720    * \param shape_len The number of tensor shape dimensions.
1721    * \param type The data type.
1722    */
1723   static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1724 
1725   /** \brief Creates an OrtValue with a Map Onnx type representation.
1726    *  The API would ref-count the supplied OrtValues and they will be released
1727    *  when the returned OrtValue is released. The caller may release keys and values after the call
1728    *  returns.
1729    *
1730    * \param keys an OrtValue containing a tensor with primitive data type keys.
1731    * \param values an OrtValue that may contain a tensor. Ort currently supports only primitive data type values.
1732    */
1733   static Value CreateMap(const Value& keys, const Value& values);  ///< Wraps OrtApi::CreateValue
1734 
1735   /** \brief Creates an OrtValue with a Sequence Onnx type representation.
1736    *  The API would ref-count the supplied OrtValues and they will be released
1737    *  when the returned OrtValue is released. The caller may release the values after the call
1738    *  returns.
1739    *
1740    * \param values a vector of OrtValues that must have the same Onnx value type.
1741    */
1742   static Value CreateSequence(const std::vector<Value>& values);  ///< Wraps OrtApi::CreateValue
1743 
1744   /** \brief Creates an OrtValue wrapping an Opaque type.
1745    *  This is used for experimental support of non-tensor types.
1746    *
1747    * \tparam T - the type of the value.
1748    * \param domain - zero terminated utf-8 string. Domain of the type.
1749    * \param type_name - zero terminated utf-8 string. Name of the type.
1750    * \param value - the value to be wrapped.
1751    */
1752   template <typename T>
1753   static Value CreateOpaque(const char* domain, const char* type_name, const T& value);  ///< Wraps OrtApi::CreateOpaqueValue
1754 
1755 #if !defined(DISABLE_SPARSE_TENSORS)
1756   /// <summary>
1757   /// This is a simple forwarding method to the other overload that helps deducing
1758   /// data type enum value from the type of the buffer.
1759   /// </summary>
1760   /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1761   /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1762   /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1763   /// <param name="dense_shape">a would be dense shape of the tensor</param>
1764   /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1765   /// <returns></returns>
1766   template <typename T>
1767   static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1768                                   const Shape& values_shape);
1769 
1770   /// <summary>
1771   /// Creates an OrtValue instance containing SparseTensor. This constructs
1772   /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1773   /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1774   /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1775   /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1776   /// to supply a sparse format specific indices.
1777   /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1778   /// can be properly copied into the allocated buffer.
1779   /// </summary>
1780   /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1781   /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1782   /// <param name="dense_shape">a would be dense shape of the tensor</param>
1783   /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1784   /// <param name="type">data type</param>
1785   /// <returns>Ort::Value instance containing SparseTensor</returns>
1786   static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1787                                   const Shape& values_shape, ONNXTensorElementDataType type);
1788 
1789   /// <summary>
1790   /// This is a simple forwarding method to the below CreateSparseTensor.
1791   /// This helps to specify data type enum in terms of C++ data type.
1792   /// Use CreateSparseTensor<T>
1793   /// </summary>
1794   /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1795   /// <param name="allocator">allocator to use</param>
1796   /// <param name="dense_shape">a would be dense shape of the tensor</param>
1797   /// <returns>Ort::Value</returns>
1798   template <typename T>
1799   static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1800 
1801   /// <summary>
1802   /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1803   /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1804   /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1805   /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1806   /// strings.
1807   /// </summary>
1808   /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1809   /// <param name="dense_shape">a would be dense shape of the tensor</param>
1810   /// <param name="type">data type</param>
1811   /// <returns>an instance of Ort::Value</returns>
1812   static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1813 
1814 #endif  // !defined(DISABLE_SPARSE_TENSORS)
1815 };
1816 
1817 /// <summary>
1818 /// Represents native memory allocation coming from one of the
1819 /// OrtAllocators registered with OnnxRuntime.
1820 /// Use it to wrap an allocation made by an allocator
1821 /// so it can be automatically released when no longer needed.
1822 /// </summary>
1823 struct MemoryAllocation {
1824   MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1825   ~MemoryAllocation();
1826   MemoryAllocation(const MemoryAllocation&) = delete;
1827   MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1828   MemoryAllocation(MemoryAllocation&&) noexcept;
1829   MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1830 
1831   void* get() { return p_; }
1832   size_t size() const { return size_; }
1833 
1834  private:
1835   OrtAllocator* allocator_;
1836   void* p_;
1837   size_t size_;
1838 };
1839 
1840 namespace detail {
1841 template <typename T>
1842 struct AllocatorImpl : Base<T> {
1843   using B = Base<T>;
1844   using B::B;
1845 
1846   void* Alloc(size_t size);
1847   MemoryAllocation GetAllocation(size_t size);
1848   void Free(void* p);
1849   ConstMemoryInfo GetInfo() const;
1850 };
1851 
1852 }  // namespace detail
1853 
1854 /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1855  *
1856  */
1857 struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1858   explicit AllocatorWithDefaultOptions(std::nullptr_t) {}  ///< Convenience to create a class member and then replace with an instance
1859   AllocatorWithDefaultOptions();
1860 };
1861 
1862 /** \brief Wrapper around ::OrtAllocator
1863  *
1864  */
1865 struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1866   explicit Allocator(std::nullptr_t) {}  ///< Convenience to create a class member and then replace with an instance
1867   Allocator(const Session& session, const OrtMemoryInfo*);
1868 };
1869 
1870 using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1871 
1872 namespace detail {
1873 namespace binding_utils {
1874 // Bring these out of template
1875 std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1876 std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1877 }  // namespace binding_utils
1878 
1879 template <typename T>
1880 struct ConstIoBindingImpl : Base<T> {
1881   using B = Base<T>;
1882   using B::B;
1883 
1884   std::vector<std::string> GetOutputNames() const;
1885   std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1886   std::vector<Value> GetOutputValues() const;
1887   std::vector<Value> GetOutputValues(OrtAllocator*) const;
1888 };
1889 
1890 template <typename T>
1891 struct IoBindingImpl : ConstIoBindingImpl<T> {
1892   using B = ConstIoBindingImpl<T>;
1893   using B::B;
1894 
1895   void BindInput(const char* name, const Value&);
1896   void BindOutput(const char* name, const Value&);
1897   void BindOutput(const char* name, const OrtMemoryInfo*);
1898   void ClearBoundInputs();
1899   void ClearBoundOutputs();
1900   void SynchronizeInputs();
1901   void SynchronizeOutputs();
1902 };
1903 
1904 }  // namespace detail
1905 
1906 using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
1907 using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
1908 
1909 /** \brief Wrapper around ::OrtIoBinding
1910  *
1911  */
1912 struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1913   explicit IoBinding(std::nullptr_t) {}  ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1914   explicit IoBinding(Session& session);
1915   ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1916   UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1917 };
1918 
1919 /*! \struct Ort::ArenaCfg
1920  * \brief it is a structure that represents the configuration of an arena based allocator
1921  * \details Please see docs/C_API.md for details
1922  */
1923 struct ArenaCfg : detail::Base<OrtArenaCfg> {
1924   explicit ArenaCfg(std::nullptr_t) {}  ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1925   /**
1926    * Wraps OrtApi::CreateArenaCfg
1927    * \param max_mem - use 0 to allow ORT to choose the default
1928    * \param arena_extend_strategy -  use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1929    * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1930    * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1931    * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1932    */
1933   ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1934 };
1935 
1936 //
1937 // Custom OPs (only needed to implement custom OPs)
1938 //
1939 
1940 /// <summary>
1941 /// This struct provides life time management for custom op attribute
1942 /// </summary>
1943 struct OpAttr : detail::Base<OrtOpAttr> {
1944   OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1945 };
1946 
1947 /**
1948  * Macro that logs a message using the provided logger. Throws an exception if OrtApi::Logger_LogMessage fails.
1949  * Example: ORT_CXX_LOG(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1950  *
1951  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1952  * \param message_severity The logging severity level of the message.
1953  * \param message A null-terminated UTF-8 message to log.
1954  */
1955 #define ORT_CXX_LOG(logger, message_severity, message)                                       \
1956   do {                                                                                       \
1957     if (message_severity >= logger.GetLoggingSeverityLevel()) {                              \
1958       Ort::ThrowOnError(logger.LogMessage(message_severity, ORT_FILE, __LINE__,              \
1959                                           static_cast<const char*>(__FUNCTION__), message)); \
1960     }                                                                                        \
1961   } while (false)
1962 
1963 /**
1964  * Macro that logs a message using the provided logger. Can be used in noexcept code since errors are silently ignored.
1965  * Example: ORT_CXX_LOG_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log a message");
1966  *
1967  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1968  * \param message_severity The logging severity level of the message.
1969  * \param message A null-terminated UTF-8 message to log.
1970  */
1971 #define ORT_CXX_LOG_NOEXCEPT(logger, message_severity, message)                              \
1972   do {                                                                                       \
1973     if (message_severity >= logger.GetLoggingSeverityLevel()) {                              \
1974       static_cast<void>(logger.LogMessage(message_severity, ORT_FILE, __LINE__,              \
1975                                           static_cast<const char*>(__FUNCTION__), message)); \
1976     }                                                                                        \
1977   } while (false)
1978 
1979 /**
1980  * Macro that logs a printf-like formatted message using the provided logger. Throws an exception if
1981  * OrtApi::Logger_LogMessage fails or if a formatting error occurs.
1982  * Example: ORT_CXX_LOGF(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
1983  *
1984  * \param logger The Ort::Logger instance to use. Must be a value or reference.
1985  * \param message_severity The logging severity level of the message.
1986  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
1987  *               Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
1988  * \param ... Zero or more variadic arguments referenced by the format string.
1989  */
1990 #define ORT_CXX_LOGF(logger, message_severity, /*format,*/...)                                            \
1991   do {                                                                                                    \
1992     if (message_severity >= logger.GetLoggingSeverityLevel()) {                                           \
1993       Ort::ThrowOnError(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__,                  \
1994                                                    static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
1995     }                                                                                                     \
1996   } while (false)
1997 
1998 /**
1999  * Macro that logs a printf-like formatted message using the provided logger. Can be used in noexcept code since errors
2000  * are silently ignored.
2001  * Example: ORT_CXX_LOGF_NOEXCEPT(logger, ORT_LOGGING_LEVEL_INFO, "Log an int: %d", 12);
2002  *
2003  * \param logger The Ort::Logger instance to use. Must be a value or reference.
2004  * \param message_severity The logging severity level of the message.
2005  * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
2006  *               Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
2007  * \param ... Zero or more variadic arguments referenced by the format string.
2008  */
2009 #define ORT_CXX_LOGF_NOEXCEPT(logger, message_severity, /*format,*/...)                                   \
2010   do {                                                                                                    \
2011     if (message_severity >= logger.GetLoggingSeverityLevel()) {                                           \
2012       static_cast<void>(logger.LogFormattedMessage(message_severity, ORT_FILE, __LINE__,                  \
2013                                                    static_cast<const char*>(__FUNCTION__), __VA_ARGS__)); \
2014     }                                                                                                     \
2015   } while (false)
2016 
2017 /// <summary>
2018 /// This class represents an ONNX Runtime logger that can be used to log information with an
2019 /// associated severity level and source code location (file path, line number, function name).
2020 ///
2021 /// A Logger can be obtained from within custom operators by calling Ort::KernelInfo::GetLogger().
2022 /// Instances of Ort::Logger are the size of two pointers and can be passed by value.
2023 ///
2024 /// Use the ORT_CXX_LOG macros to ensure the source code location is set properly from the callsite
2025 /// and to take advantage of a cached logging severity level that can bypass calls to the underlying C API.
2026 /// </summary>
2027 struct Logger {
2028   /**
2029    * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
2030    */
2031   Logger() = default;
2032 
2033   /**
2034    * Creates an empty Ort::Logger. Must be initialized from a valid Ort::Logger before use.
2035    */
2036   explicit Logger(std::nullptr_t) {}
2037 
2038   /**
2039    * Creates a logger from an ::OrtLogger instance. Caches the logger's current severity level by calling
2040    * OrtApi::Logger_GetLoggingSeverityLevel. Throws an exception if OrtApi::Logger_GetLoggingSeverityLevel fails.
2041    *
2042    * \param logger The ::OrtLogger to wrap.
2043    */
2044   explicit Logger(const OrtLogger* logger);
2045 
2046   ~Logger() = default;
2047 
2048   Logger(const Logger&) = default;
2049   Logger& operator=(const Logger&) = default;
2050 
2051   Logger(Logger&& v) noexcept = default;
2052   Logger& operator=(Logger&& v) noexcept = default;
2053 
2054   /**
2055    * Returns the logger's current severity level from the cached member.
2056    *
2057    * \return The current ::OrtLoggingLevel.
2058    */
2059   OrtLoggingLevel GetLoggingSeverityLevel() const noexcept;
2060 
2061   /**
2062    * Logs the provided message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOG or ORT_CXX_LOG_NOEXCEPT
2063    * macros to properly set the source code location and to use the cached severity level to potentially bypass
2064    * calls to the underlying C API.
2065    *
2066    * \param log_severity_level The message's logging severity level.
2067    * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2068    * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2069    * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2070    * \param message The message to log.
2071    * \return A Ort::Status value to indicate error or success.
2072    */
2073   Status LogMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2074                     const char* func_name, const char* message) const noexcept;
2075 
2076   /**
2077    * Logs a printf-like formatted message via OrtApi::Logger_LogMessage. Use the ORT_CXX_LOGF or ORT_CXX_LOGF_NOEXCEPT
2078    * macros to properly set the source code location and to use the cached severity level to potentially bypass
2079    * calls to the underlying C API. Returns an error status if a formatting error occurs.
2080    *
2081    * \param log_severity_level The message's logging severity level.
2082    * \param file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE.
2083    * \param line_number The file line number in which the message is logged. Usually the value of __LINE__.
2084    * \param func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__.
2085    * \param format A null-terminated UTF-8 format string forwarded to a printf-like function.
2086    *               Refer to https://en.cppreference.com/w/cpp/io/c/fprintf for information on valid formats.
2087    * \param args Zero or more variadic arguments referenced by the format string.
2088    * \return A Ort::Status value to indicate error or success.
2089    */
2090   template <typename... Args>
2091   Status LogFormattedMessage(OrtLoggingLevel log_severity_level, const ORTCHAR_T* file_path, int line_number,
2092                              const char* func_name, const char* format, Args&&... args) const noexcept;
2093 
2094  private:
2095   const OrtLogger* logger_{};
2096   OrtLoggingLevel cached_severity_level_{};
2097 };
2098 
2099 /// <summary>
2100 /// This class wraps a raw pointer OrtKernelContext* that is being passed
2101 /// to the custom kernel Compute() method. Use it to safely access context
2102 /// attributes, input and output parameters with exception safety guarantees.
2103 /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
2104 /// </summary>
2105 struct KernelContext {
2106   explicit KernelContext(OrtKernelContext* context);
2107   size_t GetInputCount() const;
2108   size_t GetOutputCount() const;
2109   // If input is optional and is not present, the method returns en empty ConstValue
2110   // which can be compared to nullptr.
2111   ConstValue GetInput(size_t index) const;
2112   // If outout is optional and is not present, the method returns en empty UnownedValue
2113   // which can be compared to nullptr.
2114   UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
2115   UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
2116   void* GetGPUComputeStream() const;
2117   Logger GetLogger() const;
2118   OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
2119   OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
2120   void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;
2121 
2122  private:
2123   OrtKernelContext* ctx_;
2124 };
2125 
2126 struct KernelInfo;
2127 
2128 namespace detail {
2129 namespace attr_utils {
2130 void GetAttr(const OrtKernelInfo* p, const char* name, float&);
2131 void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
2132 void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
2133 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
2134 void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
2135 }  // namespace attr_utils
2136 
2137 template <typename T>
2138 struct KernelInfoImpl : Base<T> {
2139   using B = Base<T>;
2140   using B::B;
2141 
2142   KernelInfo Copy() const;
2143 
2144   template <typename R>  // R is only implemented for float, int64_t, and string
2145   R GetAttribute(const char* name) const {
2146     R val;
2147     attr_utils::GetAttr(this->p_, name, val);
2148     return val;
2149   }
2150 
2151   template <typename R>  // R is only implemented for std::vector<float>, std::vector<int64_t>
2152   std::vector<R> GetAttributes(const char* name) const {
2153     std::vector<R> result;
2154     attr_utils::GetAttrs(this->p_, name, result);
2155     return result;
2156   }
2157 
2158   Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
2159 
2160   size_t GetInputCount() const;
2161   size_t GetOutputCount() const;
2162 
2163   std::string GetInputName(size_t index) const;
2164   std::string GetOutputName(size_t index) const;
2165 
2166   TypeInfo GetInputTypeInfo(size_t index) const;
2167   TypeInfo GetOutputTypeInfo(size_t index) const;
2168 
2169   ConstValue GetTensorConstantInput(size_t index, int* is_constant) const;
2170 
2171   std::string GetNodeName() const;
2172   Logger GetLogger() const;
2173 };
2174 
2175 }  // namespace detail
2176 
2177 using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
2178 
2179 /// <summary>
2180 /// This struct owns the OrtKernInfo* pointer when a copy is made.
2181 /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
2182 /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
2183 /// so it does not destroy the pointer the kernel does not own.
2184 /// </summary>
2185 struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
2186   explicit KernelInfo(std::nullptr_t) {}     ///< Create an empty instance to initialize later
2187   explicit KernelInfo(OrtKernelInfo* info);  ///< Take ownership of the instance
2188   ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
2189 };
2190 
2191 /// <summary>
2192 /// Create and own custom defined operation.
2193 /// </summary>
2194 struct Op : detail::Base<OrtOp> {
2195   explicit Op(std::nullptr_t) {}  ///< Create an empty Operator object, must be assigned a valid one to be used
2196 
2197   explicit Op(OrtOp*);  ///< Take ownership of the OrtOp
2198 
2199   static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
2200                    int version, const char** type_constraint_names,
2201                    const ONNXTensorElementDataType* type_constraint_values,
2202                    size_t type_constraint_count,
2203                    const OpAttr* attr_values,
2204                    size_t attr_count,
2205                    size_t input_count, size_t output_count);
2206 
2207   void Invoke(const OrtKernelContext* context,
2208               const Value* input_values,
2209               size_t input_count,
2210               Value* output_values,
2211               size_t output_count);
2212 
2213   // For easier refactoring
2214   void Invoke(const OrtKernelContext* context,
2215               const OrtValue* const* input_values,
2216               size_t input_count,
2217               OrtValue* const* output_values,
2218               size_t output_count);
2219 };
2220 
2221 /// <summary>
2222 /// Provide access to per-node attributes and input shapes, so one could compute and set output shapes.
2223 /// </summary>
2224 struct ShapeInferContext {
2225   struct SymbolicInteger {
2226     SymbolicInteger(int64_t i) : i_(i), is_int_(true) {};
2227     SymbolicInteger(const char* s) : s_(s), is_int_(false) {};
2228     SymbolicInteger(const SymbolicInteger&) = default;
2229     SymbolicInteger(SymbolicInteger&&) = default;
2230 
2231     SymbolicInteger& operator=(const SymbolicInteger&) = default;
2232     SymbolicInteger& operator=(SymbolicInteger&&) = default;
2233 
2234     bool operator==(const SymbolicInteger& dim) const {
2235       if (is_int_ == dim.is_int_) {
2236         if (is_int_) {
2237           return i_ == dim.i_;
2238         } else {
2239           return std::string{s_} == std::string{dim.s_};
2240         }
2241       }
2242       return false;
2243     }
2244 
2245     bool IsInt() const { return is_int_; }
2246     int64_t AsInt() const { return i_; }
2247     const char* AsSym() const { return s_; }
2248 
2249     static constexpr int INVALID_INT_DIM = -2;
2250 
2251    private:
2252     union {
2253       int64_t i_;
2254       const char* s_;
2255     };
2256     bool is_int_;
2257   };
2258 
2259   using Shape = std::vector<SymbolicInteger>;
2260 
2261   ShapeInferContext(const OrtApi* ort_api, OrtShapeInferContext* ctx);
2262 
2263   const Shape& GetInputShape(size_t indice) const { return input_shapes_.at(indice); }
2264 
2265   size_t GetInputCount() const { return input_shapes_.size(); }
2266 
2267   Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
2268 
2269   int64_t GetAttrInt(const char* attr_name);
2270 
2271   using Ints = std::vector<int64_t>;
2272   Ints GetAttrInts(const char* attr_name);
2273 
2274   float GetAttrFloat(const char* attr_name);
2275 
2276   using Floats = std::vector<float>;
2277   Floats GetAttrFloats(const char* attr_name);
2278 
2279   std::string GetAttrString(const char* attr_name);
2280 
2281   using Strings = std::vector<std::string>;
2282   Strings GetAttrStrings(const char* attr_name);
2283 
2284  private:
2285   const OrtOpAttr* GetAttrHdl(const char* attr_name) const;
2286   const OrtApi* ort_api_;
2287   OrtShapeInferContext* ctx_;
2288   std::vector<Shape> input_shapes_;
2289 };
2290 
2291 using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);
2292 
2293 #define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1
2294 
2295 template <typename TOp, typename TKernel, bool WithStatus = false>
2296 struct CustomOpBase : OrtCustomOp {
2297   CustomOpBase() {
2298     OrtCustomOp::version = ORT_API_VERSION;
2299     OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
2300 
2301     OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
2302 
2303     OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
2304     OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
2305     OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
2306 
2307     OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
2308     OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
2309 
2310 #if defined(_MSC_VER) && !defined(__clang__)
2311 #pragma warning(push)
2312 #pragma warning(disable : 26409)
2313 #endif
2314     OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
2315 #if defined(_MSC_VER) && !defined(__clang__)
2316 #pragma warning(pop)
2317 #endif
2318     OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
2319     OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
2320 
2321     OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
2322     OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
2323     OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
2324     OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
2325 #ifdef __cpp_if_constexpr
2326     if constexpr (WithStatus) {
2327 #else
2328     if (WithStatus) {
2329 #endif
2330       OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
2331         return static_cast<const TOp*>(this_)->CreateKernelV2(*api, info, op_kernel);
2332       };
2333       OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
2334         return static_cast<TKernel*>(op_kernel)->ComputeV2(context);
2335       };
2336     } else {
2337       OrtCustomOp::CreateKernelV2 = nullptr;
2338       OrtCustomOp::KernelComputeV2 = nullptr;
2339 
2340       OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
2341       OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
2342         static_cast<TKernel*>(op_kernel)->Compute(context);
2343       };
2344     }
2345 
2346     SetShapeInferFn<TOp>(0);
2347 
2348     OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
2349       return static_cast<const TOp*>(this_)->start_ver_;
2350     };
2351 
2352     OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
2353       return static_cast<const TOp*>(this_)->end_ver_;
2354     };
2355 
2356     OrtCustomOp::GetMayInplace = nullptr;
2357     OrtCustomOp::ReleaseMayInplace = nullptr;
2358     OrtCustomOp::GetAliasMap = nullptr;
2359     OrtCustomOp::ReleaseAliasMap = nullptr;
2360   }
2361 
2362   // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
2363   const char* GetExecutionProviderType() const { return nullptr; }
2364 
2365   // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
2366   // (inputs and outputs are required by default)
2367   OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
2368     return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2369   }
2370 
2371   OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
2372     return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
2373   }
2374 
2375   // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
2376   OrtMemType GetInputMemoryType(size_t /*index*/) const {
2377     return OrtMemTypeDefault;
2378   }
2379 
2380   // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
2381   // should expect at least 1 argument.
2382   int GetVariadicInputMinArity() const {
2383     return 1;
2384   }
2385 
2386   // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
2387   // to a variadic input should be of the same type.
2388   bool GetVariadicInputHomogeneity() const {
2389     return true;
2390   }
2391 
2392   // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
2393   // should produce at least 1 output value.
2394   int GetVariadicOutputMinArity() const {
2395     return 1;
2396   }
2397 
2398   // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
2399   // produced by a variadic output should be of the same type.
2400   bool GetVariadicOutputHomogeneity() const {
2401     return true;
2402   }
2403 
2404   // Declare list of session config entries used by this Custom Op.
2405   // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
2406   // This default implementation returns an empty vector of config entries.
2407   std::vector<std::string> GetSessionConfigKeys() const {
2408     return std::vector<std::string>{};
2409   }
2410 
2411   template <typename C>
2412   decltype(&C::InferOutputShape) SetShapeInferFn(decltype(&C::InferOutputShape)) {
2413     OrtCustomOp::InferOutputShapeFn = [](const OrtCustomOp*, OrtShapeInferContext* ort_ctx) -> OrtStatusPtr {
2414       ShapeInferContext ctx(&GetApi(), ort_ctx);
2415       return C::InferOutputShape(ctx);
2416     };
2417     return {};
2418   }
2419 
2420   template <typename C>
2421   void SetShapeInferFn(...) {
2422     OrtCustomOp::InferOutputShapeFn = {};
2423   }
2424 
2425  protected:
2426   // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
2427   void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
2428 
2429   int start_ver_ = 1;
2430   int end_ver_ = MAX_CUSTOM_OP_END_VER;
2431 };
2432 
2433 }  // namespace Ort
2434 
2435 #include "onnxruntime_cxx_inline.h"