File indexing completed on 2025-01-18 10:11:05
0001 #ifndef TMVA_SOFIE_ROperator_BasicBinary
0002 #define TMVA_SOFIE_ROperator_BasicBinary
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 enum EBasicBinaryOperator { Add, Sub, Mul, Div, Pow };
0015
0016 template <typename T, EBasicBinaryOperator Op1>
0017 struct BinaryOperatorTrait {};
0018
0019 template <typename T>
0020 struct BinaryOperatorTrait<T, Add> {
0021 static const std::string Name() { return "Add"; }
0022 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " + " + t2; }
0023 };
0024
0025 template <typename T>
0026 struct BinaryOperatorTrait<T, Sub> {
0027 static const std::string Name() { return "Sub"; }
0028 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " - " + t2; }
0029 };
0030
0031 template <typename T>
0032 struct BinaryOperatorTrait<T, Mul> {
0033 static const std::string Name() { return "Mul"; }
0034 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " * " + t2; }
0035 };
0036
0037 template <typename T>
0038 struct BinaryOperatorTrait<T, Div> {
0039 static const std::string Name() { return "Div"; }
0040 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " / " + t2; }
0041 };
0042
0043 template <typename T>
0044 struct BinaryOperatorTrait<T, Pow> {
0045 static const std::string Name() { return "Pow"; }
0046 static std::string Op(const std::string & t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
0047 };
0048
0049 template<typename T, EBasicBinaryOperator Op>
0050 class ROperator_BasicBinary final : public ROperator{
0051 private:
0052
0053 std::string fNA;
0054 std::string fNB;
0055 std::string fNBroadcadstedA;
0056 std::string fNBroadcadstedB;
0057 std::string fNY;
0058
0059 std::vector<size_t> fShapeA;
0060 std::vector<size_t> fShapeB;
0061 std::vector<size_t> fShapeY;
0062
0063 public:
0064 ROperator_BasicBinary(){}
0065 ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY):
0066 fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)){}
0067
0068
0069 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0070 return input;
0071 }
0072
0073
0074 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0075
0076 auto ret = std::vector<std::vector<size_t>>(1, input[0]);
0077 return ret;
0078 }
0079
0080 void Initialize(RModel& model) override {
0081
0082 if (!model.CheckIfTensorAlreadyExist(fNA)){
0083 throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNA + "is not found in model");
0084 }
0085 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0086 throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNB + "is not found in model");
0087 }
0088 fShapeA = model.GetTensorShape(fNA);
0089 fShapeB = model.GetTensorShape(fNB);
0090 bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB);
0091 if (broadcast) {
0092
0093 fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeA, fShapeB);
0094 bool broadcastA = !UTILITY::AreSameShape(fShapeA, fShapeY);
0095 bool broadcastB = !UTILITY::AreSameShape(fShapeB, fShapeY);
0096
0097 if (broadcastA) {
0098 if (model.IsInitializedTensor(fNA)) {
0099 auto data = model.GetInitializedTensorData(fNA);
0100 std::shared_ptr<void> broadcastedData(
0101 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeA, fShapeY),
0102 std::default_delete<float[]>());
0103
0104 model.UpdateInitializedTensor(fNA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0105 fShapeA = fShapeY;
0106 } else {
0107
0108 fNBroadcadstedA = "Broadcasted" + fNA;
0109 model.AddIntermediateTensor(fNBroadcadstedA, model.GetTensorType(fNA), fShapeY);
0110 }
0111 }
0112
0113 if (broadcastB) {
0114 if (model.IsInitializedTensor(fNB)) {
0115 auto data = model.GetInitializedTensorData(fNB);
0116 std::shared_ptr<void> broadcastedData(
0117 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeB, fShapeY),
0118 std::default_delete<float[]>());
0119
0120 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0121 fShapeB = fShapeY;
0122 } else {
0123
0124 fNBroadcadstedB = "Broadcasted" + fNB;
0125 model.AddIntermediateTensor(fNBroadcadstedB, model.GetTensorType(fNB), fShapeY);
0126 }
0127 }
0128 } else {
0129 fShapeY = fShapeA;
0130 }
0131 model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0132 }
0133
0134 std::string GenerateInitCode() override {
0135 std::stringstream out;
0136 return out.str();
0137 }
0138
0139 std::string Generate(std::string OpName) override {
0140 OpName = "op_" + OpName;
0141
0142 if (fShapeY.empty()) {
0143 throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
0144 }
0145 std::stringstream out;
0146 out << SP << "\n//------ " << BinaryOperatorTrait<T,Op>::Name() << "\n";
0147 size_t length = ConvertShapeToLength(fShapeY);
0148
0149 if (!fNBroadcadstedA.empty()) {
0150 out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
0151 out << SP << "{\n";
0152 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0153 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcadstedA << ");\n";
0154 out << SP << SP << "delete[] data;\n";
0155 out << SP << "}\n";
0156 }
0157
0158 if (!fNBroadcadstedB.empty()) {
0159 out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
0160 out << SP << "{\n";
0161 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0162 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcadstedB << ");\n";
0163 out << SP << SP << "delete[] data;\n";
0164 out << SP << "}\n";
0165 }
0166 const std::string& nameA = fNBroadcadstedA.empty()? fNA : fNBroadcadstedA;
0167 const std::string& nameB = fNBroadcadstedB.empty()? fNB : fNBroadcadstedB;
0168 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0169 out << SP << SP << "tensor_" << fNY << "[id] = " << BinaryOperatorTrait<T,Op>::Op( "tensor_" + nameA + "[id]" , "tensor_" + nameB + "[id]") << " ;\n";
0170 out << SP << "}\n";
0171 return out.str();
0172 }
0173
0174 std::vector<std::string> GetStdLibs() override {
0175 if (Op == EBasicBinaryOperator::Pow) {
0176 return { std::string("cmath") };
0177 } else {
0178 return {};
0179 }
0180 }
0181 };
0182
0183 }
0184 }
0185 }
0186
0187
0188 #endif