Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-10 08:43:15

0001 //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
0002 //
0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
0004 // See https://llvm.org/LICENSE.txt for license information.
0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
0006 //
0007 //===----------------------------------------------------------------------===//
0008 //
0009 
0010 #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
0011 #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
0012 
0013 #include "llvm/ADT/STLExtras.h"
0014 #include "llvm/ADT/iterator_range.h"
0015 #include "llvm/Analysis/TensorSpec.h"
0016 #include "llvm/Config/llvm-config.h"
0017 
0018 #ifdef LLVM_HAVE_TFLITE
0019 #include "llvm/Analysis/MLModelRunner.h"
0020 #include "llvm/Analysis/Utils/TFUtils.h"
0021 #include "llvm/IR/LLVMContext.h"
0022 #include "llvm/IR/PassManager.h"
0023 
0024 namespace llvm {
0025 
0026 /// ModelUnderTrainingRunner - training mode implementation. It uses TFLite
0027 /// to dynamically load and evaluate a TF SavedModel
0028 /// (https://www.tensorflow.org/guide/saved_model) converted to TFLite. see
0029 /// lib/Analysis/models/saved-model-to-tflite.py. Runtime performance is
0030 /// sacrificed for ease of use while training.
0031 class ModelUnderTrainingRunner final : public MLModelRunner {
0032 public:
0033   // Disallows copy and assign.
0034   ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
0035   ModelUnderTrainingRunner &
0036   operator=(const ModelUnderTrainingRunner &) = delete;
0037 
0038   const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const {
0039     return ExtraOutputsForLogging;
0040   }
0041 
0042   const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const {
0043     return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
0044   }
0045 
0046   const std::optional<TFModelEvaluator::EvaluationResult> &
0047   lastEvaluationResult() const {
0048     return LastEvaluationResult;
0049   }
0050   static bool classof(const MLModelRunner *R) {
0051     return R->getKind() == MLModelRunner::Kind::Development;
0052   }
0053 
0054   static std::unique_ptr<ModelUnderTrainingRunner>
0055   createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
0056                        StringRef DecisionName,
0057                        const std::vector<TensorSpec> &InputSpecs,
0058                        StringRef OutputSpecsPathOverride = "");
0059 
0060   ModelUnderTrainingRunner(
0061       LLVMContext &Ctx, const std::string &ModelPath,
0062       const std::vector<TensorSpec> &InputSpecs,
0063       const std::vector<TensorSpec> &OutputSpecs,
0064       const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
0065 
0066   bool isValid() const { return !!Evaluator; }
0067 
0068 private:
0069   std::unique_ptr<TFModelEvaluator> Evaluator;
0070   const std::vector<TensorSpec> OutputSpecs;
0071   const std::vector<TensorSpec> ExtraOutputsForLogging;
0072   std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
0073   void *evaluateUntyped() override;
0074 };
0075 
0076 } // namespace llvm
0077 #endif // define(LLVM_HAVE_TFLITE)
0078 #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H