Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:45

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include "onnx/defs/schema.h"
0008 
0009 namespace ONNX_NAMESPACE {
0010 
0011 // Declare training operators.
0012 
0013 class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient);
0014 class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum);
0015 class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad);
0016 class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam);
0017 
0018 // Iterate over schema from ai.onnx.training version 1
0019 class OpSet_OnnxPreview_ver1 {
0020  public:
0021   static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
0022     fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient)>());
0023     fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum)>());
0024     fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad)>());
0025     fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam)>());
0026   }
0027 };
0028 
0029 // Register preview operators.
0030 inline void RegisterOnnxPreviewOperatorSetSchema() {
0031   // Preview operators should have only one version.
0032   // If changes are needed for a specific preview operator,
0033   // its spec should be modified without increasing its version.
0034   RegisterOpSetSchema<OpSet_OnnxPreview_ver1>();
0035 }
0036 
0037 } // namespace ONNX_NAMESPACE