|
|
|||
File indexing completed on 2026-05-10 08:43:10
0001 //===- TrainingLogger.h - mlgo feature/reward logging ----------*- C++ -*-===// 0002 // 0003 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 0004 // See https://llvm.org/LICENSE.txt for license information. 0005 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 0006 // 0007 //===----------------------------------------------------------------------===// 0008 // 0009 // The design goals of the logger are: 0010 // - no dependencies that llvm doesn't already have. 0011 // - support streaming, so that we don't need to buffer data during compilation 0012 // - 0-decoding tensor values. Tensor values are potentially very large buffers 0013 // of scalars. Because of their potentially large size, avoiding 0014 // serialization/deserialization overhead is preferred. 0015 // 0016 // The simple logger produces an output of the form (each line item on its line) 0017 // - header: a json object describing the data that will follow. 0018 // - context: e.g. function name, for regalloc, or "default" for module-wide 0019 // optimizations like the inliner. This is the context to which the subsequent 0020 // data corresponds. 0021 // - observation number. 0022 // - tensor values - raw bytes of the tensors, in the order given in the header. 0023 // The values are in succession, i.e. no separator is found between successive 0024 // tensor values. At the end, there is a new line character. 0025 // - [score] - this is optional, and is present if it was present in the header. 0026 // Currently, for final rewards, we output "0" scores after each observation, 0027 // except for the last one. 0028 // <repeat> 0029 // The file should be read as binary, but the reason we use newlines is mostly 0030 // ease of debugging: the log can be opened in a text editor and, while tensor 0031 // values are inscrutable, at least the sequence of data can be easily observed. 0032 // Of course, the buffer of tensor values could contain '\n' bytes. A reader 0033 // should use the header information to know how much data to read for the 0034 // tensor values, and not use line information for that. 0035 // 0036 // An example reader, used for test, is available at 0037 // Analysis/models/log_reader.py 0038 // 0039 // Example: 0040 // {"features":[list of TensorSpecs], "score":<a tensor spec>} 0041 // {"context": "aFunction"} 0042 // {"observation": 0} 0043 // <bytes> 0044 // {"outcome": 0} 0045 // <bytes for the tensor corresponding to the "score" spec in the header> 0046 // {"observation": 1} 0047 // ... 0048 // {"context": "anotherFunction"} 0049 // {"observation": 0} 0050 // ... 0051 // 0052 0053 #ifndef LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H 0054 #define LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H 0055 0056 #include "llvm/Config/llvm-config.h" 0057 0058 #include "llvm/ADT/StringMap.h" 0059 #include "llvm/Analysis/TensorSpec.h" 0060 #include "llvm/IR/LLVMContext.h" 0061 #include "llvm/Support/JSON.h" 0062 0063 #include <memory> 0064 #include <optional> 0065 #include <vector> 0066 0067 namespace llvm { 0068 0069 /// Logging utility - given an ordered specification of features, and assuming 0070 /// a scalar reward, allow logging feature values and rewards. 0071 /// The assumption is that, for an event to be logged (i.e. a set of feature 0072 /// values and a reward), the user calls the log* API for each feature exactly 0073 /// once, providing the index matching the position in the feature spec list 0074 /// provided at construction. The example assumes the first feature's element 0075 /// type is float, the second is int64, and the reward is float: 0076 /// 0077 /// event 0: 0078 /// logFloatValue(0, ...) 0079 /// logInt64Value(1, ...) 0080 /// ... 0081 /// logFloatReward(...) 0082 /// event 1: 0083 /// logFloatValue(0, ...) 0084 /// logInt64Value(1, ...) 0085 /// ... 0086 /// logFloatReward(...) 0087 /// 0088 /// At the end, call print to generate the log. 0089 /// Alternatively, don't call logReward at the end of each event, just 0090 /// log{Float|Int32|Int64}FinalReward at the end. 0091 class Logger final { 0092 std::unique_ptr<raw_ostream> OS; 0093 const std::vector<TensorSpec> FeatureSpecs; 0094 const TensorSpec RewardSpec; 0095 const bool IncludeReward; 0096 StringMap<size_t> ObservationIDs; 0097 std::string CurrentContext; 0098 0099 void writeHeader(std::optional<TensorSpec> AdviceSpec); 0100 void writeTensor(const TensorSpec &Spec, const char *RawData) { 0101 OS->write(RawData, Spec.getTotalTensorBufferSize()); 0102 } 0103 void logRewardImpl(const char *RawData); 0104 0105 public: 0106 /// Construct a Logger. If IncludeReward is false, then logReward or 0107 /// logFinalReward shouldn't be called, and the reward feature won't be 0108 /// printed out. 0109 /// NOTE: the FeatureSpecs are expected to be in the same order (i.e. have 0110 /// corresponding indices) with any MLModelRunner implementations 0111 /// corresponding to the model being trained/logged. 0112 Logger(std::unique_ptr<raw_ostream> OS, 0113 const std::vector<TensorSpec> &FeatureSpecs, 0114 const TensorSpec &RewardSpec, bool IncludeReward, 0115 std::optional<TensorSpec> AdviceSpec = std::nullopt); 0116 0117 void switchContext(StringRef Name); 0118 void startObservation(); 0119 void endObservation(); 0120 void flush() { OS->flush(); } 0121 0122 const std::string ¤tContext() const { return CurrentContext; } 0123 0124 /// Check if there is at least an observation for `currentContext()`. 0125 bool hasObservationInProgress() const { 0126 return hasAnyObservationForContext(CurrentContext); 0127 } 0128 0129 /// Check if there is at least an observation for the context `Ctx`. 0130 bool hasAnyObservationForContext(StringRef Ctx) const { 0131 return ObservationIDs.contains(Ctx); 0132 } 0133 0134 template <typename T> void logReward(T Value) { 0135 logRewardImpl(reinterpret_cast<const char *>(&Value)); 0136 } 0137 0138 void logTensorValue(size_t FeatureID, const char *RawData) { 0139 writeTensor(FeatureSpecs[FeatureID], RawData); 0140 } 0141 }; 0142 0143 } // namespace llvm 0144 #endif // LLVM_ANALYSIS_UTILS_TRAININGLOGGER_H
| [ Source navigation ] | [ Diff markup ] | [ Identifier search ] | [ general search ] |
|
This page was automatically generated by the 2.3.7 LXR engine. The LXR team |
|