Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:23:00

0001 #ifndef TMVA_SOFIE_ROPERATOR_LSTM
0002 #define TMVA_SOFIE_ROPERATOR_LSTM
0003 
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007 
0008 #include <memory>
0009 #include <sstream>
0010 #include <string>
0011 #include <vector>
0012 
0013 namespace TMVA {
0014 namespace Experimental {
0015 namespace SOFIE {
0016 
0017 /*! \brief Long Short-Term Memory operator
0018  *
0019  * Inference code generation for one-layer LSTM. Supports forward, reverse and bidirectional LSTM.
0020  * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM">ONNX documentation</a>
0021  * for details about the supported LSTM architectures.
0022  */
0023 template <typename T> class ROperator_LSTM final : public ROperator {
0024  private:
0025    std::vector<float> fAttrActivationAlpha;   ///< Sacling values used by some activation functions
0026    std::vector<float> fAttrActivationBeta;    ///< Scaling values used by some activation functions
0027    std::vector<std::string> fAttrActivations; ///< Activation functions
0028    float fAttrClip;                           ///< Clip threshold
0029    std::string fAttrDirection;                ///< Direction of processing
0030    size_t fAttrHiddenSize;                    ///< Number of the hidden layers
0031    size_t fAttrInputForget;                   ///< Forget gate
0032    size_t fAttrLayout;                        ///< Data layout
0033 
0034    std::string fNX;                           ///< Name of the input
0035    std::string fNW;                           ///< Name of the weights
0036    std::string fNR;                           ///< Name of the recurrence
0037    std::string fNB;                           ///< Name of the bias
0038    std::string fNSequence_lens;               ///< Name of length of the sequences
0039    std::string fNInitial_h;                   ///< Name of the initial value of the hidden states
0040    std::string fNInitial_c;                   ///< Name of the initial value of the cell states
0041    std::string fNP;                           ///< Name of peepholes
0042    std::string fNY;                           ///< Name of the output
0043    std::string fNY_h;                         ///< Name of the last sequence of the output
0044    std::string fNY_c;                         ///< Name of the last sequence of the cell states
0045 
0046    std::vector<size_t> fShapeX;               ///< Shape of the input
0047    std::vector<size_t> fShapeW;               ///< Shape of the weights
0048    std::vector<size_t> fShapeR;               ///< Shape of the recurrence
0049    std::vector<size_t> fShapeB;               ///< Shape of the bias
0050    std::vector<size_t> fShapeSequence_lens;   ///< Shape of the length of the sequences
0051    std::vector<size_t> fShapeInitial_h;       ///< Shape of the initial value of the hidden states
0052    std::vector<size_t> fShapeInitial_c;       ///< Shape of the initial value of the cell states
0053    std::vector<size_t> fShapeP;               ///< Shape of the peepholes
0054    std::vector<size_t> fShapeY;               ///< Shape of the output
0055    std::vector<size_t> fShapeY_h;             ///< Shape of the last sequence of the output
0056    std::vector<size_t> fShapeY_c;             ///< Shape of the last sequence of the cell states
0057 
0058    std::string fType;                         ///< Type of the tensors
0059 
0060  public:
0061    /*! Default constructor of ROperator_LSTM */
0062    ROperator_LSTM() {}
0063 
0064    /*! \brief Constructor of ROperator_LSTM from the attributes
0065     *
0066     * \param activation_alpha scaling values used by some activation functions
0067     * \param activation_beta scaling values used by some activation functions
0068     * \param activations activation functions
0069     * \param clip clip threshold
0070     * \param direction direction of processing of the sequneces
0071     * \param hidden_size number of hidden layers
0072     * \param input_forget forget gate
0073     * \param layout data layout
0074     * \param nameX name of the input tensor
0075     * \param nameW name of the weight tensor
0076     * \param nameR name of the recurrence tensor
0077     * \param nameB name of the bias tensor
0078     * \param nameSequence_lens name of the length of the sequences
0079     * \param nameInitial_h name of the initial value of the hidden states
0080     * \param nameInitial_c name of the initial value of the cell states
0081     * \param nameP name of the peepholes tensor
0082     * \param nameY name of the output
0083     * \param nameY_h name of the last sequence of the output
0084     * \param nameY_c name of the last sequence of the cell states
0085     */
0086    ROperator_LSTM(std::vector<float> activation_alpha,
0087                  std::vector<float> activation_beta,
0088                  std::vector<std::string> activations, float clip,
0089                  std::string direction, size_t hidden_size,
0090                  size_t input_forget, size_t layout,
0091                  std::string nameX, std::string nameW, std::string nameR,
0092                  std::string nameB, std::string nameSequence_lens,
0093                  std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
0094                  std::string nameY, std::string nameY_h, std::string nameY_c)
0095        : fAttrActivationAlpha(activation_alpha),
0096          fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0097          fAttrClip(clip), fAttrDirection(direction), fAttrHiddenSize(hidden_size),
0098          fAttrInputForget(input_forget), fAttrLayout(layout),
0099          fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0100          fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0101          fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0102          fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0103          fNInitial_c(UTILITY::Clean_name(nameInitial_c)), fNP(UTILITY::Clean_name(nameP)),
0104          fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)),
0105          fNY_c(UTILITY::Clean_name(nameY_c)) {
0106       if (std::is_same<T, float>::value) {
0107          fType = "float";
0108       } else {
0109          throw std::runtime_error(
0110              "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
0111       }
0112    }
0113 
0114    /*! \brief Infers the type of the output tensors
0115     *
0116     * \param input type of the input tensors
0117     */
0118    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0119 
0120    /*! \brief Infers the shape of the output tensors
0121     *
0122     * \param input shape of the input tensors
0123     */
0124    std::vector<std::vector<size_t>>
0125    ShapeInference(std::vector<std::vector<size_t>> input);
0126 
0127    /*! \brief Initialize the model
0128     *
0129     * \param model Model
0130     */
0131    void Initialize(RModel &model);
0132 
0133    /*! \brief Generate the inference code
0134     *
0135     * \param OpName name of the operator
0136     */
0137    std::string Generate(std::string OpName);
0138 
0139    /*! \brief Generate the code for the Session internal data vectors
0140     *
0141     * \param opName name of the operator
0142     */
0143    std::string GenerateSessionMembersCode(std::string opName);
0144 
0145    /*! \brief Returns the blas routines needed to compile the generated code
0146     */
0147    std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0148 };
0149 
0150 } // namespace SOFIE
0151 } // namespace Experimental
0152 } // namespace TMVA
0153 
0154 // Implementation of the ROperator_LSTM class
0155 #include "TMVA/ROperator_LSTM.icc"
0156 
0157 #endif