Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:32:34

0001 #ifndef TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0002 #define TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0003 
0004 #include <TMVA/SOFIE_common.hxx>
0005 #include <TMVA/ROperator.hxx>
0006 #include <TMVA/RModel.hxx>
0007 
0008 #include <memory>
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <stdexcept>
0012 #include <vector>
0013 #include <cassert>
0014 
0015 namespace TMVA {
0016 namespace Experimental {
0017 namespace SOFIE {
0018 
0019 /*! \brief Transposed Convolution operator
0020  *
0021  * Inference code generation for a transposed convolution layer.
0022  * See the <a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#convtranspose">ONNX documentation</a> for
0023  * details about the transposed conv layer.
0024  */
0025 template <typename T>
0026 class ROperator_ConvTranspose final : public ROperator {
0027 private:
0028    std::string fAttrAutopad;
0029    std::vector<size_t> fAttrDilations;
0030    size_t fAttrGroup;
0031    std::vector<size_t> fAttrKernelShape;
0032    std::vector<size_t> fAttrOutputPadding;
0033    std::vector<size_t> fAttrOutputShape;
0034    std::vector<size_t> fAttrPads;
0035    std::vector<size_t> fAttrStrides;
0036 
0037    std::string fNX;
0038    std::string fNW;
0039    std::string fNB;
0040    std::string fNBroadcastedB;
0041    std::string fNY;
0042 
0043    std::string fConvK;
0044    std::string fImcol;
0045 
0046    std::vector<size_t> fShapeX;
0047    std::vector<size_t> fShapeW;
0048    std::vector<size_t> fShapeB;
0049    std::vector<size_t> fShapeY;
0050 
0051    std::string fType;
0052 
0053    size_t fDim; // dimension of the convolution
0054 
0055 public:
0056    /*! Default constructor of ROperator_ConvTranspose */
0057    ROperator_ConvTranspose() {}
0058 
0059    /*! \brief Constructor of ROperator_ConvTranspose from the attributes
0060     *
0061     * \param autopad padding
0062     * \param dilations dilations of the kernel
0063     * \param group number of groups
0064     * \param kernelShape shape of the kernel
0065     * \param outputPadding padding of the output
0066     * \param outputShape shape of the output
0067     * \param pads padding of the input
0068     * \param strides strides
0069     * \param nameX name of the input
0070     * \param nameW name of the weight
0071     * \param nameB name of the bias
0072     * \param nameY name of the output
0073     */
0074    ROperator_ConvTranspose(std::string autopad, std::vector<size_t> dilations, size_t group,
0075                            std::vector<size_t> kernelShape, std::vector<size_t> outputPadding,
0076                            std::vector<size_t> outputShape, std::vector<size_t> pads, std::vector<size_t> strides,
0077                            std::string nameX, std::string nameW, std::string nameB, std::string nameY)
0078       : fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0079         fAttrOutputPadding(outputPadding), fAttrOutputShape(outputShape), fAttrPads(pads), fAttrStrides(strides),
0080         fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNB(UTILITY::Clean_name(nameB)),
0081         fNY(UTILITY::Clean_name(nameY))
0082    {
0083       fInputTensorNames = { fNX, fNW };
0084       fOutputTensorNames = { fNY };
0085       if (!fNB.empty()) {
0086          fInputTensorNames.emplace_back(fNB);
0087       }
0088 
0089       if (std::is_same<T, float>::value) {
0090          fType = "float";
0091       } else {
0092          throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0093       }
0094    }
0095 
0096    /*! \brief Infers the type of the output tensor
0097     * \param input type of the input tensors
0098     */
0099    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override
0100    {
0101       ETensorType out = input[0];
0102       return {out};
0103    }
0104 
0105    /*! \brief Infers the shape of the input tensors
0106     * \param input shape of the input tensors
0107     */
0108    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> /*input*/) override;
0109 
0110    /*! \brief Initialize the model
0111     * \param model Model
0112     */
0113    void Initialize(RModel &) override;
0114 
0115    /*! \brief Generate code for initializing the op
0116     */
0117    std::string GenerateInitCode() override;
0118 
0119    /*! \brief Generate the inference code
0120     * \param opName name of the operator
0121     */
0122    std::string Generate(std::string opName) override;
0123 
0124    /*! \brief Returns the blas routines needed to compile the generated code
0125     */
0126    std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
0127 };
0128 
0129 } // namespace SOFIE
0130 } // namespace Experimental
0131 } // namespace TMVA
0132 
0133 // Implementation of the ROperator_ConvTranspose class
0134 #include "TMVA/ROperator_ConvTranspose.icc"
0135 
0136 #endif