File indexing completed on 2026-05-10 08:43:19
0001
0002
0003
0004
0005
0006
0007
0008
0009 #ifndef LLVM_ANALYSIS_TENSORSPEC_H
0010 #define LLVM_ANALYSIS_TENSORSPEC_H
0011
0012 #include "llvm/Config/llvm-config.h"
0013
0014 #include "llvm/ADT/StringMap.h"
0015 #include "llvm/IR/LLVMContext.h"
0016
0017 #include <memory>
0018 #include <optional>
0019 #include <vector>
0020
0021 namespace llvm {
0022 namespace json {
0023 class OStream;
0024 class Value;
0025 }
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042 #define SUPPORTED_TENSOR_TYPES(M) \
0043 M(float, Float) \
0044 M(double, Double) \
0045 M(int8_t, Int8) \
0046 M(uint8_t, UInt8) \
0047 M(int16_t, Int16) \
0048 M(uint16_t, UInt16) \
0049 M(int32_t, Int32) \
0050 M(uint32_t, UInt32) \
0051 M(int64_t, Int64) \
0052 M(uint64_t, UInt64)
0053
0054 enum class TensorType {
0055 Invalid,
0056 #define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name,
0057 SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS)
0058 #undef _TENSOR_TYPE_ENUM_MEMBERS
0059 Total
0060 };
0061
0062 class TensorSpec final {
0063 public:
0064 template <typename T>
0065 static TensorSpec createSpec(const std::string &Name,
0066 const std::vector<int64_t> &Shape,
0067 int Port = 0) {
0068 return TensorSpec(Name, Port, getDataType<T>(), sizeof(T), Shape);
0069 }
0070
0071 const std::string &name() const { return Name; }
0072 int port() const { return Port; }
0073 TensorType type() const { return Type; }
0074 const std::vector<int64_t> &shape() const { return Shape; }
0075
0076 bool operator==(const TensorSpec &Other) const {
0077 return Name == Other.Name && Port == Other.Port && Type == Other.Type &&
0078 Shape == Other.Shape;
0079 }
0080
0081 bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
0082
0083
0084 size_t getElementCount() const { return ElementCount; }
0085
0086 size_t getElementByteSize() const { return ElementSize; }
0087
0088 size_t getTotalTensorBufferSize() const { return ElementCount * ElementSize; }
0089
0090 template <typename T> bool isElementType() const {
0091 return getDataType<T>() == Type;
0092 }
0093
0094 TensorSpec(const std::string &NewName, const TensorSpec &Other)
0095 : TensorSpec(NewName, Other.Port, Other.Type, Other.ElementSize,
0096 Other.Shape) {}
0097
0098 void toJSON(json::OStream &OS) const;
0099
0100 private:
0101 TensorSpec(const std::string &Name, int Port, TensorType Type,
0102 size_t ElementSize, const std::vector<int64_t> &Shape);
0103
0104 template <typename T> static TensorType getDataType();
0105
0106 std::string Name;
0107 int Port = 0;
0108 TensorType Type = TensorType::Invalid;
0109 std::vector<int64_t> Shape;
0110 size_t ElementCount = 0;
0111 size_t ElementSize = 0;
0112 };
0113
0114
0115 std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec);
0116
0117
0118
0119
0120
0121
0122
0123
0124 std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
0125 const json::Value &Value);
0126
0127 #define TFUTILS_GETDATATYPE_DEF(T, Name) \
0128 template <> TensorType TensorSpec::getDataType<T>();
0129 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF)
0130
0131 #undef TFUTILS_GETDATATYPE_DEF
0132 }
0133
0134 #endif