Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:05

0001 #ifndef TMVA_SOFIE_ROPERATOR_CONV
0002 #define TMVA_SOFIE_ROPERATOR_CONV
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <memory>
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <stdexcept>
0012 #include <vector>
0013 #include <cassert>
0014 
0015 namespace TMVA {
0016 namespace Experimental {
0017 namespace SOFIE {
0018 
0019 template<typename T>
0020 class ROperator_Conv final : public ROperator
0021 {
0022 private:
0023    std::string fAttrAutopad;
0024    std::vector<size_t> fAttrDilations;
0025    size_t fAttrGroup;
0026    std::vector<size_t> fAttrKernelShape;
0027    std::vector<size_t> fAttrPads;
0028    std::vector<size_t> fAttrStrides;
0029 
0030    std::string fNX;
0031    std::string fNW;
0032    std::string fNB;
0033    std::string fNB2; // bias tensor name after broadcasting
0034    std::string fNY;
0035 
0036    std::vector<size_t> fShapeX;
0037    std::vector<size_t> fShapeW;
0038    std::vector<size_t> fShapeB;
0039    std::vector<size_t> fShapeY;
0040 
0041    std::string fType;
0042 
0043    size_t fDim;   // dimension of the convolution
0044 
0045 
0046 public:
0047 
0048    ROperator_Conv() {}
0049 
0050    ROperator_Conv(std::string autopad, std::vector<size_t> dilations,
0051       size_t group, std::vector<size_t> kernelShape, std::vector<size_t> pads,
0052       std::vector<size_t> strides, std::string nameX, std::string nameW,
0053       std::string nameB, std::string nameY):
0054       fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0055       fAttrPads(pads), fAttrStrides(strides),
0056       fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0057       fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY))
0058    {
0059       if(std::is_same<T, float>::value) {
0060          fType = "float";
0061       } else {
0062          throw
0063             std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0064       }
0065    }
0066 
0067    ROperator_Conv(std::string autopad, std::vector<size_t> dilations,
0068       size_t group, std::vector<size_t> kernelShape, std::vector<size_t> pads,
0069       std::vector<size_t> strides, std::string nameX, std::string nameW,
0070       std::string nameY):
0071       fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0072       fAttrPads(pads), fAttrStrides(strides),
0073       fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNY(UTILITY::Clean_name(nameY))
0074    {
0075       if(std::is_same<T, float>::value) {
0076          fType = "float";
0077       } else {
0078          throw
0079             std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0080       }
0081    }
0082 
0083    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) {
0084       ETensorType out = input[0];
0085       return {out};
0086    }
0087 
0088    // function returning output shape given input
0089    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) {
0090       // shape of convolution input has to be (according to ONNX): N x C x H x W
0091       // Where N : batch size, C : input  channels, H : input height, W : input width
0092 
0093       if (input.size() > 3 ) {
0094          throw
0095             std::runtime_error("TMVA SOFIE Conv Op Shape inference need 2 or 3 input tensors");
0096       }
0097       for(size_t i = 0; i < input.size(); i++) {
0098          if (input[i].size() -2 != fDim) {
0099             throw
0100                std::runtime_error("TMVA SOFIE Conv Op Shape inference - invalid inputs ");
0101          }
0102       }
0103 
0104       if (fAttrGroup == 0) {
0105          fAttrGroup = input[0][1] / input[1][1];
0106       }
0107 
0108       // kernel shape
0109       size_t k1 = ((fAttrKernelShape.empty())? input[1][2] : fAttrKernelShape[0]);
0110       size_t k2 = (fDim > 1) ? ((fAttrKernelShape.empty()) ? input[1][3] : fAttrKernelShape[1]) : 1;
0111       size_t k3 = (fDim > 2) ? ((fAttrKernelShape.empty()) ? input[1][4] : fAttrKernelShape[2]) : 1;
0112 
0113 
0114       size_t i1 = (fDim > 1) ? ((fDim > 2) ? 3 : 2) : 1;
0115       size_t i2 = (fDim > 2) ? 4 : 3;
0116       size_t i3 = 5;
0117 
0118       if (fAttrDilations.empty()) {
0119          fAttrDilations = {1, 1, 1};
0120       }
0121       fAttrDilations.resize(3);
0122       if (fDim < 3) {
0123          fAttrDilations.resize(3, 1);
0124       }
0125       // Shape of the kernel
0126       fAttrKernelShape = {k1 + (fAttrDilations[0] - 1) * (k1 - 1),
0127                           k2 + (fAttrDilations[1] - 1) * (k2 - 1),
0128                           k3 + (fAttrDilations[2] - 1) * (k3 - 1)};
0129 
0130       if (fAttrAutopad == "NOTSET") {
0131          if (fAttrPads.empty()) {
0132             fAttrPads = {1, 1, 1, 1, 1, 1};
0133          }
0134       } else if (fAttrAutopad == "SAME_UPPER" || fAttrAutopad == "SAME_LOWER") {
0135          if (fDim == 1)
0136             fAttrPads = {fAttrKernelShape[0] / 2, fAttrKernelShape[0] / 2};
0137          else if (fDim == 2)
0138             fAttrPads = {fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2, fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2};
0139          else if (fDim == 3)
0140             fAttrPads = {fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2, fAttrKernelShape[2] / 2,
0141                          fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2, fAttrKernelShape[2] / 2};
0142          // add extra padding at beginning or end (depending if SAME_UPPER or SAME_LOWER)
0143          // need to check this!
0144          if (fAttrKernelShape[0] % 2 == 1) {
0145             (fAttrAutopad == "SAME_UPPER") ? fAttrPads[0]++ : fAttrPads[i1]++;
0146          }
0147          if (fDim > 1 && fAttrKernelShape[1] % 2 == 1) {
0148             (fAttrAutopad == "SAME_UPPER") ? fAttrPads[1]++ : fAttrPads[i2]++;
0149          }
0150          if (fDim > 2 && fAttrKernelShape[2] % 2 == 1) {
0151             (fAttrAutopad == "SAME_UPPER") ? fAttrPads[2]++ : fAttrPads[i3]++;
0152          }
0153       } else if (fAttrAutopad != "VALID") {
0154          throw
0155             std::runtime_error("TMVA SOFIE Conv Op invalid fAutopad");
0156       }
0157       // to be sure pad is vector of size 6
0158       if (fDim < 3) fAttrPads.resize(6, 0);
0159 
0160       if (fAttrStrides.empty()) {
0161          fAttrStrides = {1, 1, 1};
0162       }
0163       if (fDim < 3)
0164          fAttrStrides.resize(3, 1);
0165 
0166 
0167       size_t input1 = input[0][2];
0168       size_t input2 = (fDim > 1) ? input[0][3] : 1;
0169       size_t input3 = (fDim > 2) ? input[0][4] : 1;
0170 
0171       size_t pad1 = fAttrPads[0] + fAttrPads[i1];
0172       size_t output1 = (input1 + pad1 - fAttrKernelShape[0]) / fAttrStrides[0] + 1;
0173 
0174       size_t batch_size = input[0][0];        // first element in input tensor
0175       size_t output_channels = input[1][0];   // first element in weight tensor
0176 
0177       std::vector<std::vector<size_t>> ret({{ batch_size, output_channels, output1 }});
0178 
0179       if (fDim == 1)
0180          return ret;
0181 
0182       size_t pad2 = fAttrPads[1] + fAttrPads[i2];
0183       size_t output2 = (input2 + pad2 - fAttrKernelShape[1]) / fAttrStrides[1] + 1;
0184       // output is N x M x OH x OW
0185       ret[0].push_back(output2);
0186       if (fDim == 2)
0187          return ret;
0188 
0189       size_t pad3 = fAttrPads[2] + fAttrPads[i3];
0190       size_t output3 = (input3 + pad3 - fAttrKernelShape[2] ) / fAttrStrides[2] + 1;
0191 
0192       // output is N x M x OH x OW x OD
0193       ret[0].push_back(output3);
0194       return ret;
0195    }
0196 
0197    void Initialize(RModel& model) {
0198       fUseSession = model.UseSession();
0199       if (!model.CheckIfTensorAlreadyExist(fNX)) {
0200          throw
0201             std::runtime_error("TMVA SOFIE Conv op Input Tensor " + fNX + " is not found in model");
0202       }
0203       fShapeX = model.GetTensorShape(fNX);
0204       if (fShapeX.size() < 3 || fShapeX.size()  > 5) {
0205          std::cout << fNX << " : " << ConvertShapeToString(fShapeX) << std::endl;
0206          throw
0207             std::runtime_error("TMVA SOFIE Conv Op input data tensor" + fNX + " is not of 3,4 or 5 dimensions");
0208       }
0209       fDim = fShapeX.size() - 2;
0210       if (!model.CheckIfTensorAlreadyExist(fNW)) {
0211          throw
0212             std::runtime_error("TMVA SOFIE Conv op Input weight Tensor " + fNW + " is not found in model");
0213       }
0214       fShapeW = model.GetTensorShape(fNW);
0215       if (fShapeW.size() < 3 || fShapeW.size()  > 5) {
0216          std::cout << fNW << " : " << ConvertShapeToString(fShapeW) << std::endl;
0217          throw std::runtime_error("TMVA SOFIE Conv Op input weight tensor" + fNW + " is not of 3,4 or 5 dimensions");
0218       }
0219       fShapeY = ShapeInference({fShapeX, fShapeW})[0];
0220       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0221       if (fNB != "") {
0222          if (!model.CheckIfTensorAlreadyExist(fNB)) {
0223             throw
0224                std::runtime_error("TMVA SOFIE Conv op Input Tensor " + fNB + " is not found in model");
0225          }
0226          fShapeB = model.GetTensorShape(fNB);
0227          std::vector<size_t> targetShape(fShapeY.begin() + 1, fShapeY.end());
0228          bool broadcast_needed = !UTILITY::AreSameShape(fShapeB, targetShape);
0229          if (broadcast_needed) {
0230             auto original_data = model.GetInitializedTensorData(fNB);
0231             // make bias shape equal to Y shape by adding 1
0232             if (fShapeB.size() < 1)
0233                throw std::runtime_error("TMVA SOFIE Conv op: Bias Tensor has empty shape");
0234             // we assume bias tensor dimension is equal to number of filters that is the second dimension in
0235             // the output tensor
0236             if (fShapeB[0] != fShapeY[1])
0237                throw std::runtime_error("TMVA SOFIE Conv op: Bias Tensor has wrong shape: " +
0238                                            ConvertShapeToString(fShapeB));
0239             if (fType != "float")
0240                throw std::runtime_error("TMVA SOFIE Conv op: Broadcasting for non-float type tensors is not supported");
0241             // here is the actual broadcasting
0242             if (!fUseSession) {
0243                std::vector<size_t> shape(fDim + 1, 1);
0244                shape[0] = fShapeB[0];
0245                std::shared_ptr<void> new_data_ptr(
0246                   UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(original_data.get()), shape, targetShape),
0247                   std::default_delete<float[]>());
0248                model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), targetShape, new_data_ptr);
0249                fShapeB = model.GetTensorShape(fNB);
0250                fNB2 = fNB;   // use same name
0251             }
0252             else {
0253                // In case of session add broadcasting code in Session constructor and in GenerateInitCode
0254                // we need to add a new intermediate tensor for broadcasted bias tensor
0255                fNB2 = fNB + "bcast";
0256                model.AddIntermediateTensor(fNB2, model.GetTensorType(fNB), targetShape);
0257             }
0258          }
0259       }
0260    }
0261 
0262    std::string GenerateInitCode() {
0263       std::stringstream out;
0264       // Generate initialization code for broadcasting of bias tensor
0265       if (!fNB2.empty()) {
0266          // include a separate scope to avoid defining unique operator temp variables
0267          std::vector<size_t> shape(fDim + 1, 1);
0268          shape[0] = fShapeB[0];
0269          std::vector<size_t> targetShape(fShapeY.begin() + 1, fShapeY.end());
0270          out << SP << "{\n";
0271          out << SP << SP << "float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_"
0272              << fNB << ", " << ConvertShapeToString(shape) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0273          out << SP << SP << "std::copy(data, data + " << ConvertShapeToLength(targetShape) << ", tensor_" << fNB2 << ");\n";
0274          out << SP << SP << "delete[] data;\n";
0275          out << SP << "}\n";
0276       }
0277       return out.str();
0278    }
0279 
0280    // Generate code for Session data members (e.g. internal vectors)
0281    virtual std::string GenerateSessionMembersCode(std::string opName) {
0282 
0283       size_t outputChannelSize = fShapeY[2];  // size/channel = D * H * W
0284       size_t kernelSize = fAttrKernelShape[0];
0285       for (size_t i = 1; i < fDim; i++) {
0286          outputChannelSize *= fShapeY[2 + i];
0287          kernelSize *= fAttrKernelShape[i];
0288       }
0289 
0290       opName = "op_" + opName;
0291       std::stringstream out;
0292       // matrix with convolution kernels
0293       out << "std::vector<" << fType << "> fVec_" << opName << "_f = std::vector<" << fType << ">("
0294           << fShapeW[0] * fShapeW[1] * kernelSize << ");\n";
0295       // output matrix of im2col
0296       out << "std::vector<" << fType << "> fVec_" << opName << "_xcol = std::vector<" << fType << ">("
0297           << fShapeW[1] * kernelSize * outputChannelSize << ");\n";
0298       out << "\n";
0299 
0300       return out.str();
0301    }
0302 
0303    std::string Generate(std::string OpName) {
0304       OpName = "op_" + OpName;
0305 
0306       if (fShapeX.empty() || fShapeW.empty() || (fNB != "" && fShapeB.empty()) || fShapeY.empty()) {
0307          throw
0308             std::runtime_error("TMVA SOFIE Conv Op called to Generate without being initialized first");
0309       }
0310 
0311       std::stringstream out;
0312       size_t bsize = fShapeX[0];
0313       size_t kDepth = (fDim > 2) ?  fShapeW[2] : 1;  // kernel depth
0314       size_t kHeight = (fDim > 1) ? fShapeW[fDim] : 1;  // kernel height
0315       size_t kWidth = fShapeW[fDim+1]; // kernel width
0316       size_t iDepth = (fDim > 2) ?  fShapeX[2] : 1;  // input depth
0317       size_t iHeight = (fDim > 1) ? fShapeX[fDim] : 1; // input height
0318       size_t iWidth = fShapeX[fDim+1]; // input width
0319       size_t oDepth = (fDim > 2) ? fShapeY[2] : 1; // output depth
0320       size_t oHeight = (fDim > 1) ? fShapeY[fDim] : 1;  // ouput height
0321       size_t oWidth = fShapeY[fDim+1]; // output width
0322 
0323       out << "\n//----  operator Conv " << OpName << "\n";
0324 
0325       // create first matrix with convolution kernels
0326       if (fUseSession)
0327          out << SP << fType << " * " << OpName << "_f = fVec_" << OpName << "_f.data();\n";
0328       else
0329          out << SP << fType << " " << OpName << "_f[" << fShapeW[0] * fShapeW[1] * fAttrKernelShape[0] * fAttrKernelShape[1] << "] = {0};\n";
0330 
0331       // vectorize the (dilated)convolution kernels into a matrix
0332       // no need to transpose the matrix
0333       // to fix for 1d and 3d
0334 
0335       size_t id = (fDim > 2) ? fDim-3 : 2;
0336       size_t ih = (fDim > 1) ? fDim-2 : 1;
0337       size_t iw = fDim-1;
0338 
0339       size_t wstrideDil = fAttrDilations[iw];
0340       size_t hstride = kWidth;
0341       size_t hstrideDil = fAttrDilations[ih] * fAttrKernelShape[iw];  // stride dilated in the height
0342       size_t dstride = kHeight * kWidth;
0343       size_t dstrideDil = fAttrDilations[id] * fAttrKernelShape[ih] * fAttrKernelShape[iw];
0344       size_t icstride = kHeight * kWidth * kDepth;
0345       size_t icstrideDil = fAttrKernelShape[id] * fAttrKernelShape[ih] * fAttrKernelShape[iw];
0346       size_t ocstride = fShapeW[1] * icstride;
0347       size_t ocstrideDil = fShapeW[1] * icstrideDil;
0348 
0349       out << SP << "for (std::size_t oc = 0; oc < " << fShapeW[0] << "; oc++) {\n";
0350       out << SP << SP << "for (std::size_t ic = 0; ic < " << fShapeW[1] << "; ic++) {\n";
0351       if (fDim > 2)
0352          out << SP << SP << SP << "for (std::size_t kd = 0; kd < " << kDepth << "; kd++) {\n";
0353       if (fDim > 1)
0354          out << SP << SP << SP << "for (std::size_t kh = 0; kh < " << kHeight << "; kh++) {\n";
0355       out << SP << SP << SP << SP << "for (std::size_t kw = 0; kw < " << kWidth << "; kw++) {\n";
0356 
0357       out << SP << SP << SP << SP << SP << OpName <<  "_f[oc * "
0358           << ocstrideDil << " + ic * " << icstrideDil;
0359       if (fDim > 2) out << " + kd * " << dstrideDil;
0360       if (fDim > 1) out << " + kh * " << hstrideDil;
0361       out << " + kw * " << wstrideDil  << "  ] = tensor_" << fNW << "[oc * " << ocstride << " + ic * " << icstride;
0362       if (fDim > 2) out << " + kd * " << dstride;
0363       if (fDim > 1) out << " + kh * " << hstride;
0364       out  << " + kw ];\n";
0365 
0366       out << SP << SP << SP << SP << "}\n";
0367       if (fDim > 1) out << SP << SP << SP << "}\n";
0368       if (fDim > 2) out << SP << SP << SP << "}\n";
0369       out << SP << SP << "}\n";
0370       out << SP << "}\n";
0371 
0372       //out << SP << "char " << OpName << "_transA = 'T';\n";
0373       out << SP << "char " << OpName << "_transA = 'N';\n";
0374       out << SP << "char " << OpName << "_transB = 'N';\n";
0375       out << SP << "int " << OpName << "_m = " << oHeight * oWidth * oDepth << ";\n"; // output h*w
0376       assert(fShapeY[1] == fShapeW[0]);
0377       assert(fShapeW[1] == fShapeX[1] / fAttrGroup);
0378       out << SP << "int " << OpName << "_n = " << fShapeW[0] << ";\n"; // output channels
0379       out << SP << "int " << OpName << "_k = " << fShapeW[1] * fAttrKernelShape[0] * fAttrKernelShape[1] * fAttrKernelShape[2] << ";\n";
0380       out << SP << "float " << OpName << "_alpha = 1.0;\n";
0381       out << SP << "float " << OpName << "_beta = 0.0;\n";
0382 
0383       if (fUseSession) {
0384          out << SP << fType << " * " << OpName << "_xcol = fVec_" << OpName << "_xcol.data();\n";
0385       }
0386       else {
0387          out << SP << fType << " " << OpName << "_xcol["
0388              << fShapeX[1] * fAttrKernelShape[0] * fAttrKernelShape[1] * fAttrKernelShape[2] * oDepth * oHeight * oWidth
0389              << "] = {0};\n";
0390       }
0391 
0392       // Loop on batch size
0393       out << SP << "for (size_t n = 0; n < " << bsize << "; n++) {\n";
0394 
0395       // IM2COL: Unroll the input tensor
0396       // order input data as  (e.g. kernel 2x2)  and (xa,ya) is channel 1 and (xb,yb) is channel 2
0397       //   (xa1,..,xak,ya1,..yak)(xb1,...,xbk,yb1,..,ybk)
0398       //   (xa2,...xak+1,ya1,...yak)(......)
0399       // trick for speed is using caffe im2col and output a matrix which contains filtered values as rows.
0400       // By doing this one has consecutive memory reads and writes
0401       // Resulting matrix op_xcol is (input channels * filter_h * filter_w , output_h * output_w)
0402       if (fDim ==1) {
0403          if (fAttrPads[0] != fAttrPads[1] ) {
0404             std::cout << "TMVA SOFIE Operator Conv:  asymmetric padding not supported. Assume an average padding "
0405                       << std::endl;
0406             fAttrPads[0] = (fAttrPads[0] + fAttrPads[1]) / 2;
0407          }
0408          fAttrPads[1] = 0;
0409          fAttrStrides[1] = 1;
0410       }
0411       if (fDim == 2) {
0412          if (fAttrPads[0] != fAttrPads[2] || fAttrPads[1] != fAttrPads[3]) {
0413             std::cout << "TMVA SOFIE Operator Conv:  asymmetric padding not supported. Assume an average padding " << std::endl;
0414             fAttrPads[0] = (fAttrPads[0] + fAttrPads[2]) / 2;
0415             fAttrPads[1] = (fAttrPads[1] + fAttrPads[3]) / 2;
0416          }
0417       }
0418       if (fDim == 3) {
0419          if (fAttrPads[0] != fAttrPads[3] || fAttrPads[1] != fAttrPads[4] || fAttrPads[2] != fAttrPads[5]) {
0420             std::cout << "TMVA SOFIE Operator Conv:  asymmetric padding not supported. Assume an average padding " << std::endl;
0421             fAttrPads[0] = (fAttrPads[0] + fAttrPads[3]) / 2;
0422             fAttrPads[1] = (fAttrPads[1] + fAttrPads[4]) / 2;
0423             fAttrPads[2] = (fAttrPads[2] + fAttrPads[5]) / 2;
0424          }
0425       }
0426       out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
0427 
0428       if (fAttrGroup == 1) {
0429          out << SP << SP << "size_t x_offset = n * " << fShapeX[1] * iHeight * iWidth << ";\n";
0430          // when using im2col - resulting matrix is transposed, the dimension is (input_c * filter_h * filter_y,  output_h *
0431          // output_w)
0432          if (fDim < 3) {
0433             out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col<float>(tensor_" << fNX
0434                 << " + x_offset,"
0435                 //  channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
0436                 //  dilation_w,
0437                 //
0438                 << fShapeW[1] << "," << iHeight << "," << iWidth << ",";
0439             if (fDim == 1)
0440                out << "1, " << fAttrKernelShape[0] << ",0," << fAttrPads[0] << ",1," << fAttrStrides[0] << ",1,"
0441                    << fAttrDilations[0];
0442             else // dim ==2
0443                out << fAttrKernelShape[0] << "," << fAttrKernelShape[1] << "," << fAttrPads[0] << "," << fAttrPads[1]
0444                    << "," << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrDilations[0] << ","
0445                    << fAttrDilations[1];
0446             out << "," << OpName << "_xcol);\n\n ";
0447          } else {
0448             // 3d im2col
0449             out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col_3d<float>(tensor_" << fNX
0450                 << " + x_offset,"
0451                 //  channels, d, h, w, k_d, k_h, k_w, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w,
0452                 //  dilation_d, dilation_h, dilation_w,
0453                 //
0454                 << fShapeW[1] << "," << iDepth << "," << iHeight << "," << iWidth << ","
0455                 << fAttrKernelShape[0] << "," << fAttrKernelShape[1] << "," << fAttrKernelShape[2] << ","
0456                 << fAttrPads[0] << "," << fAttrPads[1] << "," << fAttrPads[2] << ","
0457                 << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrStrides[2] << ","
0458                 << fAttrDilations[0] << "," << fAttrDilations[1] << "," << fAttrDilations[2] << ","
0459                 << OpName << "_xcol);\n\n ";
0460          }
0461          // BLAS
0462          out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transA, &" << OpName << "_transB, &" << OpName << "_m, &"
0463              << OpName << "_n, &" << OpName << "_k, &" << OpName << "_alpha, " << OpName << "_xcol, &" << OpName
0464              << "_m,\n"; // use m if op_xcol is not transpose , otherwise k
0465          out << SP << SP << SP << OpName << "_f, &" << OpName << "_k, &" << OpName << "_beta, tensor_" << fNY
0466              << " + out_offset, &" << OpName << "_m);\n";
0467       } else {
0468          // case of group convolution
0469          // Unroll (IM2COL) the input tensor- make loop on groups and repeat operations (IM2COL + GEMM for each
0470          // group)
0471          // out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
0472          out << SP << SP << "for (size_t g = 0; g < " << fAttrGroup << "; g++) {\n";
0473          out << SP << SP << "size_t x_offset = n * " << fShapeX[1] * iDepth * iHeight * iWidth << " + g * "
0474              << fShapeW[1] * iDepth * iHeight * iWidth << ";\n ";
0475          out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << " + g * "
0476              << fShapeW[0] * oDepth * oHeight * oWidth / fAttrGroup << ";\n ";
0477 
0478          if (fDim < 3) {
0479             out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col<float>(tensor_" << fNX
0480                 << " + x_offset,"
0481                 //  channels, height, width, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
0482                 //  dilation_w,
0483                 //
0484                 << fShapeW[1] << "," << iHeight << "," << iWidth << ",";
0485             if (fDim == 1)
0486                out << "1, " << fAttrKernelShape[0] << ",0," << fAttrPads[0] << ",1," << fAttrStrides[0] << ",1,"
0487                    << fAttrDilations[0];
0488             else // dim ==2
0489                out << fAttrKernelShape[0] << "," << fAttrKernelShape[1] << "," << fAttrPads[0] << "," << fAttrPads[1]
0490                    << "," << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrDilations[0] << ","
0491                    << fAttrDilations[1];
0492             out << "," << OpName << "_xcol);\n\n ";
0493          } else {
0494             // 3d im2col
0495             out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col_3d<float>(tensor_" << fNX
0496                 << " + x_offset,"
0497                 //  channels, d, h, w, k_d, k_h, k_w, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w,
0498                 //  dilation_d, dilation_h, dilation_w,
0499                 //
0500                 << fShapeW[1] << "," << iDepth << "," << iHeight << "," << iWidth << "," << fAttrKernelShape[0] << ","
0501                 << fAttrKernelShape[1] << "," << fAttrKernelShape[2] << "," << fAttrPads[0] << "," << fAttrPads[1]
0502                 << "," << fAttrPads[2] << "," << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrStrides[2]
0503                 << "," << fAttrDilations[0] << "," << fAttrDilations[1] << "," << fAttrDilations[2] << "," << OpName
0504                 << "_xcol);\n\n ";
0505          }
0506 
0507          // BLAS
0508          // n must be divided by the number of groups
0509          out << SP << SP << SP << OpName << "_n = " << fShapeW[0] / fAttrGroup << ";\n";
0510          // offset g must be  g * k * n
0511          out << SP << SP << SP << "size_t offset_f = g * "
0512              << fShapeW[0] * fShapeW[1] * fAttrKernelShape[0] * fAttrKernelShape[1] * fAttrKernelShape[2] / fAttrGroup
0513              << ";\n";
0514          out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transA, &" << OpName << "_transB, &" << OpName << "_m, &"
0515              << OpName << "_n, &" << OpName << "_k, &" << OpName << "_alpha, " << OpName << "_xcol, &" << OpName
0516              << "_m,\n"; // use m if op_xcol is not transpose , otherwise k
0517          out << SP << SP << SP << OpName << "_f + offset_f, &" << OpName << "_k, &" << OpName << "_beta, tensor_" << fNY
0518              << " + out_offset"
0519              << ", &" << OpName << "_m);\n";
0520 
0521          out << SP << SP << "}\n"; // end of group loop
0522       }
0523 
0524       if (fNB2 != "") {
0525          out << SP << "int " << OpName << "_size = " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
0526          out << SP << "float " << OpName << "_gamma = 1.0;\n";
0527          out << SP << "int " << OpName << "_incx = 1;\n";
0528          out << SP << "int " << OpName << "_incy = 1;\n";
0529 
0530          out << SP << "BLAS::saxpy_(&" << OpName << "_size, &" << OpName << "_gamma, tensor_" << fNB2 << ", &"
0531              << OpName << "_incx, tensor_" << fNY << " + out_offset, &" << OpName << "_incy);\n";
0532 
0533       }
0534       out << SP << "}\n"; // end of batch size loop
0535 
0536       return out.str();
0537       }
0538 
0539    /*! \brief Returns the blas routines needed to compile the generated code
0540     */
0541    std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0542 };
0543 
0544 } // namespace SOFIE
0545 } // namespace Experimental
0546 } // namespace TMVA
0547 
0548 #endif