Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-06 10:01:47

0001 #ifndef TMVA_SOFIE_ROPERATOR_LSTM
0002 #define TMVA_SOFIE_ROPERATOR_LSTM
0003 
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007 
0008 #include <memory>
0009 #include <sstream>
0010 #include <string>
0011 #include <vector>
0012 
0013 namespace TMVA {
0014 namespace Experimental {
0015 namespace SOFIE {
0016 
0017 /*! \brief Long Short-Term Memory operator
0018  *
0019  * Inference code generation for one-layer LSTM. Supports forward, reverse and bidirectional LSTM.
0020  * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM">ONNX documentation</a>
0021  * for details about the supported LSTM architectures.
0022  */
0023 template <typename T> class ROperator_LSTM final : public ROperator {
0024  private:
0025    std::vector<float> fAttrActivationAlpha;   ///< Sacling values used by some activation functions
0026    std::vector<float> fAttrActivationBeta;    ///< Scaling values used by some activation functions
0027    std::vector<std::string> fAttrActivations; ///< Activation functions
0028    float fAttrClip;                           ///< Clip threshold
0029    std::string fAttrDirection;                ///< Direction of processing
0030    size_t fAttrHiddenSize;                    ///< Number of the hidden layers
0031    size_t fAttrInputForget;                   ///< Forget gate
0032    size_t fAttrLayout;                        ///< Data layout
0033 
0034    std::string fNX;                           ///< Name of the input
0035    std::string fNW;                           ///< Name of the weights
0036    std::string fNR;                           ///< Name of the recurrence
0037    std::string fNB;                           ///< Name of the bias
0038    std::string fNSequence_lens;               ///< Name of length of the sequences
0039    std::string fNInitial_h;                   ///< Name of the initial value of the hidden states
0040    std::string fNInitial_c;                   ///< Name of the initial value of the cell states
0041    std::string fNP;                           ///< Name of peepholes
0042    std::string fNY;                           ///< Name of the output
0043    std::string fNY_h;                         ///< Name of the last sequence of the output
0044    std::string fNY_c;                         ///< Name of the last sequence of the cell states
0045 
0046    std::vector<size_t> fShapeX;               ///< Shape of the input
0047    std::vector<size_t> fShapeW;               ///< Shape of the weights
0048    std::vector<size_t> fShapeR;               ///< Shape of the recurrence
0049    std::vector<size_t> fShapeB;               ///< Shape of the bias
0050    std::vector<size_t> fShapeSequence_lens;   ///< Shape of the length of the sequences
0051    std::vector<size_t> fShapeInitial_h;       ///< Shape of the initial value of the hidden states
0052    std::vector<size_t> fShapeInitial_c;       ///< Shape of the initial value of the cell states
0053    std::vector<size_t> fShapeP;               ///< Shape of the peepholes
0054    std::vector<size_t> fShapeY;               ///< Shape of the output
0055    std::vector<size_t> fShapeY_h;             ///< Shape of the last sequence of the output
0056    std::vector<size_t> fShapeY_c;             ///< Shape of the last sequence of the cell states
0057 
0058    std::string fType;                         ///< Type of the tensors
0059 
0060  public:
0061    /*! Default constructor of ROperator_LSTM */
0062    ROperator_LSTM() {}
0063 
0064    /*! \brief Constructor of ROperator_LSTM from the attributes
0065     *
0066     * \param activation_alpha scaling values used by some activation functions
0067     * \param activation_beta scaling values used by some activation functions
0068     * \param activations activation functions
0069     * \param clip clip threshold
0070     * \param direction direction of processing of the sequneces
0071     * \param hidden_size number of hidden layers
0072     * \param input_forget forget gate
0073     * \param layout data layout
0074     * \param nameX name of the input tensor
0075     * \param nameW name of the weight tensor
0076     * \param nameR name of the recurrence tensor
0077     * \param nameB name of the bias tensor
0078     * \param nameSequence_lens name of the length of the sequences
0079     * \param nameInitial_h name of the initial value of the hidden states
0080     * \param nameInitial_c name of the initial value of the cell states
0081     * \param nameP name of the peepholes tensor
0082     * \param nameY name of the output
0083     * \param nameY_h name of the last sequence of the output
0084     * \param nameY_c name of the last sequence of the cell states
0085     */
0086    ROperator_LSTM(std::vector<float> activation_alpha,
0087                  std::vector<float> activation_beta,
0088                  std::vector<std::string> activations, float clip,
0089                  std::string direction, size_t hidden_size,
0090                  size_t input_forget, size_t layout,
0091                  std::string nameX, std::string nameW, std::string nameR,
0092                  std::string nameB, std::string nameSequence_lens,
0093                  std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
0094                  std::string nameY, std::string nameY_h, std::string nameY_c)
0095        : fAttrActivationAlpha(activation_alpha),
0096          fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0097          fAttrClip(clip), fAttrDirection(direction), fAttrHiddenSize(hidden_size),
0098          fAttrInputForget(input_forget), fAttrLayout(layout),
0099          fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0100          fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0101          fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0102          fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0103          fNInitial_c(UTILITY::Clean_name(nameInitial_c)), fNP(UTILITY::Clean_name(nameP)),
0104          fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)),
0105          fNY_c(UTILITY::Clean_name(nameY_c)) {
0106       if (std::is_same<T, float>::value) {
0107          fType = "float";
0108       } else {
0109          throw std::runtime_error(
0110              "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
0111       }
0112       
0113       fInputTensorNames = { fNX, fNW, fNR };
0114       if (!fNB.empty()){
0115          fInputTensorNames.emplace_back(fNB);
0116       }
0117       if (!fNSequence_lens.empty()){
0118          fInputTensorNames.emplace_back(fNSequence_lens);
0119       }
0120       if (!fNInitial_h.empty()){
0121          fInputTensorNames.emplace_back(fNInitial_h);
0122       }
0123       if (!fNInitial_c.empty()){
0124          fInputTensorNames.emplace_back(fNInitial_c);
0125       }
0126       if (!fNP.empty()){
0127          fInputTensorNames.emplace_back(fNP);
0128       }
0129 
0130       fOutputTensorNames = { };
0131       if (!fNY.empty()){
0132          fOutputTensorNames.emplace_back(fNY);
0133       }
0134       if (!fNY_h.empty()){
0135          fOutputTensorNames.emplace_back(fNY_h);
0136       }
0137       if (!fNY_c.empty()){
0138          fOutputTensorNames.emplace_back(fNY_c);
0139       }
0140    }
0141 
0142    /*! \brief Infers the type of the output tensors
0143     *
0144     * \param input type of the input tensors
0145     */
0146    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0147 
0148    /*! \brief Infers the shape of the output tensors
0149     *
0150     * \param input shape of the input tensors
0151     */
0152    std::vector<std::vector<size_t>>
0153    ShapeInference(std::vector<std::vector<size_t>> input);
0154 
0155    /*! \brief Initialize the model
0156     *
0157     * \param model Model
0158     */
0159    void Initialize(RModel &);
0160 
0161    /*! \brief Generate the inference code
0162     *
0163     * \param OpName name of the operator
0164     */
0165    std::string Generate(std::string OpName);
0166 
0167    /*! \brief Generate the code for the Session internal data vectors
0168     *
0169     * \param opName name of the operator
0170     */
0171    std::string GenerateSessionMembersCode(std::string opName);
0172 
0173    /*! \brief Returns the blas routines needed to compile the generated code
0174     */
0175    std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0176 };
0177 
0178 } // namespace SOFIE
0179 } // namespace Experimental
0180 } // namespace TMVA
0181 
0182 // Implementation of the ROperator_LSTM class
0183 #include "TMVA/ROperator_LSTM.icc"
0184 
0185 #endif