File indexing completed on 2025-09-18 09:32:33
0001 #ifndef TMVA_SOFIE_ROperator_BasicBinary
0002 #define TMVA_SOFIE_ROperator_BasicBinary
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 enum EBasicBinaryOperator { Add, Sub, Mul, Div, Pow };
0015
0016 template <typename T, EBasicBinaryOperator Op1>
0017 struct BinaryOperatorTrait {};
0018
0019 template <typename T>
0020 struct BinaryOperatorTrait<T, Add> {
0021 static const std::string Name() { return "Add"; }
0022 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " + " + t2; }
0023 static T Func(T t1, T t2) {return t1 + t2;}
0024 };
0025
0026 template <typename T>
0027 struct BinaryOperatorTrait<T, Sub> {
0028 static const std::string Name() { return "Sub"; }
0029 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " - " + t2; }
0030 static T Func (T t1, T t2) { return t1 - t2;}
0031 };
0032
0033 template <typename T>
0034 struct BinaryOperatorTrait<T, Mul> {
0035 static const std::string Name() { return "Mul"; }
0036 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " * " + t2; }
0037 static T Func (T t1, T t2) { return t1 * t2;}
0038 };
0039
0040 template <typename T>
0041 struct BinaryOperatorTrait<T, Div> {
0042 static const std::string Name() { return "Div"; }
0043 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " / " + t2; }
0044 static T Func (T t1, T t2) { return t1/t2;}
0045 };
0046
0047 template <typename T>
0048 struct BinaryOperatorTrait<T, Pow> {
0049 static const std::string Name() { return "Pow"; }
0050 static std::string Op(const std::string & t1, const std::string t2) { return "std::pow(" + t1 + "," + t2 + ")"; }
0051 static T Func (T t1, T t2) { return std::pow(t1,t2);}
0052 };
0053
0054 template<typename T, EBasicBinaryOperator Op>
0055 class ROperator_BasicBinary final : public ROperator{
0056 private:
0057
0058 std::string fNA;
0059 std::string fNB;
0060 std::string fNBroadcastedA;
0061 std::string fNBroadcastedB;
0062 std::string fNY;
0063
0064 std::vector<size_t> fShapeA;
0065 std::vector<size_t> fShapeB;
0066 std::vector<size_t> fShapeY;
0067
0068 public:
0069 ROperator_BasicBinary(){}
0070 ROperator_BasicBinary(std::string nameA, std::string nameB, std::string nameY):
0071 fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNY(UTILITY::Clean_name(nameY)){
0072 fInputTensorNames = { fNA, fNB };
0073 fOutputTensorNames = { fNY };
0074 }
0075
0076
0077 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0078 return input;
0079 }
0080
0081
0082 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0083
0084 auto ret = std::vector<std::vector<size_t>>(1, input[0]);
0085 return ret;
0086 }
0087
0088 void Initialize(RModel& model) override {
0089
0090 if (!model.CheckIfTensorAlreadyExist(fNA)){
0091 throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNA + "is not found in model");
0092 }
0093 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0094 throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNB + "is not found in model");
0095 }
0096 fShapeA = model.GetTensorShape(fNA);
0097 fShapeB = model.GetTensorShape(fNB);
0098 bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB);
0099 if (broadcast) {
0100
0101 fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeA, fShapeB);
0102 bool broadcastA = !UTILITY::AreSameShape(fShapeA, fShapeY);
0103 bool broadcastB = !UTILITY::AreSameShape(fShapeB, fShapeY);
0104
0105 if (broadcastA) {
0106 fNBroadcastedA = "Broadcasted" + fNA + "to" + fNY;
0107 if (model.IsInitializedTensor(fNA)) {
0108 auto data = model.GetInitializedTensorData(fNA);
0109 std::shared_ptr<void> broadcastedData(
0110 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeA, fShapeY),
0111 std::default_delete<T[]>());
0112
0113 model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0114 fShapeA = fShapeY;
0115 } else {
0116
0117 model.AddIntermediateTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY);
0118 }
0119 }
0120
0121 if (broadcastB) {
0122 fNBroadcastedB = "Broadcasted" + fNB + "to" + fNY;
0123 if (model.IsInitializedTensor(fNB)) {
0124 auto data = model.GetInitializedTensorData(fNB);
0125 std::cout << "data B " << ConvertShapeToString(fShapeB) << " : " <<
0126 ConvertValuesToString(ConvertShapeToLength(fShapeB), static_cast<T*>(data.get())) << std::endl;
0127 std::shared_ptr<void> broadcastedData(
0128 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeB, fShapeY),
0129 std::default_delete<T[]>());
0130
0131 std::cout << "broadcasted data B " << ConvertShapeToString(fShapeY) << " : " <<
0132 ConvertValuesToString(ConvertShapeToLength(fShapeY), static_cast<T*>(broadcastedData.get())) << std::endl;
0133 model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0134 fShapeB = fShapeY;
0135 } else {
0136
0137 model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY);
0138 }
0139 }
0140 } else {
0141 fShapeY = fShapeA;
0142 }
0143
0144 if (model.IsInitializedTensor(fNA) && model.IsInitializedTensor(fNB)) {
0145 const std::string& nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0146 const std::string& nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0147 auto dataA = static_cast<T *>(model.GetInitializedTensorData(nameA).get());
0148 auto dataB = static_cast<T *>(model.GetInitializedTensorData(nameB).get());
0149 std::vector<T> dataY(ConvertShapeToLength(fShapeY));
0150 for (size_t i = 0; i < dataY.size(); i++) {
0151 dataY[i] = BinaryOperatorTrait<T,Op>::Func(dataA[i], dataB[i]);
0152 }
0153 model.AddConstantTensor<T>(fNY, fShapeY, dataY.data());
0154
0155 model.SetNotWritableInitializedTensor(nameA);
0156 model.SetNotWritableInitializedTensor(nameB);
0157 fIsOutputConstant = true;
0158 if (model.Verbose())
0159 std::cout << "Binary op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : "
0160 << ConvertValuesToString(dataY) << std::endl;
0161 }
0162 else {
0163 model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0164 }
0165 }
0166
0167 std::string GenerateInitCode() override {
0168 std::stringstream out;
0169 return out.str();
0170 }
0171
0172 std::string Generate(std::string OpName) override {
0173
0174 if (fIsOutputConstant) return "";
0175
0176 OpName = "op_" + OpName;
0177
0178 if (fShapeY.empty()) {
0179 throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
0180 }
0181 std::stringstream out;
0182 out << SP << "\n//------ " << BinaryOperatorTrait<T,Op>::Name() << "\n";
0183 size_t length = ConvertShapeToLength(fShapeY);
0184 std::string typeName = TensorType<T>::Name();
0185
0186
0187 if (fShapeA != fShapeY) {
0188 out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
0189 out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY)
0190 << ", fTensor_" << fNBroadcastedA << ");\n";
0191 }
0192
0193 if (fShapeB != fShapeY) {
0194 out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
0195 out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY)
0196 << ", fTensor_" << fNBroadcastedB << ");\n";
0197 }
0198 const std::string& nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0199 const std::string& nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0200 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0201 out << SP << SP << "tensor_" << fNY << "[id] = " << BinaryOperatorTrait<T,Op>::Op( "tensor_" + nameA + "[id]" , "tensor_" + nameB + "[id]") << " ;\n";
0202 out << SP << "}\n";
0203 return out.str();
0204 }
0205
0206 std::vector<std::string> GetStdLibs() override {
0207 if (Op == EBasicBinaryOperator::Pow) {
0208 return { std::string("cmath") };
0209 } else {
0210 return {};
0211 }
0212 }
0213 };
0214
0215 }
0216 }
0217 }
0218
0219
0220 #endif