Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:43:10

0001 //===- TFUtils.h - utilities for TFLite -------------------------*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 #ifndef LLVM_ANALYSIS_UTILS_TFUTILS_H
0010 #define LLVM_ANALYSIS_UTILS_TFUTILS_H
0011 
0012 #include "llvm/Config/llvm-config.h"
0013 
0014 #ifdef LLVM_HAVE_TFLITE
0015 #include "llvm/ADT/StringMap.h"
0016 #include "llvm/Analysis/TensorSpec.h"
0017 #include "llvm/IR/LLVMContext.h"
0018 #include "llvm/Support/JSON.h"
0019 
0020 #include <memory>
0021 #include <vector>
0022 
0023 namespace llvm {
0024 
0025 /// Load a SavedModel, find the given inputs and outputs, and setup storage
0026 /// for input tensors. The user is responsible for correctly dimensioning the
0027 /// input tensors and setting their values before calling evaluate().
0028 /// To initialize:
0029 /// - construct the object
0030 /// - initialize the input tensors using initInput. Indices must correspond to
0031 ///   indices in the InputNames used at construction.
0032 /// To use:
0033 /// - set input values by using getInput to get each input tensor, and then
0034 ///   setting internal scalars, for all dimensions (tensors are row-major:
0035 ///   https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/c/c_api.h#L205)
0036 /// - call evaluate. The input tensors' values are not consumed after this, and
0037 ///   may still be read.
0038 /// - use the outputs in the output vector
0039 class TFModelEvaluatorImpl;
0040 class EvaluationResultImpl;
0041 
0042 class TFModelEvaluator final {
0043 public:
0044   /// The result of a model evaluation. Handles the lifetime of the output
0045   /// tensors, which means that their values need to be used before
0046   /// the EvaluationResult's dtor is called.
0047   class EvaluationResult {
0048   public:
0049     EvaluationResult(const EvaluationResult &) = delete;
0050     EvaluationResult &operator=(const EvaluationResult &Other) = delete;
0051 
0052     EvaluationResult(EvaluationResult &&Other);
0053     EvaluationResult &operator=(EvaluationResult &&Other);
0054 
0055     ~EvaluationResult();
0056 
0057     /// Get a (const) pointer to the first element of the tensor at Index.
0058     template <typename T> T *getTensorValue(size_t Index) {
0059       return static_cast<T *>(getUntypedTensorValue(Index));
0060     }
0061 
0062     template <typename T> const T *getTensorValue(size_t Index) const {
0063       return static_cast<T *>(getUntypedTensorValue(Index));
0064     }
0065 
0066     /// Get a (const) pointer to the untyped data of the tensor.
0067     void *getUntypedTensorValue(size_t Index);
0068     const void *getUntypedTensorValue(size_t Index) const;
0069 
0070   private:
0071     friend class TFModelEvaluator;
0072     EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl);
0073     std::unique_ptr<EvaluationResultImpl> Impl;
0074   };
0075 
0076   TFModelEvaluator(StringRef SavedModelPath,
0077                    const std::vector<TensorSpec> &InputSpecs,
0078                    const std::vector<TensorSpec> &OutputSpecs,
0079                    const char *Tags = "serve");
0080 
0081   ~TFModelEvaluator();
0082   TFModelEvaluator(const TFModelEvaluator &) = delete;
0083   TFModelEvaluator(TFModelEvaluator &&) = delete;
0084 
0085   /// Evaluate the model, assuming it is valid. Returns std::nullopt if the
0086   /// evaluation fails or the model is invalid, or an EvaluationResult
0087   /// otherwise. The inputs are assumed to have been already provided via
0088   /// getInput(). When returning std::nullopt, it also invalidates this object.
0089   std::optional<EvaluationResult> evaluate();
0090 
0091   /// Provides access to the input vector.
0092   template <typename T> T *getInput(size_t Index) {
0093     return static_cast<T *>(getUntypedInput(Index));
0094   }
0095 
0096   /// Returns true if the model was loaded successfully, false
0097   /// otherwise.
0098   bool isValid() const { return !!Impl; }
0099 
0100   /// Untyped access to input.
0101   void *getUntypedInput(size_t Index);
0102 
0103 private:
0104   std::unique_ptr<TFModelEvaluatorImpl> Impl;
0105 };
0106 
0107 } // namespace llvm
0108 
0109 #endif // LLVM_HAVE_TFLITE
0110 #endif // LLVM_ANALYSIS_UTILS_TFUTILS_H