File indexing completed on 2026-05-10 08:43:10
0001
0002
0003
0004
0005
0006
0007
0008
0009 #ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H
0010 #define LLVM_ANALYSIS_UTILS_TFUTILS_H
0011
0012 #include "llvm/Config/llvm-config.h"
0013
0014 #ifdef LLVM_HAVE_TFLITE
0015 #include "llvm/ADT/StringMap.h"
0016 #include "llvm/Analysis/TensorSpec.h"
0017 #include "llvm/IR/LLVMContext.h"
0018 #include "llvm/Support/JSON.h"
0019
0020 #include <memory>
0021 #include <vector>
0022
0023 namespace llvm {
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039 class TFModelEvaluatorImpl;
0040 class EvaluationResultImpl;
0041
0042 class TFModelEvaluator final {
0043 public:
0044
0045
0046
0047 class EvaluationResult {
0048 public:
0049 EvaluationResult(const EvaluationResult &) = delete;
0050 EvaluationResult &operator=(const EvaluationResult &Other) = delete;
0051
0052 EvaluationResult(EvaluationResult &&Other);
0053 EvaluationResult &operator=(EvaluationResult &&Other);
0054
0055 ~EvaluationResult();
0056
0057
0058 template <typename T> T *getTensorValue(size_t Index) {
0059 return static_cast<T *>(getUntypedTensorValue(Index));
0060 }
0061
0062 template <typename T> const T *getTensorValue(size_t Index) const {
0063 return static_cast<T *>(getUntypedTensorValue(Index));
0064 }
0065
0066
0067 void *getUntypedTensorValue(size_t Index);
0068 const void *getUntypedTensorValue(size_t Index) const;
0069
0070 private:
0071 friend class TFModelEvaluator;
0072 EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl);
0073 std::unique_ptr<EvaluationResultImpl> Impl;
0074 };
0075
0076 TFModelEvaluator(StringRef SavedModelPath,
0077 const std::vector<TensorSpec> &InputSpecs,
0078 const std::vector<TensorSpec> &OutputSpecs,
0079 const char *Tags = "serve");
0080
0081 ~TFModelEvaluator();
0082 TFModelEvaluator(const TFModelEvaluator &) = delete;
0083 TFModelEvaluator(TFModelEvaluator &&) = delete;
0084
0085
0086
0087
0088
0089 std::optional<EvaluationResult> evaluate();
0090
0091
0092 template <typename T> T *getInput(size_t Index) {
0093 return static_cast<T *>(getUntypedInput(Index));
0094 }
0095
0096
0097
0098 bool isValid() const { return !!Impl; }
0099
0100
0101 void *getUntypedInput(size_t Index);
0102
0103 private:
0104 std::unique_ptr<TFModelEvaluatorImpl> Impl;
0105 };
0106
0107 }
0108
0109 #endif
0110 #endif