Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:20:24

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 // Experimental language syntax and parser for ONNX. Please note that the syntax as formalized
0006 // by this parser is preliminary and may change.
0007 
0008 #pragma once
0009 
0010 #include <ctype.h>
0011 
0012 #include <iostream>
0013 #include <stdexcept>
0014 #include <string>
0015 #include <unordered_map>
0016 
0017 #include "onnx/common/status.h"
0018 #include "onnx/onnx_pb.h"
0019 #include "onnx/string_utils.h"
0020 
0021 namespace ONNX_NAMESPACE {
0022 
0023 using namespace ONNX_NAMESPACE::Common;
0024 
0025 using IdList = google::protobuf::RepeatedPtrField<std::string>;
0026 
0027 using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;
0028 
0029 using AttrList = google::protobuf::RepeatedPtrField<AttributeProto>;
0030 
0031 using ValueInfoList = google::protobuf::RepeatedPtrField<ValueInfoProto>;
0032 
0033 using TensorList = google::protobuf::RepeatedPtrField<TensorProto>;
0034 
0035 using OpsetIdList = google::protobuf::RepeatedPtrField<OperatorSetIdProto>;
0036 
0037 using StringStringList = google::protobuf::RepeatedPtrField<StringStringEntryProto>;
0038 
0039 #define CHECK_PARSER_STATUS(status) \
0040   {                                 \
0041     auto local_status_ = status;    \
0042     if (!local_status_.IsOK())      \
0043       return local_status_;         \
0044   }
0045 
0046 template <typename Map>
0047 class StringIntMap {
0048  public:
0049   static const std::unordered_map<std::string, int32_t>& Instance() {
0050     static Map instance;
0051     return instance.map_;
0052   }
0053 
0054   static int32_t Lookup(const std::string& dtype) {
0055     auto it = Instance().find(dtype);
0056     if (it != Instance().end())
0057       return it->second;
0058     return 0;
0059   }
0060 
0061   static const std::string& ToString(int32_t dtype) {
0062     static std::string undefined("undefined");
0063     for (const auto& pair : Instance()) {
0064       if (pair.second == dtype)
0065         return pair.first;
0066     }
0067     return undefined;
0068   }
0069 
0070  protected:
0071   std::unordered_map<std::string, int32_t> map_;
0072 };
0073 
0074 class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
0075  public:
0076   PrimitiveTypeNameMap() : StringIntMap() {
0077     map_["float"] = TensorProto_DataType_FLOAT;
0078     map_["uint8"] = TensorProto_DataType_UINT8;
0079     map_["int8"] = TensorProto_DataType_INT8;
0080     map_["uint16"] = TensorProto_DataType_UINT16;
0081     map_["int16"] = TensorProto_DataType_INT16;
0082     map_["int32"] = TensorProto_DataType_INT32;
0083     map_["int64"] = TensorProto_DataType_INT64;
0084     map_["string"] = TensorProto_DataType_STRING;
0085     map_["bool"] = TensorProto_DataType_BOOL;
0086     map_["float16"] = TensorProto_DataType_FLOAT16;
0087     map_["double"] = TensorProto_DataType_DOUBLE;
0088     map_["uint32"] = TensorProto_DataType_UINT32;
0089     map_["uint64"] = TensorProto_DataType_UINT64;
0090     map_["complex64"] = TensorProto_DataType_COMPLEX64;
0091     map_["complex128"] = TensorProto_DataType_COMPLEX128;
0092     map_["bfloat16"] = TensorProto_DataType_BFLOAT16;
0093     map_["float8e4m3fn"] = TensorProto_DataType_FLOAT8E4M3FN;
0094     map_["float8e4m3fnuz"] = TensorProto_DataType_FLOAT8E4M3FNUZ;
0095     map_["float8e5m2"] = TensorProto_DataType_FLOAT8E5M2;
0096     map_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
0097     map_["uint4"] = TensorProto_DataType_UINT4;
0098     map_["int4"] = TensorProto_DataType_INT4;
0099   }
0100 
0101   static bool IsTypeName(const std::string& dtype) {
0102     return Lookup(dtype) != 0;
0103   }
0104 };
0105 
0106 class AttributeTypeNameMap : public StringIntMap<AttributeTypeNameMap> {
0107  public:
0108   AttributeTypeNameMap() : StringIntMap() {
0109     map_["float"] = AttributeProto_AttributeType_FLOAT;
0110     map_["int"] = AttributeProto_AttributeType_INT;
0111     map_["string"] = AttributeProto_AttributeType_STRING;
0112     map_["tensor"] = AttributeProto_AttributeType_TENSOR;
0113     map_["graph"] = AttributeProto_AttributeType_GRAPH;
0114     map_["sparse_tensor"] = AttributeProto_AttributeType_SPARSE_TENSOR;
0115     map_["type_proto"] = AttributeProto_AttributeType_TYPE_PROTO;
0116     map_["floats"] = AttributeProto_AttributeType_FLOATS;
0117     map_["ints"] = AttributeProto_AttributeType_INTS;
0118     map_["strings"] = AttributeProto_AttributeType_STRINGS;
0119     map_["tensors"] = AttributeProto_AttributeType_TENSORS;
0120     map_["graphs"] = AttributeProto_AttributeType_GRAPHS;
0121     map_["sparse_tensors"] = AttributeProto_AttributeType_SPARSE_TENSORS;
0122     map_["type_protos"] = AttributeProto_AttributeType_TYPE_PROTOS;
0123   }
0124 };
0125 
0126 class KeyWordMap {
0127  public:
0128   enum class KeyWord {
0129     NONE,
0130     IR_VERSION,
0131     OPSET_IMPORT,
0132     PRODUCER_NAME,
0133     PRODUCER_VERSION,
0134     DOMAIN_KW,
0135     MODEL_VERSION,
0136     DOC_STRING,
0137     METADATA_PROPS,
0138     SEQ_TYPE,
0139     MAP_TYPE,
0140     OPTIONAL_TYPE,
0141     SPARSE_TENSOR_TYPE,
0142     OVERLOAD_KW
0143   };
0144 
0145   KeyWordMap() {
0146     map_["ir_version"] = KeyWord::IR_VERSION;
0147     map_["opset_import"] = KeyWord::OPSET_IMPORT;
0148     map_["producer_name"] = KeyWord::PRODUCER_NAME;
0149     map_["producer_version"] = KeyWord::PRODUCER_VERSION;
0150     map_["domain"] = KeyWord::DOMAIN_KW;
0151     map_["model_version"] = KeyWord::MODEL_VERSION;
0152     map_["doc_string"] = KeyWord::DOC_STRING;
0153     map_["metadata_props"] = KeyWord::METADATA_PROPS;
0154     map_["seq"] = KeyWord::SEQ_TYPE;
0155     map_["map"] = KeyWord::MAP_TYPE;
0156     map_["optional"] = KeyWord::OPTIONAL_TYPE;
0157     map_["sparse_tensor"] = KeyWord::SPARSE_TENSOR_TYPE;
0158     map_["overload"] = KeyWord::OVERLOAD_KW;
0159   }
0160 
0161   static const std::unordered_map<std::string, KeyWord>& Instance() {
0162     static KeyWordMap instance;
0163     return instance.map_;
0164   }
0165 
0166   static KeyWord Lookup(const std::string& id) {
0167     auto it = Instance().find(id);
0168     if (it != Instance().end())
0169       return it->second;
0170     return KeyWord::NONE;
0171   }
0172 
0173   static const std::string& ToString(KeyWord kw) {
0174     static std::string undefined("undefined");
0175     for (const auto& pair : Instance()) {
0176       if (pair.second == kw)
0177         return pair.first;
0178     }
0179     return undefined;
0180   }
0181 
0182  private:
0183   std::unordered_map<std::string, KeyWord> map_;
0184 };
0185 
0186 class ParserBase {
0187  public:
0188   ParserBase(const std::string& str)
0189       : start_(str.data()), next_(str.data()), end_(str.data() + str.length()), saved_pos_(next_) {}
0190 
0191   ParserBase(const char* cstr) : start_(cstr), next_(cstr), end_(cstr + strlen(cstr)), saved_pos_(next_) {}
0192 
0193   void SavePos() {
0194     saved_pos_ = next_;
0195   }
0196 
0197   void RestorePos() {
0198     next_ = saved_pos_;
0199   }
0200 
0201   std::string GetCurrentPos() {
0202     uint32_t line = 1, col = 1;
0203     for (const char* p = start_; p < next_; ++p) {
0204       if (*p == '\n') {
0205         ++line;
0206         col = 1;
0207       } else {
0208         ++col;
0209       }
0210     }
0211     return ONNX_NAMESPACE::MakeString("(line: ", line, " column: ", col, ")");
0212   }
0213 
0214   // Return a suitable suffix of what has been parsed to provide error message context:
0215   // return the line containing the last non-space character preceding the error (if it exists).
0216   std::string GetErrorContext() {
0217     // Special cases: empty input string, and parse-error at first character.
0218     const char* p = next_ < end_ ? next_ : next_ - 1;
0219     while ((p > start_) && isspace(*p))
0220       --p;
0221     while ((p > start_) && (*p != '\n'))
0222       --p;
0223     // Start at character after '\n' unless we are at start of input
0224     const char* context_start = (p > start_) ? (p + 1) : start_;
0225     for (p = context_start; (p < end_) && (*p != '\n'); ++p)
0226       ;
0227     return std::string(context_start, p - context_start);
0228   }
0229 
0230   template <typename... Args>
0231   Status ParseError(const Args&... args) {
0232     return Status(
0233         NONE,
0234         FAIL,
0235         ONNX_NAMESPACE::MakeString(
0236             "[ParseError at position ", GetCurrentPos(), "]\n", "Error context: ", GetErrorContext(), "\n", args...));
0237   }
0238 
0239   void SkipWhiteSpace() {
0240     do {
0241       while ((next_ < end_) && (isspace(*next_)))
0242         ++next_;
0243       if ((next_ >= end_) || ((*next_) != '#'))
0244         return;
0245       // Skip rest of the line:
0246       while ((next_ < end_) && ((*next_) != '\n'))
0247         ++next_;
0248     } while (true);
0249   }
0250 
0251   int NextChar(bool skipspace = true) {
0252     if (skipspace)
0253       SkipWhiteSpace();
0254     return (next_ < end_) ? *next_ : 0;
0255   }
0256 
0257   bool Matches(char ch, bool skipspace = true) {
0258     if (skipspace)
0259       SkipWhiteSpace();
0260     if ((next_ < end_) && (*next_ == ch)) {
0261       ++next_;
0262       return true;
0263     }
0264     return false;
0265   }
0266 
0267   Status Match(char ch, bool skipspace = true) {
0268     if (!Matches(ch, skipspace))
0269       return ParseError("Expected character ", ch, " not found.");
0270     return Status::OK();
0271   }
0272 
0273   bool EndOfInput() {
0274     SkipWhiteSpace();
0275     return (next_ >= end_);
0276   }
0277 
0278   enum class LiteralType { INT_LITERAL, FLOAT_LITERAL, STRING_LITERAL };
0279 
0280   struct Literal {
0281     LiteralType type;
0282     std::string value;
0283   };
0284 
0285   Status Parse(Literal& result);
0286 
0287   Status Parse(int64_t& val) {
0288     Literal literal;
0289     CHECK_PARSER_STATUS(Parse(literal));
0290     if (literal.type != LiteralType::INT_LITERAL)
0291       return ParseError("Integer value expected, but not found.");
0292     std::string s = literal.value;
0293     val = std::stoll(s);
0294     return Status::OK();
0295   }
0296 
0297   Status Parse(uint64_t& val) {
0298     Literal literal;
0299     CHECK_PARSER_STATUS(Parse(literal));
0300     if (literal.type != LiteralType::INT_LITERAL)
0301       return ParseError("Integer value expected, but not found.");
0302     std::string s = literal.value;
0303     val = std::stoull(s);
0304     return Status::OK();
0305   }
0306 
0307   Status Parse(float& val) {
0308     Literal literal;
0309     CHECK_PARSER_STATUS(Parse(literal));
0310     switch (literal.type) {
0311       case LiteralType::INT_LITERAL:
0312       case LiteralType::FLOAT_LITERAL:
0313         val = std::stof(literal.value);
0314         break;
0315       default:
0316         return ParseError("Unexpected literal type.");
0317     }
0318     return Status::OK();
0319   }
0320 
0321   Status Parse(double& val) {
0322     Literal literal;
0323     CHECK_PARSER_STATUS(Parse(literal));
0324     switch (literal.type) {
0325       case LiteralType::INT_LITERAL:
0326       case LiteralType::FLOAT_LITERAL:
0327         val = std::stod(literal.value);
0328         break;
0329       default:
0330         return ParseError("Unexpected literal type.");
0331     }
0332     return Status::OK();
0333   }
0334 
0335   // Parse a string-literal enclosed within doube-quotes.
0336   Status Parse(std::string& val) {
0337     Literal literal;
0338     CHECK_PARSER_STATUS(Parse(literal));
0339     if (literal.type != LiteralType::STRING_LITERAL)
0340       return ParseError("String value expected, but not found.");
0341     val = literal.value;
0342     return Status::OK();
0343   }
0344 
0345   // Parse an identifier, including keywords. If none found, this will
0346   // return an empty-string identifier.
0347   Status ParseOptionalIdentifier(std::string& id) {
0348     SkipWhiteSpace();
0349     auto from = next_;
0350     if ((next_ < end_) && (isalpha(*next_) || (*next_ == '_'))) {
0351       ++next_;
0352       while ((next_ < end_) && (isalnum(*next_) || (*next_ == '_')))
0353         ++next_;
0354     }
0355     id = std::string(from, next_ - from);
0356     return Status::OK();
0357   }
0358 
0359   Status ParseIdentifier(std::string& id) {
0360     ParseOptionalIdentifier(id);
0361     if (id.empty())
0362       return ParseError("Identifier expected but not found.");
0363     return Status::OK();
0364   }
0365 
0366   Status PeekIdentifier(std::string& id) {
0367     SavePos();
0368     ParseOptionalIdentifier(id);
0369     RestorePos();
0370     return Status::OK();
0371   }
0372 
0373   Status Parse(KeyWordMap::KeyWord& keyword) {
0374     std::string id;
0375     CHECK_PARSER_STATUS(ParseIdentifier(id));
0376     keyword = KeyWordMap::Lookup(id);
0377     return Status::OK();
0378   }
0379 
0380  protected:
0381   const char* start_;
0382   const char* next_;
0383   const char* end_;
0384   const char* saved_pos_;
0385 
0386   bool NextIsValidFloatString();
0387 };
0388 
0389 class OnnxParser : public ParserBase {
0390  public:
0391   OnnxParser(const char* cstr) : ParserBase(cstr) {}
0392 
0393   Status Parse(TensorShapeProto& shape);
0394 
0395   Status Parse(TypeProto& typeProto);
0396 
0397   Status Parse(StringStringList& stringStringList);
0398 
0399   Status Parse(TensorProto& tensorProto);
0400 
0401   Status Parse(AttributeProto& attr);
0402 
0403   Status Parse(AttributeProto& attr, std::string& name);
0404 
0405   Status Parse(AttrList& attrlist);
0406 
0407   Status Parse(NodeProto& node);
0408 
0409   Status Parse(NodeList& nodelist);
0410 
0411   Status Parse(GraphProto& graph);
0412 
0413   Status Parse(FunctionProto& fn);
0414 
0415   Status Parse(ModelProto& model);
0416 
0417   template <typename T>
0418   static Status Parse(T& parsedData, const char* input) {
0419     OnnxParser parser(input);
0420     return parser.Parse(parsedData);
0421   }
0422 
0423  private:
0424   Status Parse(std::string name, GraphProto& graph);
0425 
0426   Status Parse(IdList& idlist);
0427 
0428   Status Parse(char open, IdList& idlist, char close);
0429 
0430   Status Parse(IdList& idlist, AttrList& attrlist);
0431 
0432   Status Parse(char open, IdList& idlist, AttrList& attrlist, char close);
0433 
0434   Status ParseSingleAttributeValue(AttributeProto& attr, AttributeProto_AttributeType expected);
0435 
0436   Status Parse(ValueInfoProto& valueinfo);
0437 
0438   Status ParseGraphInputOutput(ValueInfoList& vilist);
0439 
0440   Status ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist);
0441 
0442   Status Parse(char open, ValueInfoList& vilist, char close);
0443 
0444   Status ParseInput(ValueInfoList& vilist, TensorList& initializers);
0445 
0446   Status ParseValueInfo(ValueInfoList& vilist, TensorList& initializers);
0447 
0448   Status Parse(TensorProto& tensorProto, const TypeProto& tensorTypeProto);
0449 
0450   Status Parse(OpsetIdList& opsets);
0451 
0452   bool NextIsType();
0453 
0454   bool NextIsIdentifier();
0455 };
0456 
0457 } // namespace ONNX_NAMESPACE