Back to home page

EIC code displayed by LXR



File indexing completed on 2025-01-18 10:11:05

0001 #ifndef TMVA_SOFIE_ROperator_BasicBinary
0002 #define TMVA_SOFIE_ROperator_BasicBinary
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0008 #include <sstream>
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0014 enum EBasicBinaryOperator { Add, Sub, Mul, Div, Pow };
0016 template <typename T, EBasicBinaryOperator Op1>
0017 struct BinaryOperatorTrait {};
0019 template <typename T>
0020 struct BinaryOperatorTrait<T, Add> {
0021    static const std::string Name() { return "Add"; }
0022    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " + " + t2; }
0023 };
0025 template <typename T>
0026 struct BinaryOperatorTrait<T, Sub> {
0027    static const std::string Name() { return "Sub"; }
0028    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " - " + t2; }
0029 };
0031 template <typename T>
0032 struct BinaryOperatorTrait<T, Mul> {
0033    static const std::string Name() { return "Mul"; }
0034    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " * " + t2; }
0035 };
0037 template <typename T>
0038 struct BinaryOperatorTrait<T, Div> {
0039    static const std::string Name() { return "Div"; }
0040    static std::string Op(const std::string & t1, const std::string t2) { return t1 + " / " + t2; }
0041 };
0043 template <typename T>
0044 struct BinaryOperatorTrait<T, Pow> {
0045    static const std::string Name() { return "Pow"; }
0046    static std::string Op(const std::string & t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
0047 };
0049 template<typename T, EBasicBinaryOperator Op>
0050 class ROperator_BasicBinary final : public ROperator{
0051 private:
0053    std::string fNA;
0054    std::string fNB;
0055    std::string fNBroadcadstedA;
0056    std::string fNBroadcadstedB;
0057    std::string fNY;
0059    std::vector<size_t> fShapeA;
0060    std::vector<size_t> fShapeB;
0061    std::vector<size_t> fShapeY;
0063 public:
0064    ROperator_BasicBinary(){}
0065    ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY):
0066       fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)){}
0068    // type of output given input
0069    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0070       return input;
0071    }
0073    // shape of output tensors given input tensors
0074    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0075       // assume now inputs have same shape (no broadcasting)
0076       auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
0077       return ret;
0078    }
0080    void Initialize(RModel& model) override {
0081       // input must be a graph input, or already initialized intermediate tensor
0082       if (!model.CheckIfTensorAlreadyExist(fNA)){
0083          throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNA + "is not found in model");
0084       }
0085       if (!model.CheckIfTensorAlreadyExist(fNB)) {
0086          throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNB + "is not found in model");
0087       }
0088       fShapeA = model.GetTensorShape(fNA);
0089       fShapeB = model.GetTensorShape(fNB);
0090       bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB);
0091       if (broadcast) {
0092          // Y is the common shape of A and B
0093          fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeA, fShapeB);
0094          bool broadcastA = !UTILITY::AreSameShape(fShapeA, fShapeY);
0095          bool broadcastB = !UTILITY::AreSameShape(fShapeB, fShapeY);
0096          // Broadcast A to Y
0097          if (broadcastA) {
0098             if (model.IsInitializedTensor(fNA)) {
0099                auto data = model.GetInitializedTensorData(fNA);
0100                std::shared_ptr<void> broadcastedData(
0101                   UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeA, fShapeY),
0102                   std::default_delete<float[]>());
0103                // Update the data and the shape of A
0104                model.UpdateInitializedTensor(fNA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0105                fShapeA = fShapeY;
0106             } else {
0107                // Add an intermediate tensor for broadcasting A
0108                fNBroadcadstedA = "Broadcasted" + fNA;
0109                model.AddIntermediateTensor(fNBroadcadstedA, model.GetTensorType(fNA), fShapeY);
0110             }
0111          }
0112          // Broadcast B to Y
0113          if (broadcastB) {
0114             if (model.IsInitializedTensor(fNB)) {
0115                auto data = model.GetInitializedTensorData(fNB);
0116                std::shared_ptr<void> broadcastedData(
0117                   UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeB, fShapeY),
0118                   std::default_delete<float[]>());
0119                // Update the data and the shape of B
0120                model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0121                fShapeB = fShapeY;
0122             } else {
0123                // Add an intermediate tensor for broadcasting B
0124                fNBroadcadstedB = "Broadcasted" + fNB;
0125                model.AddIntermediateTensor(fNBroadcadstedB, model.GetTensorType(fNB), fShapeY);
0126             }
0127          }
0128       } else {
0129          fShapeY = fShapeA;
0130       }
0131       model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0132    }
0134    std::string GenerateInitCode() override {
0135       std::stringstream out;
0136       return out.str();
0137    }
0139    std::string Generate(std::string OpName) override {
0140       OpName = "op_" + OpName;
0142       if (fShapeY.empty()) {
0143          throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
0144       }
0145       std::stringstream out;
0146       out << SP << "\n//------ " << BinaryOperatorTrait<T,Op>::Name() << "\n";
0147       size_t length = ConvertShapeToLength(fShapeY);
0148       // Broadcast A if it's uninitialized
0149       if (!fNBroadcadstedA.empty()) {
0150          out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
0151          out << SP << "{\n";
0152          out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0153          out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcadstedA << ");\n";
0154          out << SP << SP << "delete[] data;\n";
0155          out << SP << "}\n";
0156       }
0157       // Broadcast B if it's uninitialized
0158       if (!fNBroadcadstedB.empty()) {
0159          out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
0160          out << SP << "{\n";
0161          out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0162          out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcadstedB << ");\n";
0163          out << SP << SP << "delete[] data;\n";
0164          out << SP << "}\n";
0165       }
0166       const std::string& nameA = fNBroadcadstedA.empty()? fNA : fNBroadcadstedA;
0167       const std::string& nameB = fNBroadcadstedB.empty()? fNB : fNBroadcadstedB;
0168       out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0169       out << SP << SP << "tensor_" << fNY << "[id] = "  << BinaryOperatorTrait<T,Op>::Op( "tensor_" + nameA + "[id]" , "tensor_" + nameB + "[id]") <<  " ;\n";
0170       out << SP << "}\n";
0171       return out.str();
0172    }
0174    std::vector<std::string> GetStdLibs() override {
0175       if (Op == EBasicBinaryOperator::Pow) {
0176          return { std::string("cmath") };
0177       } else {
0178          return {};
0179       }
0180    }
0181 };
0183 }//SOFIE
0184 }//Experimental
0185 }//TMVA
0188 #endif //TMVA_SOFIE_ROperator_BasicBinary