Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:43:15

0001 //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 
0010 #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H
0011 #define LLVM_ANALYSIS_MLMODELRUNNER_H
0012 
0013 #include "llvm/Analysis/TensorSpec.h"
0014 #include "llvm/IR/PassManager.h"
0015 
0016 namespace llvm {
0017 class LLVMContext;
0018 
0019 /// MLModelRunner interface: abstraction of a mechanism for evaluating a
0020 /// ML model. More abstractly, evaluating a function that has as tensors as
0021 /// arguments, described via TensorSpecs, and returns a tensor. Currently, the
0022 /// latter is assumed to be a scalar, in absence of more elaborate scenarios.
0023 /// NOTE: feature indices are expected to be consistent all accross
0024 /// MLModelRunners (pertaining to the same model), and also Loggers (see
0025 /// TFUtils.h)
0026 class MLModelRunner {
0027 public:
0028   // Disallows copy and assign.
0029   MLModelRunner(const MLModelRunner &) = delete;
0030   MLModelRunner &operator=(const MLModelRunner &) = delete;
0031   virtual ~MLModelRunner() = default;
0032 
0033   template <typename T> T evaluate() {
0034     return *reinterpret_cast<T *>(evaluateUntyped());
0035   }
0036 
0037   template <typename T, typename I> T *getTensor(I FeatureID) {
0038     return reinterpret_cast<T *>(
0039         getTensorUntyped(static_cast<size_t>(FeatureID)));
0040   }
0041 
0042   template <typename T, typename I> const T *getTensor(I FeatureID) const {
0043     return reinterpret_cast<const T *>(
0044         getTensorUntyped(static_cast<size_t>(FeatureID)));
0045   }
0046 
0047   void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; }
0048   const void *getTensorUntyped(size_t Index) const {
0049     return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index);
0050   }
0051 
0052   enum class Kind : int { Unknown, Release, Development, NoOp, Interactive };
0053   Kind getKind() const { return Type; }
0054   virtual void switchContext(StringRef Name) {}
0055 
0056 protected:
0057   MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NumInputs)
0058       : Ctx(Ctx), Type(Type), InputBuffers(NumInputs) {
0059     assert(Type != Kind::Unknown);
0060   }
0061   virtual void *evaluateUntyped() = 0;
0062 
0063   void setUpBufferForTensor(size_t Index, const TensorSpec &Spec,
0064                             void *Buffer) {
0065     if (!Buffer) {
0066       OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize());
0067       Buffer = OwnedBuffers.back().data();
0068     }
0069     InputBuffers[Index] = Buffer;
0070   }
0071 
0072   LLVMContext &Ctx;
0073   const Kind Type;
0074 
0075 private:
0076   std::vector<void *> InputBuffers;
0077   std::vector<std::vector<char *>> OwnedBuffers;
0078 };
0079 } // namespace llvm
0080 
0081 #endif // LLVM_ANALYSIS_MLMODELRUNNER_H