Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:05

0001 #ifndef TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0002 #define TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0003 
0004 #include <TMVA/SOFIE_common.hxx>
0005 #include <TMVA/ROperator.hxx>
0006 #include <TMVA/RModel.hxx>
0007 
0008 #include <memory>
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <stdexcept>
0012 #include <vector>
0013 #include <cassert>
0014 
0015 namespace TMVA {
0016 namespace Experimental {
0017 namespace SOFIE {
0018 
0019 /*! \brief Transposed Convolution operator
0020  *
0021  * Inference code generation for a transposed convolution layer.
0022  * See the <a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#convtranspose">ONNX documentation</a> for
0023  * details about the transposed conv layer.
0024  */
0025 template <typename T>
0026 class ROperator_ConvTranspose final : public ROperator {
0027 private:
0028    std::string fAttrAutopad;
0029    std::vector<size_t> fAttrDilations;
0030    size_t fAttrGroup;
0031    std::vector<size_t> fAttrKernelShape;
0032    std::vector<size_t> fAttrOutputPadding;
0033    std::vector<size_t> fAttrOutputShape;
0034    std::vector<size_t> fAttrPads;
0035    std::vector<size_t> fAttrStrides;
0036 
0037    std::string fNX;
0038    std::string fNW;
0039    std::string fNB;
0040    std::string fNBroadcastedB;
0041    std::string fNY;
0042 
0043    std::vector<size_t> fShapeX;
0044    std::vector<size_t> fShapeW;
0045    std::vector<size_t> fShapeB;
0046    std::vector<size_t> fShapeY;
0047 
0048    std::string fType;
0049 
0050    size_t fDim; // dimension of the convolution
0051 
0052 public:
0053    /*! Default constructor of ROperator_ConvTranspose */
0054    ROperator_ConvTranspose() {}
0055 
0056    /*! \brief Constructor of ROperator_ConvTranspose from the attributes
0057     *
0058     * \param autopad padding
0059     * \param dilations dilations of the kernel
0060     * \param group number of groups
0061     * \param kernelShape shape of the kernel
0062     * \param outputPadding padding of the output
0063     * \param outputShape shape of the output
0064     * \param pads padding of the input
0065     * \param strides strides
0066     * \param nameX name of the input
0067     * \param nameW name of the weight
0068     * \param nameB name of the bias
0069     * \param nameY name of the output
0070     */
0071    ROperator_ConvTranspose(std::string autopad, std::vector<size_t> dilations, size_t group,
0072                            std::vector<size_t> kernelShape, std::vector<size_t> outputPadding,
0073                            std::vector<size_t> outputShape, std::vector<size_t> pads, std::vector<size_t> strides,
0074                            std::string nameX, std::string nameW, std::string nameB, std::string nameY)
0075       : fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0076         fAttrOutputPadding(outputPadding), fAttrOutputShape(outputShape), fAttrPads(pads), fAttrStrides(strides),
0077         fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNB(UTILITY::Clean_name(nameB)),
0078         fNY(UTILITY::Clean_name(nameY))
0079    {
0080       if (std::is_same<T, float>::value) {
0081          fType = "float";
0082       } else {
0083          throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0084       }
0085    }
0086 
0087    /*! \brief Infers the type of the output tensor
0088     * \param input type of the input tensors
0089     */
0090    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override
0091    {
0092       ETensorType out = input[0];
0093       return {out};
0094    }
0095 
0096    /*! \brief Infers the shape of the input tensors
0097     * \param input shape of the input tensors
0098     */
0099    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> /*input*/) override;
0100 
0101    /*! \brief Initialize the model
0102     * \param model Model
0103     */
0104    void Initialize(RModel & /*model*/) override;
0105 
0106    /*! \brief Generate code for initializing the op
0107     */
0108    std::string GenerateInitCode() override;
0109 
0110    /*! \brief Generate code for Session data members (e.g. internal vectors)
0111     * \param opName name of the operator
0112     */
0113    std::string GenerateSessionMembersCode(std::string /*opName*/) override;
0114 
0115    /*! \brief Generate the inference code
0116     * \param opName name of the operator
0117     */
0118    std::string Generate(std::string opName) override;
0119 
0120    /*! \brief Returns the blas routines needed to compile the generated code
0121     */
0122    std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
0123 };
0124 
0125 } // namespace SOFIE
0126 } // namespace Experimental
0127 } // namespace TMVA
0128 
0129 // Implementation of the ROperator_ConvTranspose class
0130 #include "TMVA/ROperator_ConvTranspose.icc"
0131 
0132 #endif