File indexing completed on 2025-01-18 10:11:05
0001 #ifndef TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0002 #define TMVA_SOFIE_ROPERATOR_CONVTRANSPOSE_HXX
0003
0004 #include <TMVA/SOFIE_common.hxx>
0005 #include <TMVA/ROperator.hxx>
0006 #include <TMVA/RModel.hxx>
0007
0008 #include <memory>
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <stdexcept>
0012 #include <vector>
0013 #include <cassert>
0014
0015 namespace TMVA {
0016 namespace Experimental {
0017 namespace SOFIE {
0018
0019
0020
0021
0022
0023
0024
0025 template <typename T>
0026 class ROperator_ConvTranspose final : public ROperator {
0027 private:
0028 std::string fAttrAutopad;
0029 std::vector<size_t> fAttrDilations;
0030 size_t fAttrGroup;
0031 std::vector<size_t> fAttrKernelShape;
0032 std::vector<size_t> fAttrOutputPadding;
0033 std::vector<size_t> fAttrOutputShape;
0034 std::vector<size_t> fAttrPads;
0035 std::vector<size_t> fAttrStrides;
0036
0037 std::string fNX;
0038 std::string fNW;
0039 std::string fNB;
0040 std::string fNBroadcastedB;
0041 std::string fNY;
0042
0043 std::vector<size_t> fShapeX;
0044 std::vector<size_t> fShapeW;
0045 std::vector<size_t> fShapeB;
0046 std::vector<size_t> fShapeY;
0047
0048 std::string fType;
0049
0050 size_t fDim;
0051
0052 public:
0053
0054 ROperator_ConvTranspose() {}
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071 ROperator_ConvTranspose(std::string autopad, std::vector<size_t> dilations, size_t group,
0072 std::vector<size_t> kernelShape, std::vector<size_t> outputPadding,
0073 std::vector<size_t> outputShape, std::vector<size_t> pads, std::vector<size_t> strides,
0074 std::string nameX, std::string nameW, std::string nameB, std::string nameY)
0075 : fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0076 fAttrOutputPadding(outputPadding), fAttrOutputShape(outputShape), fAttrPads(pads), fAttrStrides(strides),
0077 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNB(UTILITY::Clean_name(nameB)),
0078 fNY(UTILITY::Clean_name(nameY))
0079 {
0080 if (std::is_same<T, float>::value) {
0081 fType = "float";
0082 } else {
0083 throw std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0084 }
0085 }
0086
0087
0088
0089
0090 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override
0091 {
0092 ETensorType out = input[0];
0093 return {out};
0094 }
0095
0096
0097
0098
0099 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> ) override;
0100
0101
0102
0103
0104 void Initialize(RModel & ) override;
0105
0106
0107
0108 std::string GenerateInitCode() override;
0109
0110
0111
0112
0113 std::string GenerateSessionMembersCode(std::string ) override;
0114
0115
0116
0117
0118 std::string Generate(std::string opName) override;
0119
0120
0121
0122 std::vector<std::string> GetBlasRoutines() override { return { std::string("Gemm"), std::string("Axpy") }; }
0123 };
0124
0125 }
0126 }
0127 }
0128
0129
0130 #include "TMVA/ROperator_ConvTranspose.icc"
0131
0132 #endif