Warning, file /include/root/TMVA/ROperator_BasicBinary.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROperator_BasicBinary
0002 #define TMVA_SOFIE_ROperator_BasicBinary
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 enum EBasicBinaryOperator {
0015 Add,
0016 Sub,
0017 Mul,
0018 Div,
0019 Pow
0020 };
0021
0022 template <typename T, EBasicBinaryOperator Op1>
0023 struct BinaryOperatorTrait {};
0024
0025 template <typename T>
0026 struct BinaryOperatorTrait<T, Add> {
0027 static const std::string Name() { return "Add"; }
0028 static std::string Op(const std::string &t1, const std::string t2) { return t1 + " + " + t2; }
0029 static T Func(T t1, T t2) { return t1 + t2; }
0030 };
0031
0032 template <typename T>
0033 struct BinaryOperatorTrait<T, Sub> {
0034 static const std::string Name() { return "Sub"; }
0035 static std::string Op(const std::string &t1, const std::string t2) { return t1 + " - " + t2; }
0036 static T Func(T t1, T t2) { return t1 - t2; }
0037 };
0038
0039 template <typename T>
0040 struct BinaryOperatorTrait<T, Mul> {
0041 static const std::string Name() { return "Mul"; }
0042 static std::string Op(const std::string &t1, const std::string t2) { return t1 + " * " + t2; }
0043 static T Func(T t1, T t2) { return t1 * t2; }
0044 };
0045
0046 template <typename T>
0047 struct BinaryOperatorTrait<T, Div> {
0048 static const std::string Name() { return "Div"; }
0049 static std::string Op(const std::string &t1, const std::string t2) { return t1 + " / " + t2; }
0050 static T Func(T t1, T t2) { return t1 / t2; }
0051 };
0052
0053 template <typename T>
0054 struct BinaryOperatorTrait<T, Pow> {
0055 static const std::string Name() { return "Pow"; }
0056 static std::string Op(const std::string &t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
0057 static T Func(T t1, T t2) { return std::pow(t1, t2); }
0058 };
0059
0060 template <typename T, EBasicBinaryOperator Op>
0061 class ROperator_BasicBinary final : public ROperator {
0062 private:
0063 int fBroadcastFlag = 0;
0064 std::string fNA;
0065 std::string fNB;
0066 std::string fNBroadcastedA;
0067 std::string fNBroadcastedB;
0068 std::string fNY;
0069
0070 std::vector<size_t> fShapeA;
0071 std::vector<size_t> fShapeB;
0072 std::vector<size_t> fShapeY;
0073
0074 std::vector<Dim> fDimShapeA;
0075 std::vector<Dim> fDimShapeB;
0076 std::vector<Dim> fDimShapeY;
0077
0078 public:
0079 ROperator_BasicBinary() {}
0080 ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY)
0081 : fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY))
0082 {
0083 fInputTensorNames = {fNA, fNB};
0084 fOutputTensorNames = {fNY};
0085 }
0086
0087
0088 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0089
0090
0091 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override
0092 {
0093
0094 auto ret = std::vector<std::vector<size_t>>(1, input[0]);
0095 return ret;
0096 }
0097
0098 void Initialize(RModel &model) override
0099 {
0100
0101 if (!model.CheckIfTensorAlreadyExist(fNA)) {
0102 throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNA + "is not found in model");
0103 }
0104 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0105 throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNB + "is not found in model");
0106 }
0107 int dynamicInputs = 0;
0108 if (model.IsDynamicTensor(fNA)) {
0109 fDimShapeA = model.GetDynamicTensorShape(fNA);
0110 dynamicInputs |= 1;
0111 } else {
0112 fShapeA = model.GetTensorShape(fNA);
0113 fDimShapeA = ConvertShapeToDim(fShapeA);
0114 }
0115 if (model.IsDynamicTensor(fNB)) {
0116 dynamicInputs |= 2;
0117 fDimShapeB = model.GetDynamicTensorShape(fNB);
0118 } else {
0119 fShapeB = model.GetTensorShape(fNB);
0120 fDimShapeB = ConvertShapeToDim(fShapeB);
0121 }
0122 if (dynamicInputs & 1 && model.Verbose())
0123 std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNA << " is dynamic "
0124 << ConvertShapeToString(fDimShapeA) << " ";
0125 if (dynamicInputs & 2 && model.Verbose())
0126 std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNB << " is dynamic "
0127 << ConvertShapeToString(fDimShapeB) << " ";
0128 std::cout << std::endl;
0129
0130
0131
0132
0133 if (dynamicInputs == 0) {
0134 auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB);
0135 fBroadcastFlag = ret.first;
0136 fShapeY = ret.second;
0137 if (model.IsConstantTensor(fNA) && model.IsConstantTensor(fNB)) {
0138 bool broadcast = fBroadcastFlag > 0;
0139 if (broadcast) {
0140
0141 bool broadcastA = fBroadcastFlag & 2;
0142 bool broadcastB = fBroadcastFlag & 1;
0143
0144 if (broadcastA) {
0145 fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
0146 auto data = model.GetInitializedTensorData(fNA);
0147 std::shared_ptr<void> broadcastedData(
0148 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeA, fShapeY),
0149 std::default_delete<T[]>());
0150 if (model.Verbose())
0151 std::cout << "broadcasted data A " << ConvertShapeToString(fShapeY) << " : "
0152 << ConvertValuesToString(ConvertShapeToLength(fShapeY),
0153 static_cast<T *>(broadcastedData.get()))
0154 << std::endl;
0155
0156 model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0157 fShapeA = fShapeY;
0158 fDimShapeA = ConvertShapeToDim(fShapeA);
0159 }
0160
0161 if (broadcastB) {
0162 fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
0163 auto data = model.GetInitializedTensorData(fNB);
0164 if (model.Verbose())
0165 std::cout << "data B " << ConvertShapeToString(fShapeB) << " : "
0166 << ConvertValuesToString(ConvertShapeToLength(fShapeB), static_cast<T *>(data.get()))
0167 << std::endl;
0168 std::shared_ptr<void> broadcastedData(
0169 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeB, fShapeY),
0170 std::default_delete<T[]>());
0171
0172 if (model.Verbose())
0173 std::cout << "broadcasted data B " << ConvertShapeToString(fShapeY) << " : "
0174 << ConvertValuesToString(ConvertShapeToLength(fShapeY),
0175 static_cast<T *>(broadcastedData.get()))
0176 << std::endl;
0177 model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0178 fShapeB = fShapeY;
0179 fDimShapeB = ConvertShapeToDim(fShapeB);
0180 }
0181 } else {
0182 fShapeY = fShapeA;
0183 }
0184
0185
0186 const std::string &nameA = fNBroadcastedA.empty() ? fNA : fNBroadcastedA;
0187 const std::string &nameB = fNBroadcastedB.empty() ? fNB : fNBroadcastedB;
0188 auto dataA = static_cast<T *>(model.GetInitializedTensorData(nameA).get());
0189 auto dataB = static_cast<T *>(model.GetInitializedTensorData(nameB).get());
0190 std::vector<T> dataY(ConvertShapeToLength(fShapeY));
0191 for (size_t i = 0; i < dataY.size(); i++) {
0192 dataY[i] = BinaryOperatorTrait<T, Op>::Func(dataA[i], dataB[i]);
0193 }
0194 model.AddConstantTensor<T>(fNY, fShapeY, dataY.data());
0195
0196 model.SetNotWritableInitializedTensor(nameA);
0197 model.SetNotWritableInitializedTensor(nameB);
0198 fIsOutputConstant = true;
0199 if (model.Verbose()) {
0200 std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
0201 << " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " "
0202 << ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl;
0203 }
0204 } else {
0205
0206 model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0207 if (model.Verbose()) {
0208 std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA)
0209 << " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " "
0210 << ConvertShapeToString(fShapeY) << std::endl;
0211 }
0212
0213 fDimShapeY = ConvertShapeToDim(fShapeY);
0214 }
0215 } else {
0216
0217 auto ret = UTILITY::MultidirectionalBroadcastShape(fDimShapeA, fDimShapeB);
0218 fBroadcastFlag = ret.first;
0219 fDimShapeY = ret.second;
0220
0221
0222 if (ret.first & 4) {
0223
0224
0225 auto IsInputDimParam = [&](const std::string &p) {
0226 auto inputNames = model.GetInputTensorNames();
0227 for (auto &input : inputNames) {
0228 for (auto &i_s : model.GetDimTensorShape(input)) {
0229 if (i_s.isParam && i_s.param == p)
0230 return true;
0231 }
0232 }
0233 return false;
0234 };
0235 for (size_t i = 0; i < fDimShapeY.size(); i++) {
0236 auto &s = fDimShapeY[i];
0237 if (s.isParam && s.param.find("std::max") != std::string::npos) {
0238 if (IsInputDimParam(fDimShapeA[i].param)) {
0239
0240 if (fDimShapeA[i].dim != 1)
0241 s = fDimShapeA[i];
0242 else
0243 s = fDimShapeB[i];
0244 } else if (IsInputDimParam(fDimShapeB[i].param)) {
0245 if (fDimShapeB[i].dim != 1)
0246 s = fDimShapeB[i];
0247 else
0248 s = fDimShapeA[i];
0249 }
0250 }
0251 }
0252 }
0253
0254 model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fDimShapeY);
0255 if (model.Verbose()) {
0256 std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << ConvertShapeToString(fDimShapeA) << " , "
0257 << ConvertShapeToString(fDimShapeB) << " --> " << ConvertShapeToString(fDimShapeY) << std::endl;
0258 }
0259 }
0260 }
0261
0262 std::string GenerateInitCode() override
0263 {
0264 std::stringstream out;
0265 return out.str();
0266 }
0267
0268 std::string Generate(std::string opName) override
0269 {
0270
0271 if (fIsOutputConstant)
0272 return "";
0273
0274 opName = "op_" + opName;
0275
0276 if (fDimShapeY.empty()) {
0277 throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
0278 }
0279 std::stringstream out;
0280 out << SP << "\n//------ " << opName << " " << BinaryOperatorTrait<T, Op>::Name() << " --> "
0281 << ConvertDimShapeToString(fDimShapeY) << "\n";
0282 auto length = ConvertDimShapeToLength(fDimShapeY);
0283 std::string typeName = TensorType<T>::Name();
0284
0285
0286
0287 if (fBroadcastFlag & 4) {
0288
0289 auto lengthA = ConvertDimShapeToLength(fDimShapeA);
0290 auto lengthB = ConvertDimShapeToLength(fDimShapeB);
0291 out << SP << "if (" << lengthA << "!=" << lengthB << ") {\n";
0292
0293
0294 for (size_t i = 0; i < fDimShapeY.size(); i++) {
0295 if (fBroadcastFlag & 5 && fDimShapeY[i] == fDimShapeA[i] && fDimShapeA[i].dim > 1 &&
0296 fDimShapeB[i].isParam) {
0297
0298 out << SP << SP << "if (" << fDimShapeB[i] << "!= 1)\n";
0299 out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast B->A in operator "
0300 << opName << "\");\n";
0301 }
0302 if (fBroadcastFlag & 6 && fDimShapeY[i] == fDimShapeB[i] && fDimShapeB[i].dim > 1 &&
0303 fDimShapeA[i].isParam) {
0304
0305 out << SP << SP << "if (" << fDimShapeA[i] << "!= 1)\n";
0306 out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast A->B in operator "
0307 << opName << "\");\n";
0308 } else if (fDimShapeA[i].isParam && fDimShapeB[i].isParam) {
0309
0310
0311 out << SP << SP << "if (" << fDimShapeA[i] << " != " << fDimShapeB[i] << " && (" << fDimShapeA[i]
0312 << " != 1 || " << fDimShapeB[i] << " != 1))\n";
0313 out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast shapes in operator " << opName
0314 << "\");\n";
0315 }
0316 }
0317 out << SP << "}\n";
0318 }
0319
0320 auto stridesA = UTILITY::ComputeStrideFromShape(fDimShapeA);
0321 auto stridesB = UTILITY::ComputeStrideFromShape(fDimShapeB);
0322 auto stridesY = UTILITY::ComputeStrideFromShape(fDimShapeY);
0323
0324 std::string compute_idx_A, compute_idx_B, compute_idx_Y;
0325 if (fDimShapeA.empty() ||
0326 std::all_of(fDimShapeA.begin(), fDimShapeA.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
0327 compute_idx_A = "0";
0328 } else {
0329 for (size_t i = 0; i < fDimShapeA.size(); ++i) {
0330 if (fDimShapeA[i].dim == 1 || fDimShapeA[i].GetVal() == "1")
0331 continue;
0332 compute_idx_A += "idx_" + std::to_string(i + (fDimShapeY.size() - fDimShapeA.size()));
0333 if (stridesA[i].GetVal() != "1")
0334 compute_idx_A += " * " + stridesA[i].GetVal();
0335 compute_idx_A += " + ";
0336 }
0337
0338 for (int j = 0; j < 3; j++)
0339 compute_idx_A.pop_back();
0340 }
0341 if (fDimShapeB.empty() ||
0342 std::all_of(fDimShapeB.begin(), fDimShapeB.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
0343 compute_idx_B = "0";
0344 } else {
0345 for (size_t i = 0; i < fDimShapeB.size(); ++i) {
0346 if (fDimShapeB[i].dim == 1 || fDimShapeB[i].GetVal() == "1")
0347 continue;
0348 compute_idx_B += "idx_" + std::to_string(i + (fDimShapeY.size() - fDimShapeB.size()));
0349 if (stridesB[i].GetVal() != "1")
0350 compute_idx_B += " * " + stridesB[i].GetVal();
0351 compute_idx_B += " + ";
0352 }
0353
0354 for (int j = 0; j < 3; j++)
0355 compute_idx_B.pop_back();
0356 }
0357 int nloop = 0;
0358 if (fDimShapeY.empty() ||
0359 std::all_of(fDimShapeY.begin(), fDimShapeY.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
0360 compute_idx_Y = "0";
0361 } else {
0362 for (size_t i = 0; i < fDimShapeY.size(); ++i) {
0363 if (fDimShapeY[i].dim != 1 && fDimShapeY[i].GetVal() != "1") {
0364 nloop++;
0365 for (int j = 0; j < nloop; j++) out << SP;
0366 out << "for (size_t idx_" << i << " = 0; idx_" << i << " < " << fDimShapeY[i]
0367 << "; ++idx_" << i << "){\n";
0368 compute_idx_Y += "idx_" + std::to_string(i);
0369 if (stridesY[i].GetVal() != "1")
0370 compute_idx_Y += " * " + stridesY[i].GetVal();
0371 compute_idx_Y += " + ";
0372 }
0373 }
0374
0375 for (int j = 0; j < 3; j++)
0376 compute_idx_Y.pop_back();
0377 }
0378 for (int j = 0; j < nloop + 1; j++) out << SP;
0379 out << "tensor_" << fNY << "[" << compute_idx_Y << "] = "
0380 << BinaryOperatorTrait<T, Op>::Op("tensor_" + fNA + "[" + compute_idx_A + "]",
0381 "tensor_" + fNB + "[" + compute_idx_B + "]")
0382 << " ;\n";
0383
0384 for (int i = nloop; i > 0; i--) {
0385 for (int j = 0; j < i; j++) out << SP;
0386 out << "}\n";
0387 }
0388 return out.str();
0389 }
0390
0391 std::vector<std::string> GetStdLibs() override
0392 {
0393 if (Op == EBasicBinaryOperator::Pow) {
0394 return {std::string("cmath")};
0395 } else {
0396 return {};
0397 }
0398 }
0399 };
0400
0401 }
0402 }
0403 }
0404
0405 #endif