File indexing completed on 2025-01-18 10:11:05
0001 #ifndef TMVA_SOFIE_ROPERATOR_CONV
0002 #define TMVA_SOFIE_ROPERATOR_CONV
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <memory>
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <stdexcept>
0012 #include <vector>
0013 #include <cassert>
0014
0015 namespace TMVA {
0016 namespace Experimental {
0017 namespace SOFIE {
0018
0019 template<typename T>
0020 class ROperator_Conv final : public ROperator
0021 {
0022 private:
0023 std::string fAttrAutopad;
0024 std::vector<size_t> fAttrDilations;
0025 size_t fAttrGroup;
0026 std::vector<size_t> fAttrKernelShape;
0027 std::vector<size_t> fAttrPads;
0028 std::vector<size_t> fAttrStrides;
0029
0030 std::string fNX;
0031 std::string fNW;
0032 std::string fNB;
0033 std::string fNB2;
0034 std::string fNY;
0035
0036 std::vector<size_t> fShapeX;
0037 std::vector<size_t> fShapeW;
0038 std::vector<size_t> fShapeB;
0039 std::vector<size_t> fShapeY;
0040
0041 std::string fType;
0042
0043 size_t fDim;
0044
0045
0046 public:
0047
0048 ROperator_Conv() {}
0049
0050 ROperator_Conv(std::string autopad, std::vector<size_t> dilations,
0051 size_t group, std::vector<size_t> kernelShape, std::vector<size_t> pads,
0052 std::vector<size_t> strides, std::string nameX, std::string nameW,
0053 std::string nameB, std::string nameY):
0054 fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0055 fAttrPads(pads), fAttrStrides(strides),
0056 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0057 fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY))
0058 {
0059 if(std::is_same<T, float>::value) {
0060 fType = "float";
0061 } else {
0062 throw
0063 std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0064 }
0065 }
0066
0067 ROperator_Conv(std::string autopad, std::vector<size_t> dilations,
0068 size_t group, std::vector<size_t> kernelShape, std::vector<size_t> pads,
0069 std::vector<size_t> strides, std::string nameX, std::string nameW,
0070 std::string nameY):
0071 fAttrAutopad(autopad), fAttrDilations(dilations), fAttrGroup(group), fAttrKernelShape(kernelShape),
0072 fAttrPads(pads), fAttrStrides(strides),
0073 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)), fNY(UTILITY::Clean_name(nameY))
0074 {
0075 if(std::is_same<T, float>::value) {
0076 fType = "float";
0077 } else {
0078 throw
0079 std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a Conv operator");
0080 }
0081 }
0082
0083 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) {
0084 ETensorType out = input[0];
0085 return {out};
0086 }
0087
0088
0089 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) {
0090
0091
0092
0093 if (input.size() > 3 ) {
0094 throw
0095 std::runtime_error("TMVA SOFIE Conv Op Shape inference need 2 or 3 input tensors");
0096 }
0097 for(size_t i = 0; i < input.size(); i++) {
0098 if (input[i].size() -2 != fDim) {
0099 throw
0100 std::runtime_error("TMVA SOFIE Conv Op Shape inference - invalid inputs ");
0101 }
0102 }
0103
0104 if (fAttrGroup == 0) {
0105 fAttrGroup = input[0][1] / input[1][1];
0106 }
0107
0108
0109 size_t k1 = ((fAttrKernelShape.empty())? input[1][2] : fAttrKernelShape[0]);
0110 size_t k2 = (fDim > 1) ? ((fAttrKernelShape.empty()) ? input[1][3] : fAttrKernelShape[1]) : 1;
0111 size_t k3 = (fDim > 2) ? ((fAttrKernelShape.empty()) ? input[1][4] : fAttrKernelShape[2]) : 1;
0112
0113
0114 size_t i1 = (fDim > 1) ? ((fDim > 2) ? 3 : 2) : 1;
0115 size_t i2 = (fDim > 2) ? 4 : 3;
0116 size_t i3 = 5;
0117
0118 if (fAttrDilations.empty()) {
0119 fAttrDilations = {1, 1, 1};
0120 }
0121 fAttrDilations.resize(3);
0122 if (fDim < 3) {
0123 fAttrDilations.resize(3, 1);
0124 }
0125
0126 fAttrKernelShape = {k1 + (fAttrDilations[0] - 1) * (k1 - 1),
0127 k2 + (fAttrDilations[1] - 1) * (k2 - 1),
0128 k3 + (fAttrDilations[2] - 1) * (k3 - 1)};
0129
0130 if (fAttrAutopad == "NOTSET") {
0131 if (fAttrPads.empty()) {
0132 fAttrPads = {1, 1, 1, 1, 1, 1};
0133 }
0134 } else if (fAttrAutopad == "SAME_UPPER" || fAttrAutopad == "SAME_LOWER") {
0135 if (fDim == 1)
0136 fAttrPads = {fAttrKernelShape[0] / 2, fAttrKernelShape[0] / 2};
0137 else if (fDim == 2)
0138 fAttrPads = {fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2, fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2};
0139 else if (fDim == 3)
0140 fAttrPads = {fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2, fAttrKernelShape[2] / 2,
0141 fAttrKernelShape[0] / 2, fAttrKernelShape[1] / 2, fAttrKernelShape[2] / 2};
0142
0143
0144 if (fAttrKernelShape[0] % 2 == 1) {
0145 (fAttrAutopad == "SAME_UPPER") ? fAttrPads[0]++ : fAttrPads[i1]++;
0146 }
0147 if (fDim > 1 && fAttrKernelShape[1] % 2 == 1) {
0148 (fAttrAutopad == "SAME_UPPER") ? fAttrPads[1]++ : fAttrPads[i2]++;
0149 }
0150 if (fDim > 2 && fAttrKernelShape[2] % 2 == 1) {
0151 (fAttrAutopad == "SAME_UPPER") ? fAttrPads[2]++ : fAttrPads[i3]++;
0152 }
0153 } else if (fAttrAutopad != "VALID") {
0154 throw
0155 std::runtime_error("TMVA SOFIE Conv Op invalid fAutopad");
0156 }
0157
0158 if (fDim < 3) fAttrPads.resize(6, 0);
0159
0160 if (fAttrStrides.empty()) {
0161 fAttrStrides = {1, 1, 1};
0162 }
0163 if (fDim < 3)
0164 fAttrStrides.resize(3, 1);
0165
0166
0167 size_t input1 = input[0][2];
0168 size_t input2 = (fDim > 1) ? input[0][3] : 1;
0169 size_t input3 = (fDim > 2) ? input[0][4] : 1;
0170
0171 size_t pad1 = fAttrPads[0] + fAttrPads[i1];
0172 size_t output1 = (input1 + pad1 - fAttrKernelShape[0]) / fAttrStrides[0] + 1;
0173
0174 size_t batch_size = input[0][0];
0175 size_t output_channels = input[1][0];
0176
0177 std::vector<std::vector<size_t>> ret({{ batch_size, output_channels, output1 }});
0178
0179 if (fDim == 1)
0180 return ret;
0181
0182 size_t pad2 = fAttrPads[1] + fAttrPads[i2];
0183 size_t output2 = (input2 + pad2 - fAttrKernelShape[1]) / fAttrStrides[1] + 1;
0184
0185 ret[0].push_back(output2);
0186 if (fDim == 2)
0187 return ret;
0188
0189 size_t pad3 = fAttrPads[2] + fAttrPads[i3];
0190 size_t output3 = (input3 + pad3 - fAttrKernelShape[2] ) / fAttrStrides[2] + 1;
0191
0192
0193 ret[0].push_back(output3);
0194 return ret;
0195 }
0196
0197 void Initialize(RModel& model) {
0198 fUseSession = model.UseSession();
0199 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0200 throw
0201 std::runtime_error("TMVA SOFIE Conv op Input Tensor " + fNX + " is not found in model");
0202 }
0203 fShapeX = model.GetTensorShape(fNX);
0204 if (fShapeX.size() < 3 || fShapeX.size() > 5) {
0205 std::cout << fNX << " : " << ConvertShapeToString(fShapeX) << std::endl;
0206 throw
0207 std::runtime_error("TMVA SOFIE Conv Op input data tensor" + fNX + " is not of 3,4 or 5 dimensions");
0208 }
0209 fDim = fShapeX.size() - 2;
0210 if (!model.CheckIfTensorAlreadyExist(fNW)) {
0211 throw
0212 std::runtime_error("TMVA SOFIE Conv op Input weight Tensor " + fNW + " is not found in model");
0213 }
0214 fShapeW = model.GetTensorShape(fNW);
0215 if (fShapeW.size() < 3 || fShapeW.size() > 5) {
0216 std::cout << fNW << " : " << ConvertShapeToString(fShapeW) << std::endl;
0217 throw std::runtime_error("TMVA SOFIE Conv Op input weight tensor" + fNW + " is not of 3,4 or 5 dimensions");
0218 }
0219 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
0220 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0221 if (fNB != "") {
0222 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0223 throw
0224 std::runtime_error("TMVA SOFIE Conv op Input Tensor " + fNB + " is not found in model");
0225 }
0226 fShapeB = model.GetTensorShape(fNB);
0227 std::vector<size_t> targetShape(fShapeY.begin() + 1, fShapeY.end());
0228 bool broadcast_needed = !UTILITY::AreSameShape(fShapeB, targetShape);
0229 if (broadcast_needed) {
0230 auto original_data = model.GetInitializedTensorData(fNB);
0231
0232 if (fShapeB.size() < 1)
0233 throw std::runtime_error("TMVA SOFIE Conv op: Bias Tensor has empty shape");
0234
0235
0236 if (fShapeB[0] != fShapeY[1])
0237 throw std::runtime_error("TMVA SOFIE Conv op: Bias Tensor has wrong shape: " +
0238 ConvertShapeToString(fShapeB));
0239 if (fType != "float")
0240 throw std::runtime_error("TMVA SOFIE Conv op: Broadcasting for non-float type tensors is not supported");
0241
0242 if (!fUseSession) {
0243 std::vector<size_t> shape(fDim + 1, 1);
0244 shape[0] = fShapeB[0];
0245 std::shared_ptr<void> new_data_ptr(
0246 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(original_data.get()), shape, targetShape),
0247 std::default_delete<float[]>());
0248 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), targetShape, new_data_ptr);
0249 fShapeB = model.GetTensorShape(fNB);
0250 fNB2 = fNB;
0251 }
0252 else {
0253
0254
0255 fNB2 = fNB + "bcast";
0256 model.AddIntermediateTensor(fNB2, model.GetTensorType(fNB), targetShape);
0257 }
0258 }
0259 }
0260 }
0261
0262 std::string GenerateInitCode() {
0263 std::stringstream out;
0264
0265 if (!fNB2.empty()) {
0266
0267 std::vector<size_t> shape(fDim + 1, 1);
0268 shape[0] = fShapeB[0];
0269 std::vector<size_t> targetShape(fShapeY.begin() + 1, fShapeY.end());
0270 out << SP << "{\n";
0271 out << SP << SP << "float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_"
0272 << fNB << ", " << ConvertShapeToString(shape) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0273 out << SP << SP << "std::copy(data, data + " << ConvertShapeToLength(targetShape) << ", tensor_" << fNB2 << ");\n";
0274 out << SP << SP << "delete[] data;\n";
0275 out << SP << "}\n";
0276 }
0277 return out.str();
0278 }
0279
0280
0281 virtual std::string GenerateSessionMembersCode(std::string opName) {
0282
0283 size_t outputChannelSize = fShapeY[2];
0284 size_t kernelSize = fAttrKernelShape[0];
0285 for (size_t i = 1; i < fDim; i++) {
0286 outputChannelSize *= fShapeY[2 + i];
0287 kernelSize *= fAttrKernelShape[i];
0288 }
0289
0290 opName = "op_" + opName;
0291 std::stringstream out;
0292
0293 out << "std::vector<" << fType << "> fVec_" << opName << "_f = std::vector<" << fType << ">("
0294 << fShapeW[0] * fShapeW[1] * kernelSize << ");\n";
0295
0296 out << "std::vector<" << fType << "> fVec_" << opName << "_xcol = std::vector<" << fType << ">("
0297 << fShapeW[1] * kernelSize * outputChannelSize << ");\n";
0298 out << "\n";
0299
0300 return out.str();
0301 }
0302
0303 std::string Generate(std::string OpName) {
0304 OpName = "op_" + OpName;
0305
0306 if (fShapeX.empty() || fShapeW.empty() || (fNB != "" && fShapeB.empty()) || fShapeY.empty()) {
0307 throw
0308 std::runtime_error("TMVA SOFIE Conv Op called to Generate without being initialized first");
0309 }
0310
0311 std::stringstream out;
0312 size_t bsize = fShapeX[0];
0313 size_t kDepth = (fDim > 2) ? fShapeW[2] : 1;
0314 size_t kHeight = (fDim > 1) ? fShapeW[fDim] : 1;
0315 size_t kWidth = fShapeW[fDim+1];
0316 size_t iDepth = (fDim > 2) ? fShapeX[2] : 1;
0317 size_t iHeight = (fDim > 1) ? fShapeX[fDim] : 1;
0318 size_t iWidth = fShapeX[fDim+1];
0319 size_t oDepth = (fDim > 2) ? fShapeY[2] : 1;
0320 size_t oHeight = (fDim > 1) ? fShapeY[fDim] : 1;
0321 size_t oWidth = fShapeY[fDim+1];
0322
0323 out << "\n//---- operator Conv " << OpName << "\n";
0324
0325
0326 if (fUseSession)
0327 out << SP << fType << " * " << OpName << "_f = fVec_" << OpName << "_f.data();\n";
0328 else
0329 out << SP << fType << " " << OpName << "_f[" << fShapeW[0] * fShapeW[1] * fAttrKernelShape[0] * fAttrKernelShape[1] << "] = {0};\n";
0330
0331
0332
0333
0334
0335 size_t id = (fDim > 2) ? fDim-3 : 2;
0336 size_t ih = (fDim > 1) ? fDim-2 : 1;
0337 size_t iw = fDim-1;
0338
0339 size_t wstrideDil = fAttrDilations[iw];
0340 size_t hstride = kWidth;
0341 size_t hstrideDil = fAttrDilations[ih] * fAttrKernelShape[iw];
0342 size_t dstride = kHeight * kWidth;
0343 size_t dstrideDil = fAttrDilations[id] * fAttrKernelShape[ih] * fAttrKernelShape[iw];
0344 size_t icstride = kHeight * kWidth * kDepth;
0345 size_t icstrideDil = fAttrKernelShape[id] * fAttrKernelShape[ih] * fAttrKernelShape[iw];
0346 size_t ocstride = fShapeW[1] * icstride;
0347 size_t ocstrideDil = fShapeW[1] * icstrideDil;
0348
0349 out << SP << "for (std::size_t oc = 0; oc < " << fShapeW[0] << "; oc++) {\n";
0350 out << SP << SP << "for (std::size_t ic = 0; ic < " << fShapeW[1] << "; ic++) {\n";
0351 if (fDim > 2)
0352 out << SP << SP << SP << "for (std::size_t kd = 0; kd < " << kDepth << "; kd++) {\n";
0353 if (fDim > 1)
0354 out << SP << SP << SP << "for (std::size_t kh = 0; kh < " << kHeight << "; kh++) {\n";
0355 out << SP << SP << SP << SP << "for (std::size_t kw = 0; kw < " << kWidth << "; kw++) {\n";
0356
0357 out << SP << SP << SP << SP << SP << OpName << "_f[oc * "
0358 << ocstrideDil << " + ic * " << icstrideDil;
0359 if (fDim > 2) out << " + kd * " << dstrideDil;
0360 if (fDim > 1) out << " + kh * " << hstrideDil;
0361 out << " + kw * " << wstrideDil << " ] = tensor_" << fNW << "[oc * " << ocstride << " + ic * " << icstride;
0362 if (fDim > 2) out << " + kd * " << dstride;
0363 if (fDim > 1) out << " + kh * " << hstride;
0364 out << " + kw ];\n";
0365
0366 out << SP << SP << SP << SP << "}\n";
0367 if (fDim > 1) out << SP << SP << SP << "}\n";
0368 if (fDim > 2) out << SP << SP << SP << "}\n";
0369 out << SP << SP << "}\n";
0370 out << SP << "}\n";
0371
0372
0373 out << SP << "char " << OpName << "_transA = 'N';\n";
0374 out << SP << "char " << OpName << "_transB = 'N';\n";
0375 out << SP << "int " << OpName << "_m = " << oHeight * oWidth * oDepth << ";\n";
0376 assert(fShapeY[1] == fShapeW[0]);
0377 assert(fShapeW[1] == fShapeX[1] / fAttrGroup);
0378 out << SP << "int " << OpName << "_n = " << fShapeW[0] << ";\n";
0379 out << SP << "int " << OpName << "_k = " << fShapeW[1] * fAttrKernelShape[0] * fAttrKernelShape[1] * fAttrKernelShape[2] << ";\n";
0380 out << SP << "float " << OpName << "_alpha = 1.0;\n";
0381 out << SP << "float " << OpName << "_beta = 0.0;\n";
0382
0383 if (fUseSession) {
0384 out << SP << fType << " * " << OpName << "_xcol = fVec_" << OpName << "_xcol.data();\n";
0385 }
0386 else {
0387 out << SP << fType << " " << OpName << "_xcol["
0388 << fShapeX[1] * fAttrKernelShape[0] * fAttrKernelShape[1] * fAttrKernelShape[2] * oDepth * oHeight * oWidth
0389 << "] = {0};\n";
0390 }
0391
0392
0393 out << SP << "for (size_t n = 0; n < " << bsize << "; n++) {\n";
0394
0395
0396
0397
0398
0399
0400
0401
0402 if (fDim ==1) {
0403 if (fAttrPads[0] != fAttrPads[1] ) {
0404 std::cout << "TMVA SOFIE Operator Conv: asymmetric padding not supported. Assume an average padding "
0405 << std::endl;
0406 fAttrPads[0] = (fAttrPads[0] + fAttrPads[1]) / 2;
0407 }
0408 fAttrPads[1] = 0;
0409 fAttrStrides[1] = 1;
0410 }
0411 if (fDim == 2) {
0412 if (fAttrPads[0] != fAttrPads[2] || fAttrPads[1] != fAttrPads[3]) {
0413 std::cout << "TMVA SOFIE Operator Conv: asymmetric padding not supported. Assume an average padding " << std::endl;
0414 fAttrPads[0] = (fAttrPads[0] + fAttrPads[2]) / 2;
0415 fAttrPads[1] = (fAttrPads[1] + fAttrPads[3]) / 2;
0416 }
0417 }
0418 if (fDim == 3) {
0419 if (fAttrPads[0] != fAttrPads[3] || fAttrPads[1] != fAttrPads[4] || fAttrPads[2] != fAttrPads[5]) {
0420 std::cout << "TMVA SOFIE Operator Conv: asymmetric padding not supported. Assume an average padding " << std::endl;
0421 fAttrPads[0] = (fAttrPads[0] + fAttrPads[3]) / 2;
0422 fAttrPads[1] = (fAttrPads[1] + fAttrPads[4]) / 2;
0423 fAttrPads[2] = (fAttrPads[2] + fAttrPads[5]) / 2;
0424 }
0425 }
0426 out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
0427
0428 if (fAttrGroup == 1) {
0429 out << SP << SP << "size_t x_offset = n * " << fShapeX[1] * iHeight * iWidth << ";\n";
0430
0431
0432 if (fDim < 3) {
0433 out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col<float>(tensor_" << fNX
0434 << " + x_offset,"
0435
0436
0437
0438 << fShapeW[1] << "," << iHeight << "," << iWidth << ",";
0439 if (fDim == 1)
0440 out << "1, " << fAttrKernelShape[0] << ",0," << fAttrPads[0] << ",1," << fAttrStrides[0] << ",1,"
0441 << fAttrDilations[0];
0442 else
0443 out << fAttrKernelShape[0] << "," << fAttrKernelShape[1] << "," << fAttrPads[0] << "," << fAttrPads[1]
0444 << "," << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrDilations[0] << ","
0445 << fAttrDilations[1];
0446 out << "," << OpName << "_xcol);\n\n ";
0447 } else {
0448
0449 out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col_3d<float>(tensor_" << fNX
0450 << " + x_offset,"
0451
0452
0453
0454 << fShapeW[1] << "," << iDepth << "," << iHeight << "," << iWidth << ","
0455 << fAttrKernelShape[0] << "," << fAttrKernelShape[1] << "," << fAttrKernelShape[2] << ","
0456 << fAttrPads[0] << "," << fAttrPads[1] << "," << fAttrPads[2] << ","
0457 << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrStrides[2] << ","
0458 << fAttrDilations[0] << "," << fAttrDilations[1] << "," << fAttrDilations[2] << ","
0459 << OpName << "_xcol);\n\n ";
0460 }
0461
0462 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transA, &" << OpName << "_transB, &" << OpName << "_m, &"
0463 << OpName << "_n, &" << OpName << "_k, &" << OpName << "_alpha, " << OpName << "_xcol, &" << OpName
0464 << "_m,\n";
0465 out << SP << SP << SP << OpName << "_f, &" << OpName << "_k, &" << OpName << "_beta, tensor_" << fNY
0466 << " + out_offset, &" << OpName << "_m);\n";
0467 } else {
0468
0469
0470
0471
0472 out << SP << SP << "for (size_t g = 0; g < " << fAttrGroup << "; g++) {\n";
0473 out << SP << SP << "size_t x_offset = n * " << fShapeX[1] * iDepth * iHeight * iWidth << " + g * "
0474 << fShapeW[1] * iDepth * iHeight * iWidth << ";\n ";
0475 out << SP << SP << "size_t out_offset = n * " << fShapeY[1] * oDepth * oHeight * oWidth << " + g * "
0476 << fShapeW[0] * oDepth * oHeight * oWidth / fAttrGroup << ";\n ";
0477
0478 if (fDim < 3) {
0479 out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col<float>(tensor_" << fNX
0480 << " + x_offset,"
0481
0482
0483
0484 << fShapeW[1] << "," << iHeight << "," << iWidth << ",";
0485 if (fDim == 1)
0486 out << "1, " << fAttrKernelShape[0] << ",0," << fAttrPads[0] << ",1," << fAttrStrides[0] << ",1,"
0487 << fAttrDilations[0];
0488 else
0489 out << fAttrKernelShape[0] << "," << fAttrKernelShape[1] << "," << fAttrPads[0] << "," << fAttrPads[1]
0490 << "," << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrDilations[0] << ","
0491 << fAttrDilations[1];
0492 out << "," << OpName << "_xcol);\n\n ";
0493 } else {
0494
0495 out << SP << SP << "TMVA::Experimental::SOFIE::UTILITY::Im2col_3d<float>(tensor_" << fNX
0496 << " + x_offset,"
0497
0498
0499
0500 << fShapeW[1] << "," << iDepth << "," << iHeight << "," << iWidth << "," << fAttrKernelShape[0] << ","
0501 << fAttrKernelShape[1] << "," << fAttrKernelShape[2] << "," << fAttrPads[0] << "," << fAttrPads[1]
0502 << "," << fAttrPads[2] << "," << fAttrStrides[0] << "," << fAttrStrides[1] << "," << fAttrStrides[2]
0503 << "," << fAttrDilations[0] << "," << fAttrDilations[1] << "," << fAttrDilations[2] << "," << OpName
0504 << "_xcol);\n\n ";
0505 }
0506
0507
0508
0509 out << SP << SP << SP << OpName << "_n = " << fShapeW[0] / fAttrGroup << ";\n";
0510
0511 out << SP << SP << SP << "size_t offset_f = g * "
0512 << fShapeW[0] * fShapeW[1] * fAttrKernelShape[0] * fAttrKernelShape[1] * fAttrKernelShape[2] / fAttrGroup
0513 << ";\n";
0514 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transA, &" << OpName << "_transB, &" << OpName << "_m, &"
0515 << OpName << "_n, &" << OpName << "_k, &" << OpName << "_alpha, " << OpName << "_xcol, &" << OpName
0516 << "_m,\n";
0517 out << SP << SP << SP << OpName << "_f + offset_f, &" << OpName << "_k, &" << OpName << "_beta, tensor_" << fNY
0518 << " + out_offset"
0519 << ", &" << OpName << "_m);\n";
0520
0521 out << SP << SP << "}\n";
0522 }
0523
0524 if (fNB2 != "") {
0525 out << SP << "int " << OpName << "_size = " << fShapeY[1] * oDepth * oHeight * oWidth << ";\n";
0526 out << SP << "float " << OpName << "_gamma = 1.0;\n";
0527 out << SP << "int " << OpName << "_incx = 1;\n";
0528 out << SP << "int " << OpName << "_incy = 1;\n";
0529
0530 out << SP << "BLAS::saxpy_(&" << OpName << "_size, &" << OpName << "_gamma, tensor_" << fNB2 << ", &"
0531 << OpName << "_incx, tensor_" << fNY << " + out_offset, &" << OpName << "_incy);\n";
0532
0533 }
0534 out << SP << "}\n";
0535
0536 return out.str();
0537 }
0538
0539
0540
0541 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0542 };
0543
0544 }
0545 }
0546 }
0547
0548 #endif