Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:32:33

0001 #ifndef TMVA_SOFIE_ROperator_BasicBinary
0002 #define TMVA_SOFIE_ROperator_BasicBinary
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 enum EBasicBinaryOperator { Add, Sub, Mul, Div, Pow };
0015 
0016 template <typename T, EBasicBinaryOperator Op1>
0017 struct BinaryOperatorTrait {};
0018 
0019 template <typename T>
0020 struct BinaryOperatorTrait<T, Add> {
0021    static const std::string Name() { return "Add"; }
0022    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " + " + t2; }
0023    static T Func(T t1, T t2) {return  t1 + t2;}
0024 };
0025 
0026 template <typename T>
0027 struct BinaryOperatorTrait<T, Sub> {
0028    static const std::string Name() { return "Sub"; }
0029    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " - " + t2; }
0030    static T Func (T t1, T t2) { return t1 - t2;}
0031 };
0032 
0033 template <typename T>
0034 struct BinaryOperatorTrait<T, Mul> {
0035    static const std::string Name() { return "Mul"; }
0036    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " * " + t2; }
0037    static T Func (T t1, T t2) { return  t1 * t2;}
0038 };
0039 
0040 template <typename T>
0041 struct BinaryOperatorTrait<T, Div> {
0042    static const std::string Name() { return "Div"; }
0043    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " / " + t2; }
0044    static T Func (T t1, T t2) { return t1/t2;}
0045 };
0046 
0047 template <typename T>
0048 struct BinaryOperatorTrait<T, Pow> {
0049    static const std::string Name() { return "Pow"; }
0050    static std::string Op(const std::string & t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
0051    static T Func (T t1, T t2) { return std::pow(t1,t2);}
0052 };
0053 
0054 template<typename T, EBasicBinaryOperator Op>
0055 class ROperator_BasicBinary final : public ROperator{
0056 private:
0057 
0058    std::string fNA;
0059    std::string fNB;
0060    std::string fNBroadcastedA;
0061    std::string fNBroadcastedB;
0062    std::string fNY;
0063 
0064    std::vector<size_t> fShapeA;
0065    std::vector<size_t> fShapeB;
0066    std::vector<size_t> fShapeY;
0067 
0068 public:
0069    ROperator_BasicBinary(){}
0070    ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY):
0071       fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)){
0072          fInputTensorNames = { fNA, fNB };
0073          fOutputTensorNames = { fNY };
0074       }
0075 
0076    // type of output given input
0077    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0078       return input;
0079    }
0080 
0081    // shape of output tensors given input tensors
0082    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0083       // assume now inputs have same shape (no broadcasting)
0084       auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
0085       return ret;
0086    }
0087 
0088    void Initialize(RModel& model) override {
0089       // input must be a graph input, or already initialized intermediate tensor
0090       if (!model.CheckIfTensorAlreadyExist(fNA)){
0091          throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNA + "is not found in model");
0092       }
0093       if (!model.CheckIfTensorAlreadyExist(fNB)) {
0094          throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNB + "is not found in model");
0095       }
0096       fShapeA = model.GetTensorShape(fNA);
0097       fShapeB = model.GetTensorShape(fNB);
0098       bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB);
0099       if (broadcast) {
0100          // Y is the common shape of A and B
0101          fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeA, fShapeB);
0102          bool broadcastA = !UTILITY::AreSameShape(fShapeA, fShapeY);
0103          bool broadcastB = !UTILITY::AreSameShape(fShapeB, fShapeY);
0104          // Broadcast A to Y
0105          if (broadcastA) {
0106             fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
0107             if (model.IsInitializedTensor(fNA)) {
0108                auto data = model.GetInitializedTensorData(fNA);
0109                std::shared_ptr<void> broadcastedData(
0110                   UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeA, fShapeY),
0111                   std::default_delete<T[]>());
0112                // Update the data and the shape of A
0113                model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0114                fShapeA = fShapeY;
0115             } else {
0116                // Add an intermediate tensor for broadcasting A
0117                model.AddIntermediateTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY);
0118             }
0119          }
0120          // Broadcast B to Y
0121          if (broadcastB) {
0122             fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
0123             if (model.IsInitializedTensor(fNB)) {
0124                auto data = model.GetInitializedTensorData(fNB);
0125                std::cout << "data B " << ConvertShapeToString(fShapeB) << " : " <<
0126                   ConvertValuesToString(ConvertShapeToLength(fShapeB), static_cast<T*>(data.get())) << std::endl;
0127                std::shared_ptr<void> broadcastedData(
0128                   UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeB, fShapeY),
0129                   std::default_delete<T[]>());
0130                // do not update tensor B but add broadcasted one (since it can be input to some other operators)
0131                std::cout << "broadcasted data B " << ConvertShapeToString(fShapeY) << " : " <<
0132                   ConvertValuesToString(ConvertShapeToLength(fShapeY), static_cast<T*>(broadcastedData.get())) << std::endl;
0133                model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0134                fShapeB = fShapeY;
0135             } else {
0136                // Add an intermediate tensor for broadcasting B
0137                model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY);
0138             }
0139          }
0140       } else {
0141          fShapeY = fShapeA;
0142       }
0143       // check case of constant  output (if all inputs are defined)
0144       if (model.IsInitializedTensor(fNA) && model.IsInitializedTensor(fNB)) {
0145          const std::string& nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0146          const std::string& nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0147          auto dataA = static_cast<T *>(model.GetInitializedTensorData(nameA).get());
0148          auto dataB = static_cast<T *>(model.GetInitializedTensorData(nameB).get());
0149          std::vector<T> dataY(ConvertShapeToLength(fShapeY));
0150          for (size_t i = 0; i < dataY.size(); i++) {
0151             dataY[i] = BinaryOperatorTrait<T,Op>::Func(dataA[i], dataB[i]);
0152          }
0153          model.AddConstantTensor<T>(fNY, fShapeY, dataY.data());
0154          // flag tensors to not be written in a fil
0155          model.SetNotWritableInitializedTensor(nameA);
0156          model.SetNotWritableInitializedTensor(nameB);
0157          fIsOutputConstant = true;
0158          if (model.Verbose())
0159             std::cout << "Binary op ---> " << fNY << "  " << ConvertShapeToString(fShapeY) << " : "
0160                << ConvertValuesToString(dataY) << std::endl;
0161       }
0162       else {
0163         model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0164       }
0165    }
0166 
0167    std::string GenerateInitCode() override {
0168       std::stringstream out;
0169       return out.str();
0170    }
0171 
0172    std::string Generate(std::string OpName) override {
0173 
0174       if (fIsOutputConstant) return "";
0175 
0176       OpName = "op_" + OpName;
0177 
0178       if (fShapeY.empty()) {
0179          throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
0180       }
0181       std::stringstream out;
0182       out << SP << "\n//------ " << BinaryOperatorTrait<T,Op>::Name() << "\n";
0183       size_t length = ConvertShapeToLength(fShapeY);
0184       std::string typeName = TensorType<T>::Name();
0185       // Broadcast A if it's uninitialized
0186       // use broadcasting function where we pass an already allocated tensor to minimize memory allocations
0187       if (fShapeA != fShapeY) {
0188          out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
0189          out << SP  << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY)
0190                          << ", fTensor_" << fNBroadcastedA << ");\n";
0191       }
0192       // Broadcast B if it's uninitialized
0193       if (fShapeB != fShapeY) {
0194          out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
0195          out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY)
0196                    << ", fTensor_" << fNBroadcastedB << ");\n";
0197       }
0198       const std::string& nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0199       const std::string& nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0200       out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0201       out << SP << SP << "tensor_" << fNY << "[id] = "  << BinaryOperatorTrait<T,Op>::Op( "tensor_" + nameA + "[id]" , "tensor_" + nameB + "[id]") <<  " ;\n";
0202       out << SP << "}\n";
0203       return out.str();
0204    }
0205 
0206    std::vector<std::string> GetStdLibs() override {
0207       if (Op == EBasicBinaryOperator::Pow) {
0208          return { std::string("cmath") };
0209       } else {
0210          return {};
0211       }
0212    }
0213 };
0214 
0215 }//SOFIE
0216 }//Experimental
0217 }//TMVA
0218 
0219 
0220 #endif //TMVA_SOFIE_ROperator_BasicBinary