File indexing completed on 2025-09-18 09:32:34
0001 #ifndef TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0002 #define TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0003
0004 #include <TMVA/SOFIE_common.hxx>
0005 #include <TMVA/ROperator.hxx>
0006 #include <TMVA/RModel.hxx>
0007
0008 #include <memory>
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <stdexcept>
0012 #include <vector>
0013 #include <cassert>
0014
0015 namespace TMVA {
0016 namespace Experimental {
0017 namespace SOFIE {
0018
0019
0020
0021
0022
0023
0024
0025 template <typename T>
0026 class ROperator_ConvTranspose final : public ROperator {
0027 private:
0028 std::string fAttrAutopad;
0029 std::vector<size_t> fAttrDilations;
0030 size_t fAttrGroup;
0031 std::vector<size_t> fAttrKernelShape;
0032 std::vector<size_t> fAttrOutputPadding;
0033 std::vector<size_t> fAttrOutputShape;
0034 std::vector<size_t> fAttrPads;
0035 std::vector<size_t> fAttrStrides;
0036
0037 std::string fNX;
0038 std::string fNW;
0039 std::string fNB;
0040 std::string fNBroadcastedB;
0041 std::string fNY;
0042
0043 std::string fConvK;
0044 std::string fImcol;
0045
0046 std::vector<size_t> fShapeX;
0047 std::vector<size_t> fShapeW;
0048 std::vector<size_t> fShapeB;
0049 std::vector<size_t> fShapeY;
0050
0051 std::string fType;
0052
0053 size_t fDim;
0054
0055 public:
0056
0057 ROperator_ConvTranspose() {}
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074 ROperator_ConvTranspose(std::string autopad, std::vector<size_t> dilations, size_t group,
0075 std::vector<size_t> kernelShape, std::vector<size_t> outputPadding,
0076 std::vector<size_t> outputShape, std::vector<size_t> pads, std::vector<size_t> strides,
0077 std::string nameX, std::string nameW, std::string nameB, std::string nameY)
0078 : fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0079 fAttrOutputPadding(outputPadding), fAttrOutputShape(outputShape), fAttrPads(pads), fAttrStrides(strides),
0080 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNB(UTILITY::Clean_name(nameB)),
0081 fNY(UTILITY::Clean_name(nameY))
0082 {
0083 fInputTensorNames = { fNX, fNW };
0084 fOutputTensorNames = { fNY };
0085 if (!fNB.empty()) {
0086 fInputTensorNames.emplace_back(fNB);
0087 }
0088
0089 if (std::is_same<T, float>::value) {
0090 fType = "float";
0091 } else {
0092 throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0093 }
0094 }
0095
0096
0097
0098
0099 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override
0100 {
0101 ETensorType out = input[0];
0102 return {out};
0103 }
0104
0105
0106
0107
0108 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> ) override;
0109
0110
0111
0112
0113 void Initialize(RModel &) override;
0114
0115
0116
0117 std::string GenerateInitCode() override;
0118
0119
0120
0121
0122 std::string Generate(std::string opName) override;
0123
0124
0125
0126 std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
0127 };
0128
0129 }
0130 }
0131 }
0132
0133
0134 #include "TMVA/ROperator_ConvTranspose.icc"
0135
0136 #endif