File indexing completed on 2026-05-10 08:43:15
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
0011 #define LLVM_ANALYSIS_MLMODELRUNNER_H
0012
0013 #include "llvm/Analysis/TensorSpec.h"
0014 #include "llvm/IR/PassManager.h"
0015
0016 namespace llvm {
0017 class LLVMContext;
0018
0019
0020
0021
0022
0023
0024
0025
0026 class MLModelRunner {
0027 public:
0028
0029 MLModelRunner(const MLModelRunner &) = delete;
0030 MLModelRunner &operator=(const MLModelRunner &) = delete;
0031 virtual ~MLModelRunner() = default;
0032
0033 template <typename T> T evaluate() {
0034 return *reinterpret_cast<T *>(evaluateUntyped());
0035 }
0036
0037 template <typename T, typename I> T *getTensor(I FeatureID) {
0038 return reinterpret_cast<T *>(
0039 getTensorUntyped(static_cast<size_t>(FeatureID)));
0040 }
0041
0042 template <typename T, typename I> const T *getTensor(I FeatureID) const {
0043 return reinterpret_cast<const T *>(
0044 getTensorUntyped(static_cast<size_t>(FeatureID)));
0045 }
0046
0047 void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; }
0048 const void *getTensorUntyped(size_t Index) const {
0049 return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index);
0050 }
0051
0052 enum class Kind : int { Unknown, Release, Development, NoOp, Interactive };
0053 Kind getKind() const { return Type; }
0054 virtual void switchContext(StringRef Name) {}
0055
0056 protected:
0057 MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NumInputs)
0058 : Ctx(Ctx), Type(Type), InputBuffers(NumInputs) {
0059 assert(Type != Kind::Unknown);
0060 }
0061 virtual void *evaluateUntyped() = 0;
0062
0063 void setUpBufferForTensor(size_t Index, const TensorSpec &Spec,
0064 void *Buffer) {
0065 if (!Buffer) {
0066 OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize());
0067 Buffer = OwnedBuffers.back().data();
0068 }
0069 InputBuffers[Index] = Buffer;
0070 }
0071
0072 LLVMContext &Ctx;
0073 const Kind Type;
0074
0075 private:
0076 std::vector<void *> InputBuffers;
0077 std::vector<std::vector<char *>> OwnedBuffers;
0078 };
0079 }
0080
0081 #endif