File indexing completed on 2025-12-16 10:20:24
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010 #include <ctype.h>
0011
0012 #include <iostream>
0013 #include <stdexcept>
0014 #include <string>
0015 #include <unordered_map>
0016
0017 #include "onnx/common/status.h"
0018 #include "onnx/onnx_pb.h"
0019 #include "onnx/string_utils.h"
0020
0021 namespace ONNX_NAMESPACE {
0022
0023 using namespace ONNX_NAMESPACE::Common;
0024
0025 using IdList = google::protobuf::RepeatedPtrField<std::string>;
0026
0027 using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;
0028
0029 using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>;
0030
0031 using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>;
0032
0033 using TensorList = google::protobuf::RepeatedPtrField<TensorProto>;
0034
0035 using OpsetIdList = google::protobuf::RepeatedPtrField<OperatorSetIdProto>;
0036
0037 using StringStringList = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
0038
0039 #define CHECK_PARSER_STATUS(status) \
0040 { \
0041 auto local_status_ = status; \
0042 if (!local_status_.IsOK()) \
0043 return local_status_; \
0044 }
0045
0046 template <typename Map>
0047 class StringIntMap {
0048 public:
0049 static const std::unordered_map<std::string, int32_t>& Instance() {
0050 static Map instance;
0051 return instance.map_;
0052 }
0053
0054 static int32_t Lookup(const std::string& dtype) {
0055 auto it = Instance().find(dtype);
0056 if (it != Instance().end())
0057 return it->second;
0058 return 0;
0059 }
0060
0061 static const std::string& ToString(int32_t dtype) {
0062 static std::string undefined("undefined");
0063 for (const auto& pair : Instance()) {
0064 if (pair.second == dtype)
0065 return pair.first;
0066 }
0067 return undefined;
0068 }
0069
0070 protected:
0071 std::unordered_map<std::string, int32_t> map_;
0072 };
0073
0074 class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
0075 public:
0076 PrimitiveTypeNameMap() : StringIntMap() {
0077 map_["float"] = TensorProto_DataType_FLOAT;
0078 map_["uint8"] = TensorProto_DataType_UINT8;
0079 map_["int8"] = TensorProto_DataType_INT8;
0080 map_["uint16"] = TensorProto_DataType_UINT16;
0081 map_["int16"] = TensorProto_DataType_INT16;
0082 map_["int32"] = TensorProto_DataType_INT32;
0083 map_["int64"] = TensorProto_DataType_INT64;
0084 map_["string"] = TensorProto_DataType_STRING;
0085 map_["bool"] = TensorProto_DataType_BOOL;
0086 map_["float16"] = TensorProto_DataType_FLOAT16;
0087 map_["double"] = TensorProto_DataType_DOUBLE;
0088 map_["uint32"] = TensorProto_DataType_UINT32;
0089 map_["uint64"] = TensorProto_DataType_UINT64;
0090 map_["complex64"] = TensorProto_DataType_COMPLEX64;
0091 map_["complex128"] = TensorProto_DataType_COMPLEX128;
0092 map_["bfloat16"] = TensorProto_DataType_BFLOAT16;
0093 map_["float8e4m3fn"] = TensorProto_DataType_FLOAT8E4M3FN;
0094 map_["float8e4m3fnuz"] = TensorProto_DataType_FLOAT8E4M3FNUZ;
0095 map_["float8e5m2"] = TensorProto_DataType_FLOAT8E5M2;
0096 map_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
0097 map_["uint4"] = TensorProto_DataType_UINT4;
0098 map_["int4"] = TensorProto_DataType_INT4;
0099 }
0100
0101 static bool IsTypeName(const std::string& dtype) {
0102 return Lookup(dtype) != 0;
0103 }
0104 };
0105
0106 class AttributeTypeNameMap : public StringIntMap<AttributeTypeNameMap> {
0107 public:
0108 AttributeTypeNameMap() : StringIntMap() {
0109 map_["float"] = AttributeProto_AttributeType_FLOAT;
0110 map_["int"] = AttributeProto_AttributeType_INT;
0111 map_["string"] = AttributeProto_AttributeType_STRING;
0112 map_["tensor"] = AttributeProto_AttributeType_TENSOR;
0113 map_["graph"] = AttributeProto_AttributeType_GRAPH;
0114 map_["sparse_tensor"] = AttributeProto_AttributeType_SPARSE_TENSOR;
0115 map_["type_proto"] = AttributeProto_AttributeType_TYPE_PROTO;
0116 map_["floats"] = AttributeProto_AttributeType_FLOATS;
0117 map_["ints"] = AttributeProto_AttributeType_INTS;
0118 map_["strings"] = AttributeProto_AttributeType_STRINGS;
0119 map_["tensors"] = AttributeProto_AttributeType_TENSORS;
0120 map_["graphs"] = AttributeProto_AttributeType_GRAPHS;
0121 map_["sparse_tensors"] = AttributeProto_AttributeType_SPARSE_TENSORS;
0122 map_["type_protos"] = AttributeProto_AttributeType_TYPE_PROTOS;
0123 }
0124 };
0125
0126 class KeyWordMap {
0127 public:
0128 enum class KeyWord {
0129 NONE,
0130 IR_VERSION,
0131 OPSET_IMPORT,
0132 PRODUCER_NAME,
0133 PRODUCER_VERSION,
0134 DOMAIN_KW,
0135 MODEL_VERSION,
0136 DOC_STRING,
0137 METADATA_PROPS,
0138 SEQ_TYPE,
0139 MAP_TYPE,
0140 OPTIONAL_TYPE,
0141 SPARSE_TENSOR_TYPE,
0142 OVERLOAD_KW
0143 };
0144
0145 KeyWordMap() {
0146 map_["ir_version"] = KeyWord::IR_VERSION;
0147 map_["opset_import"] = KeyWord::OPSET_IMPORT;
0148 map_["producer_name"] = KeyWord::PRODUCER_NAME;
0149 map_["producer_version"] = KeyWord::PRODUCER_VERSION;
0150 map_["domain"] = KeyWord::DOMAIN_KW;
0151 map_["model_version"] = KeyWord::MODEL_VERSION;
0152 map_["doc_string"] = KeyWord::DOC_STRING;
0153 map_["metadata_props"] = KeyWord::METADATA_PROPS;
0154 map_["seq"] = KeyWord::SEQ_TYPE;
0155 map_["map"] = KeyWord::MAP_TYPE;
0156 map_["optional"] = KeyWord::OPTIONAL_TYPE;
0157 map_["sparse_tensor"] = KeyWord::SPARSE_TENSOR_TYPE;
0158 map_["overload"] = KeyWord::OVERLOAD_KW;
0159 }
0160
0161 static const std::unordered_map<std::string, KeyWord>& Instance() {
0162 static KeyWordMap instance;
0163 return instance.map_;
0164 }
0165
0166 static KeyWord Lookup(const std::string& id) {
0167 auto it = Instance().find(id);
0168 if (it != Instance().end())
0169 return it->second;
0170 return KeyWord::NONE;
0171 }
0172
0173 static const std::string& ToString(KeyWord kw) {
0174 static std::string undefined("undefined");
0175 for (const auto& pair : Instance()) {
0176 if (pair.second == kw)
0177 return pair.first;
0178 }
0179 return undefined;
0180 }
0181
0182 private:
0183 std::unordered_map<std::string, KeyWord> map_;
0184 };
0185
0186 class ParserBase {
0187 public:
0188 ParserBase(const std::string& str)
0189 : start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {}
0190
0191 ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {}
0192
0193 void SavePos() {
0194 saved_pos_ = next_;
0195 }
0196
0197 void RestorePos() {
0198 next_ = saved_pos_;
0199 }
0200
0201 std::string GetCurrentPos() {
0202 uint32_t line = 1, col = 1;
0203 for (const char* p = start_; p < next_; ++p) {
0204 if (*p == '\n') {
0205 ++line;
0206 col = 1;
0207 } else {
0208 ++col;
0209 }
0210 }
0211 return ONNX_NAMESPACE::MakeString("(line: ", line, " column: ", col, ")");
0212 }
0213
0214
0215
0216 std::string GetErrorContext() {
0217
0218 const char* p = next_ < end_ ? next_ : next_ - 1;
0219 while ((p > start_) && isspace(*p))
0220 --p;
0221 while ((p > start_) && (*p != '\n'))
0222 --p;
0223
0224 const char* context_start = (p > start_) ? (p + 1) : start_;
0225 for (p = context_start; (p < end_) && (*p != '\n'); ++p)
0226 ;
0227 return std::string(context_start, p - context_start);
0228 }
0229
0230 template <typename... Args>
0231 Status ParseError(const Args&... args) {
0232 return Status(
0233 NONE,
0234 FAIL,
0235 ONNX_NAMESPACE::MakeString(
0236 "[ParseError at position ", GetCurrentPos(), "]\n", "Error context: ", GetErrorContext(), "\n", args...));
0237 }
0238
0239 void SkipWhiteSpace() {
0240 do {
0241 while ((next_ < end_) && (isspace(*next_)))
0242 ++next_;
0243 if ((next_ >= end_) || ((*next_) != '#'))
0244 return;
0245
0246 while ((next_ < end_) && ((*next_) != '\n'))
0247 ++next_;
0248 } while (true);
0249 }
0250
0251 int NextChar(bool skipspace = true) {
0252 if (skipspace)
0253 SkipWhiteSpace();
0254 return (next_ < end_) ? *next_ : 0;
0255 }
0256
0257 bool Matches(char ch, bool skipspace = true) {
0258 if (skipspace)
0259 SkipWhiteSpace();
0260 if ((next_ < end_) && (*next_ == ch)) {
0261 ++next_;
0262 return true;
0263 }
0264 return false;
0265 }
0266
0267 Status Match(char ch, bool skipspace = true) {
0268 if (!Matches(ch, skipspace))
0269 return ParseError("Expected character ", ch, " not found.");
0270 return Status::OK();
0271 }
0272
0273 bool EndOfInput() {
0274 SkipWhiteSpace();
0275 return (next_ >= end_);
0276 }
0277
0278 enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL };
0279
0280 struct Literal {
0281 LiteralType type;
0282 std::string value;
0283 };
0284
0285 Status Parse(Literal& result);
0286
0287 Status Parse(int64_t& val) {
0288 Literal literal;
0289 CHECK_PARSER_STATUS(Parse(literal));
0290 if (literal.type != LiteralType::INT_LITERAL)
0291 return ParseError("Integer value expected, but not found.");
0292 std::string s = literal.value;
0293 val = std::stoll(s);
0294 return Status::OK();
0295 }
0296
0297 Status Parse(uint64_t& val) {
0298 Literal literal;
0299 CHECK_PARSER_STATUS(Parse(literal));
0300 if (literal.type != LiteralType::INT_LITERAL)
0301 return ParseError("Integer value expected, but not found.");
0302 std::string s = literal.value;
0303 val = std::stoull(s);
0304 return Status::OK();
0305 }
0306
0307 Status Parse(float& val) {
0308 Literal literal;
0309 CHECK_PARSER_STATUS(Parse(literal));
0310 switch (literal.type) {
0311 case LiteralType::INT_LITERAL:
0312 case LiteralType::FLOAT_LITERAL:
0313 val = std::stof(literal.value);
0314 break;
0315 default:
0316 return ParseError("Unexpected literal type.");
0317 }
0318 return Status::OK();
0319 }
0320
0321 Status Parse(double& val) {
0322 Literal literal;
0323 CHECK_PARSER_STATUS(Parse(literal));
0324 switch (literal.type) {
0325 case LiteralType::INT_LITERAL:
0326 case LiteralType::FLOAT_LITERAL:
0327 val = std::stod(literal.value);
0328 break;
0329 default:
0330 return ParseError("Unexpected literal type.");
0331 }
0332 return Status::OK();
0333 }
0334
0335
0336 Status Parse(std::string& val) {
0337 Literal literal;
0338 CHECK_PARSER_STATUS(Parse(literal));
0339 if (literal.type != LiteralType::STRING_LITERAL)
0340 return ParseError("String value expected, but not found.");
0341 val = literal.value;
0342 return Status::OK();
0343 }
0344
0345
0346
0347 Status ParseOptionalIdentifier(std::string& id) {
0348 SkipWhiteSpace();
0349 auto from = next_;
0350 if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) {
0351 ++next_;
0352 while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_')))
0353 ++next_;
0354 }
0355 id = std::string(from, next_ - from);
0356 return Status::OK();
0357 }
0358
0359 Status ParseIdentifier(std::string& id) {
0360 ParseOptionalIdentifier(id);
0361 if (id.empty())
0362 return ParseError("Identifier expected but not found.");
0363 return Status::OK();
0364 }
0365
0366 Status PeekIdentifier(std::string& id) {
0367 SavePos();
0368 ParseOptionalIdentifier(id);
0369 RestorePos();
0370 return Status::OK();
0371 }
0372
0373 Status Parse(KeyWordMap::KeyWord& keyword) {
0374 std::string id;
0375 CHECK_PARSER_STATUS(ParseIdentifier(id));
0376 keyword = KeyWordMap::Lookup(id);
0377 return Status::OK();
0378 }
0379
0380 protected:
0381 const char* start_;
0382 const char* next_;
0383 const char* end_;
0384 const char* saved_pos_;
0385
0386 bool NextIsValidFloatString();
0387 };
0388
0389 class OnnxParser : public ParserBase {
0390 public:
0391 OnnxParser(const char* cstr) : ParserBase(cstr) {}
0392
0393 Status Parse(TensorShapeProto& shape);
0394
0395 Status Parse(TypeProto& typeProto);
0396
0397 Status Parse(StringStringList& stringStringList);
0398
0399 Status Parse(TensorProto& tensorProto);
0400
0401 Status Parse(AttributeProto& attr);
0402
0403 Status Parse(AttributeProto& attr, std::string& name);
0404
0405 Status Parse(AttrList& attrlist);
0406
0407 Status Parse(NodeProto& node);
0408
0409 Status Parse(NodeList& nodelist);
0410
0411 Status Parse(GraphProto& graph);
0412
0413 Status Parse(FunctionProto& fn);
0414
0415 Status Parse(ModelProto& model);
0416
0417 template <typename T>
0418 static Status Parse(T& parsedData, const char* input) {
0419 OnnxParser parser(input);
0420 return parser.Parse(parsedData);
0421 }
0422
0423 private:
0424 Status Parse(std::string name, GraphProto& graph);
0425
0426 Status Parse(IdList& idlist);
0427
0428 Status Parse(char open, IdList& idlist, char close);
0429
0430 Status Parse(IdList& idlist, AttrList& attrlist);
0431
0432 Status Parse(char open, IdList& idlist, AttrList& attrlist, char close);
0433
0434 Status ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected);
0435
0436 Status Parse(ValueInfoProto& valueinfo);
0437
0438 Status ParseGraphInputOutput(ValueInfoList& vilist);
0439
0440 Status ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist);
0441
0442 Status Parse(char open, ValueInfoList& vilist, char close);
0443
0444 Status ParseInput(ValueInfoList& vilist, TensorList& initializers);
0445
0446 Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers);
0447
0448 Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto);
0449
0450 Status Parse(OpsetIdList& opsets);
0451
0452 bool NextIsType();
0453
0454 bool NextIsIdentifier();
0455 };
0456
0457 }