Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_BasicBinary.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROperator_BasicBinary
0002 #define TMVA_SOFIE_ROperator_BasicBinary
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 enum EBasicBinaryOperator {
0015    Add,
0016    Sub,
0017    Mul,
0018    Div,
0019    Pow
0020 };
0021 
0022 template <typename T, EBasicBinaryOperator Op1>
0023 struct BinaryOperatorTrait {};
0024 
0025 template <typename T>
0026 struct BinaryOperatorTrait<T, Add> {
0027    static const std::string Name() { return "Add"; }
0028    static std::string Op(const std::string &t1, const std::string t2) { return t1 + " + " + t2; }
0029    static T Func(T t1, T t2) { return t1 + t2; }
0030 };
0031 
0032 template <typename T>
0033 struct BinaryOperatorTrait<T, Sub> {
0034    static const std::string Name() { return "Sub"; }
0035    static std::string Op(const std::string &t1, const std::string t2) { return t1 + " - " + t2; }
0036    static T Func(T t1, T t2) { return t1 - t2; }
0037 };
0038 
0039 template <typename T>
0040 struct BinaryOperatorTrait<T, Mul> {
0041    static const std::string Name() { return "Mul"; }
0042    static std::string Op(const std::string &t1, const std::string t2) { return t1 + " * " + t2; }
0043    static T Func(T t1, T t2) { return t1 * t2; }
0044 };
0045 
0046 template <typename T>
0047 struct BinaryOperatorTrait<T, Div> {
0048    static const std::string Name() { return "Div"; }
0049    static std::string Op(const std::string &t1, const std::string t2) { return t1 + " / " + t2; }
0050    static T Func(T t1, T t2) { return t1 / t2; }
0051 };
0052 
0053 template <typename T>
0054 struct BinaryOperatorTrait<T, Pow> {
0055    static const std::string Name() { return "Pow"; }
0056    static std::string Op(const std::string &t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
0057    static T Func(T t1, T t2) { return std::pow(t1, t2); }
0058 };
0059 
0060 template <typename T, EBasicBinaryOperator Op>
0061 class ROperator_BasicBinary final : public ROperator {
0062 private:
0063    int fBroadcastFlag = 0;
0064    std::string fNA;
0065    std::string fNB;
0066    std::string fNBroadcastedA;
0067    std::string fNBroadcastedB;
0068    std::string fNY;
0069 
0070    std::vector<size_t> fShapeA;
0071    std::vector<size_t> fShapeB;
0072    std::vector<size_t> fShapeY;
0073 
0074    std::vector<Dim> fDimShapeA;
0075    std::vector<Dim> fDimShapeB;
0076    std::vector<Dim> fDimShapeY;
0077 
0078 public:
0079    ROperator_BasicBinary() {}
0080    ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY)
0081       : fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY))
0082    {
0083       fInputTensorNames = {fNA, fNB};
0084       fOutputTensorNames = {fNY};
0085    }
0086 
0087    // type of output given input
0088    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0089 
0090    // shape of output tensors given input tensors
0091    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override
0092    {
0093       // assume now inputs have same shape (no broadcasting)
0094       auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
0095       return ret;
0096    }
0097 
0098    void Initialize(RModel &model) override
0099    {
0100       // input must be a graph input, or already initialized intermediate tensor
0101       if (!model.CheckIfTensorAlreadyExist(fNA)) {
0102          throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNA + "is not found in model");
0103       }
0104       if (!model.CheckIfTensorAlreadyExist(fNB)) {
0105          throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNB + "is not found in model");
0106       }
0107       int dynamicInputs = 0;
0108       if (model.IsDynamicTensor(fNA)) {
0109          fDimShapeA = model.GetDynamicTensorShape(fNA);
0110          dynamicInputs |= 1;
0111       } else {
0112          fShapeA = model.GetTensorShape(fNA);
0113          fDimShapeA = ConvertShapeToDim(fShapeA);
0114       }
0115       if (model.IsDynamicTensor(fNB)) {
0116          dynamicInputs |= 2;
0117          fDimShapeB = model.GetDynamicTensorShape(fNB);
0118       } else {
0119          fShapeB = model.GetTensorShape(fNB);
0120          fDimShapeB = ConvertShapeToDim(fShapeB);
0121       }
0122       if (dynamicInputs & 1 && model.Verbose())
0123          std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNA << " is dynamic "
0124                    << ConvertShapeToString(fDimShapeA) << "  ";
0125       if (dynamicInputs & 2 && model.Verbose())
0126          std::cout << BinaryOperatorTrait<T, Op>::Name() << " : input " << fNB << " is dynamic "
0127                    << ConvertShapeToString(fDimShapeB) << "  ";
0128       std::cout << std::endl;
0129       // check if need to broadcast at initialization time if shapes are known and different
0130       // (we could broadcast the tensor tensor to maximum values of dynamic shapes - to be done)
0131       // case of known shapes
0132       // if shapes are known find the output shape from broadcasting
0133       if (dynamicInputs == 0) {
0134          auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB);
0135          fBroadcastFlag = ret.first;
0136          fShapeY = ret.second;
0137          if (model.IsConstantTensor(fNA) && model.IsConstantTensor(fNB)) {
0138             bool broadcast = fBroadcastFlag > 0;
0139             if (broadcast) {
0140                // Y is the common shape of A and B
0141                bool broadcastA = fBroadcastFlag & 2;
0142                bool broadcastB = fBroadcastFlag & 1;
0143                // Broadcast A to Y
0144                if (broadcastA) {
0145                   fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
0146                   auto data = model.GetInitializedTensorData(fNA);
0147                   std::shared_ptr<void> broadcastedData(
0148                      UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeA, fShapeY),
0149                      std::default_delete<T[]>());
0150                   if (model.Verbose())
0151                      std::cout << "broadcasted data A " << ConvertShapeToString(fShapeY) << " : "
0152                                << ConvertValuesToString(ConvertShapeToLength(fShapeY),
0153                                                         static_cast<T *>(broadcastedData.get()))
0154                                << std::endl;
0155                   // Update the data and the shape of A
0156                   model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0157                   fShapeA = fShapeY;
0158                   fDimShapeA = ConvertShapeToDim(fShapeA);
0159                }
0160                // Broadcast B to Y
0161                if (broadcastB) {
0162                   fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
0163                   auto data = model.GetInitializedTensorData(fNB);
0164                   if (model.Verbose())
0165                      std::cout << "data B " << ConvertShapeToString(fShapeB) << " : "
0166                                << ConvertValuesToString(ConvertShapeToLength(fShapeB), static_cast<T *>(data.get()))
0167                                << std::endl;
0168                   std::shared_ptr<void> broadcastedData(
0169                      UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeB, fShapeY),
0170                      std::default_delete<T[]>());
0171                   // do not update tensor B but add broadcasted one (since it can be input to some other operators)
0172                   if (model.Verbose())
0173                      std::cout << "broadcasted data B " << ConvertShapeToString(fShapeY) << " : "
0174                                << ConvertValuesToString(ConvertShapeToLength(fShapeY),
0175                                                         static_cast<T *>(broadcastedData.get()))
0176                                << std::endl;
0177                   model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0178                   fShapeB = fShapeY;
0179                   fDimShapeB = ConvertShapeToDim(fShapeB);
0180                }
0181             } else {
0182                fShapeY = fShapeA;
0183             }
0184             // tensors are constant: perform here the binary operation
0185 
0186             const std::string &nameA = fNBroadcastedA.empty() ? fNA : fNBroadcastedA;
0187             const std::string &nameB = fNBroadcastedB.empty() ? fNB : fNBroadcastedB;
0188             auto dataA = static_cast<T *>(model.GetInitializedTensorData(nameA).get());
0189             auto dataB = static_cast<T *>(model.GetInitializedTensorData(nameB).get());
0190             std::vector<T> dataY(ConvertShapeToLength(fShapeY));
0191             for (size_t i = 0; i < dataY.size(); i++) {
0192                dataY[i] = BinaryOperatorTrait<T, Op>::Func(dataA[i], dataB[i]);
0193             }
0194             model.AddConstantTensor<T>(fNY, fShapeY, dataY.data());
0195             // flag tensors to not be written in the weight file
0196             model.SetNotWritableInitializedTensor(nameA);
0197             model.SetNotWritableInitializedTensor(nameB);
0198             fIsOutputConstant = true;
0199             if (model.Verbose()) {
0200                std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << "  " << ConvertShapeToString(fShapeA)
0201                          << " , " << fNB << "  " << ConvertShapeToString(fShapeB) << " ---> " << fNY << "  "
0202                          << ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl;
0203             }
0204          } else {
0205             // case of defined and non-constant tensors
0206             model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0207             if (model.Verbose()) {
0208                std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << fNA << "  " << ConvertShapeToString(fShapeA)
0209                          << " , " << fNB << "  " << ConvertShapeToString(fShapeB) << " ---> " << fNY << "  "
0210                          << ConvertShapeToString(fShapeY) << std::endl;
0211             }
0212             // we convert non-dim shapes to Dim shapes
0213             fDimShapeY = ConvertShapeToDim(fShapeY);
0214          }
0215       } else {
0216          // case A or B have dynamic shapes. We need to broadcast if shape are not same
0217          auto ret = UTILITY::MultidirectionalBroadcastShape(fDimShapeA, fDimShapeB);
0218          fBroadcastFlag = ret.first;
0219          fDimShapeY = ret.second;
0220          // case of all parametric shapes and MultiDirectionalBroadcastShape  return the max of the 2
0221          // need to do before we declare the output tensor shape and the broadcasted ones
0222          if (ret.first & 4) {
0223             // check if one of the parameter is an input dimension
0224             // define function to find this
0225             auto IsInputDimParam = [&](const std::string &p) {
0226                auto inputNames = model.GetInputTensorNames();
0227                for (auto &input : inputNames) {
0228                   for (auto &i_s : model.GetDimTensorShape(input)) {
0229                      if (i_s.isParam && i_s.param == p)
0230                         return true;
0231                   }
0232                }
0233                return false;
0234             };
0235             for (size_t i = 0; i < fDimShapeY.size(); i++) {
0236                auto &s = fDimShapeY[i];
0237                if (s.isParam && s.param.find("std::max") != std::string::npos) {
0238                   if (IsInputDimParam(fDimShapeA[i].param)) {
0239                      // case dim is 1 we indicate that the input parameter is equal to 1
0240                      if (fDimShapeA[i].dim != 1)
0241                         s = fDimShapeA[i];
0242                      else
0243                         s = fDimShapeB[i];
0244                   } else if (IsInputDimParam(fDimShapeB[i].param)) {
0245                      if (fDimShapeB[i].dim != 1)
0246                         s = fDimShapeB[i];
0247                      else
0248                         s = fDimShapeA[i];
0249                   }
0250                }
0251             }
0252          }
0253 
0254          model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fDimShapeY);
0255          if (model.Verbose()) {
0256             std::cout << BinaryOperatorTrait<T, Op>::Name() << " : " << ConvertShapeToString(fDimShapeA) << " , "
0257                       << ConvertShapeToString(fDimShapeB) << " --> " << ConvertShapeToString(fDimShapeY) << std::endl;
0258          }
0259       }
0260    }
0261 
0262    std::string GenerateInitCode() override
0263    {
0264       std::stringstream out;
0265       return out.str();
0266    }
0267 
0268    std::string Generate(std::string opName) override
0269    {
0270 
0271       if (fIsOutputConstant)
0272          return "";
0273 
0274       opName = "op_" + opName;
0275 
0276       if (fDimShapeY.empty()) {
0277          throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
0278       }
0279       std::stringstream out;
0280       out << SP << "\n//------ " << opName << "  " << BinaryOperatorTrait<T, Op>::Name() << " --> "
0281           << ConvertDimShapeToString(fDimShapeY) << "\n";
0282       auto length = ConvertDimShapeToLength(fDimShapeY);
0283       std::string typeName = TensorType<T>::Name();
0284 
0285       // we need to check if we can broadcast (case flag has bit 4 set)
0286 
0287       if (fBroadcastFlag & 4) {
0288          // need to check if shapes are the same
0289          auto lengthA = ConvertDimShapeToLength(fDimShapeA);
0290          auto lengthB = ConvertDimShapeToLength(fDimShapeB);
0291          out << SP << "if (" << lengthA << "!=" << lengthB << ") {\n";
0292          // check if A->B or B->A
0293          // bool broadcastable = true;
0294          for (size_t i = 0; i < fDimShapeY.size(); i++) {
0295             if (fBroadcastFlag & 5 && fDimShapeY[i] == fDimShapeA[i] && fDimShapeA[i].dim > 1 &&
0296                 fDimShapeB[i].isParam) {
0297                // B->A B[i] needs to be 1
0298                out << SP << SP << "if (" << fDimShapeB[i] << "!= 1)\n";
0299                out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast B->A in operator "
0300                    << opName << "\");\n";
0301             }
0302             if (fBroadcastFlag & 6 && fDimShapeY[i] == fDimShapeB[i] && fDimShapeB[i].dim > 1 &&
0303                 fDimShapeA[i].isParam) {
0304                // A-> B A[i] needs to be 1
0305                out << SP << SP << "if (" << fDimShapeA[i] << "!= 1)\n";
0306                out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast A->B in operator "
0307                    << opName << "\");\n";
0308             } else if (fDimShapeA[i].isParam && fDimShapeB[i].isParam) {
0309                // both shapes are parametric and we broadcast to maximum
0310                // we allocate here output vector
0311                out << SP << SP << "if (" << fDimShapeA[i] << " != " << fDimShapeB[i] << " && (" << fDimShapeA[i]
0312                    << " != 1 || " << fDimShapeB[i] << " != 1))\n";
0313                out << SP << SP << SP << "throw std::runtime_error(\"SOFIE - Cannot broadcast shapes in operator " << opName
0314                    << "\");\n";
0315             }
0316          }
0317          out << SP << "}\n";
0318       }
0319 
0320       auto stridesA = UTILITY::ComputeStrideFromShape(fDimShapeA);
0321       auto stridesB = UTILITY::ComputeStrideFromShape(fDimShapeB);
0322       auto stridesY = UTILITY::ComputeStrideFromShape(fDimShapeY);
0323 
0324       std::string compute_idx_A, compute_idx_B, compute_idx_Y;
0325       if (fDimShapeA.empty() ||
0326           std::all_of(fDimShapeA.begin(), fDimShapeA.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
0327          compute_idx_A = "0";
0328       } else {
0329          for (size_t i = 0; i < fDimShapeA.size(); ++i) {
0330             if (fDimShapeA[i].dim == 1 || fDimShapeA[i].GetVal() == "1")
0331                continue;
0332             compute_idx_A += "idx_" + std::to_string(i + (fDimShapeY.size() - fDimShapeA.size()));
0333             if (stridesA[i].GetVal() != "1")
0334                compute_idx_A += " * " + stridesA[i].GetVal();
0335             compute_idx_A += " + ";
0336          }
0337          // remove last 3 character " + "
0338          for (int j = 0; j < 3; j++)
0339             compute_idx_A.pop_back();
0340       }
0341       if (fDimShapeB.empty() ||
0342           std::all_of(fDimShapeB.begin(), fDimShapeB.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
0343          compute_idx_B = "0";
0344       } else {
0345          for (size_t i = 0; i < fDimShapeB.size(); ++i) {
0346             if (fDimShapeB[i].dim == 1 || fDimShapeB[i].GetVal() == "1")
0347                continue;
0348             compute_idx_B += "idx_" + std::to_string(i + (fDimShapeY.size() - fDimShapeB.size()));
0349             if (stridesB[i].GetVal() != "1")
0350                compute_idx_B += " * " + stridesB[i].GetVal();
0351             compute_idx_B += " + ";
0352          }
0353           // remove last 3 character " + "
0354          for (int j = 0; j < 3; j++)
0355             compute_idx_B.pop_back();
0356       }
0357       int nloop = 0;
0358       if (fDimShapeY.empty() ||
0359           std::all_of(fDimShapeY.begin(), fDimShapeY.end(), [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) {
0360          compute_idx_Y = "0";
0361       } else {
0362          for (size_t i = 0; i < fDimShapeY.size(); ++i) {
0363             if (fDimShapeY[i].dim != 1 && fDimShapeY[i].GetVal() != "1") {
0364                nloop++;
0365                for (int j = 0; j < nloop; j++) out << SP;
0366                out << "for (size_t idx_" << i << " = 0; idx_" << i << " < " << fDimShapeY[i]
0367                    << "; ++idx_" << i << "){\n";
0368                compute_idx_Y += "idx_" + std::to_string(i);
0369                if (stridesY[i].GetVal() != "1")
0370                   compute_idx_Y += " * " + stridesY[i].GetVal();
0371                compute_idx_Y += " + ";
0372             }
0373          }
0374          // remove last 3 characters " + "
0375          for (int j = 0; j < 3; j++)
0376             compute_idx_Y.pop_back();
0377       }
0378       for (int j = 0; j < nloop + 1; j++) out << SP;
0379       out << "tensor_" << fNY << "[" << compute_idx_Y << "] = "
0380           << BinaryOperatorTrait<T, Op>::Op("tensor_" + fNA + "[" + compute_idx_A + "]",
0381                                             "tensor_" + fNB + "[" + compute_idx_B + "]")
0382           << " ;\n";
0383 
0384       for (int i = nloop; i > 0; i--) {
0385          for (int j = 0; j < i; j++) out << SP;
0386          out << "}\n";
0387       }
0388       return out.str();
0389    }
0390 
0391    std::vector<std::string> GetStdLibs() override
0392    {
0393       if (Op == EBasicBinaryOperator::Pow) {
0394          return {std::string("cmath")};
0395       } else {
0396          return {};
0397       }
0398    }
0399 };
0400 
0401 } // namespace SOFIE
0402 } // namespace Experimental
0403 } // namespace TMVA
0404 
0405 #endif // TMVA_SOFIE_ROperator_BasicBinary