File indexing completed on 2025-10-30 08:55:11
0001
0002 #ifndef TMVA_SOFIE_ROperator_Comparision
0003 #define TMVA_SOFIE_ROperator_Comparision
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 #include <sstream>
0010
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014
0015 enum EComparisionOperator { Eq, Less, LessEq, Greater, GreaterEq };
0016
0017 template <typename T, EComparisionOperator Op1>
0018 struct ComparisionTrait{};
0019
0020 template <typename T>
0021 struct ComparisionTrait<T, Eq> {
0022 static const std::string Name() { return "Equal"; }
0023 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " == " + t2 + " ? true : false "; }
0024 static bool Result(T v1, T v2) { return v1 == v2;}
0025 };
0026
0027 template <typename T>
0028 struct ComparisionTrait<T, Less> {
0029 static const std::string Name() { return "Less"; }
0030 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " < " + t2 + " ? true : false "; }
0031 static bool Result(T v1, T v2) { return v1 < v2;}
0032 };
0033
0034 template <typename T>
0035 struct ComparisionTrait<T, LessEq> {
0036 static const std::string Name() { return "LessOrEqual"; }
0037 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " <= " + t2 + " ? true : false "; }
0038 static bool Result(T v1, T v2) { return v1 <= v2;}
0039 };
0040
0041 template <typename T>
0042 struct ComparisionTrait<T, Greater> {
0043 static const std::string Name() { return "Greater"; }
0044 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " > " + t2 + " ? true : false "; }
0045 static bool Result(T v1, T v2) { return v1 > v2;}
0046 };
0047
0048 template <typename T>
0049 struct ComparisionTrait<T, GreaterEq> {
0050 static const std::string Name() { return "GreaterOrEqual"; }
0051 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " >= " + t2 + " ? true : false " ; }
0052 static bool Result(T v1, T v2) { return v1 >= v2;}
0053 };
0054
0055 template<typename T, EComparisionOperator Op>
0056 class ROperator_Comparision final : public ROperator{
0057 private:
0058
0059 bool fIsModelOutput = false;
0060 std::string fNX1;
0061 std::string fNX2;
0062 std::string fNY;
0063 std::vector<size_t> fShapeX1;
0064 std::vector<size_t> fShapeX2;
0065 std::vector<size_t> fShapeY;
0066 std::string fNBroadcastedX1;
0067 std::string fNBroadcastedX2;
0068 ETensorType fTensorType1 = ETensorType::UNDEFINED;
0069 ETensorType fTensorType2 = ETensorType::UNDEFINED;
0070 bool fBroadcast = false;
0071
0072
0073 public:
0074 ROperator_Comparision(){}
0075 ROperator_Comparision(const std::string & nameX1, const std::string & nameX2, const std::string & nameY):
0076 fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){
0077 fInputTensorNames = { fNX1, fNX2 };
0078
0079
0080 fOutputTensorNames = { fNY };
0081 }
0082
0083
0084 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0085 return input;
0086 }
0087
0088
0089 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0090 auto ret = input;
0091 return ret;
0092 }
0093
0094 void Initialize(RModel& model) override {
0095
0096 if (!model.CheckIfTensorAlreadyExist(fNX1)){
0097 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX1 + "is not found in model");
0098 }
0099 if (!model.CheckIfTensorAlreadyExist(fNX2)) {
0100 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX2 + "is not found in model");
0101 }
0102 fShapeX1 = model.GetTensorShape(fNX1);
0103 fShapeX2 = model.GetTensorShape(fNX2);
0104 fTensorType1 = model.GetTensorType(fNX1);
0105 fTensorType2 = model.GetTensorType(fNX2);
0106 bool broadcast = !UTILITY::AreSameShape(fShapeX1, fShapeX2);
0107 if (broadcast) {
0108
0109 fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeX1, fShapeX2);
0110 bool broadcastX1 = !UTILITY::AreSameShape(fShapeX1, fShapeY);
0111 bool broadcastX2 = !UTILITY::AreSameShape(fShapeX2, fShapeY);
0112
0113 if (broadcastX1) {
0114 if (model.IsInitializedTensor(fNX1)) {
0115 auto data = model.GetInitializedTensorData(fNX1);
0116 std::shared_ptr<void> broadcastedData(
0117 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeX1, fShapeY),
0118 std::default_delete<T[]>());
0119
0120 model.UpdateInitializedTensor(fNX1, model.GetTensorType(fNX1), fShapeY, broadcastedData);
0121 fShapeX1 = fShapeY;
0122 } else {
0123
0124 fNBroadcastedX1 = "Broadcasted" + fNX1;
0125 model.AddIntermediateTensor(fNBroadcastedX1, model.GetTensorType(fNX1), fShapeY);
0126 }
0127 }
0128
0129 if (broadcastX2) {
0130 if (model.IsInitializedTensor(fNX2)) {
0131 auto data = model.GetInitializedTensorData(fNX2);
0132 std::shared_ptr<void> broadcastedData(
0133 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeX2, fShapeY),
0134 std::default_delete<T[]>());
0135
0136 model.UpdateInitializedTensor(fNX2, model.GetTensorType(fNX2), fShapeY, broadcastedData);
0137 fShapeX2 = fShapeY;
0138 } else {
0139
0140 fNBroadcastedX2 = "Broadcasted" + fNX2;
0141 model.AddIntermediateTensor(fNBroadcastedX2, model.GetTensorType(fNX2), fShapeY);
0142 }
0143 }
0144 } else {
0145 fShapeY = fShapeX1;
0146 }
0147
0148 if (model.IsInitializedTensor(fNX1) && model.IsInitializedTensor(fNX2) ) {
0149 fIsOutputConstant = true;
0150 auto data1 = static_cast<T *>(model.GetInitializedTensorData(fNX1).get());
0151 auto data2 = static_cast<T *>(model.GetInitializedTensorData(fNX2).get());
0152 size_t length = ConvertShapeToLength(fShapeY);
0153 bool * outData = new bool[length];
0154 for (size_t i = 0; i < length; i++)
0155 outData[i] = ComparisionTrait<T,Op>::Result(data1[i], data2[i]);
0156 model.AddConstantTensor(fNY, fShapeY, outData);
0157 if (model.Verbose())
0158 std::cout << ComparisionTrait<T,Op>::Name() << " op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : "
0159 << ConvertValuesToString(length,outData) << std::endl;
0160 delete [] outData;
0161 } else {
0162 model.AddIntermediateTensor(fNY, ETensorType::BOOL , fShapeY);
0163 }
0164
0165 const auto & outputTensorNames = model.GetOutputTensorNames();
0166 fIsModelOutput = false;
0167 if (std::find(outputTensorNames.begin(), outputTensorNames.end(), fNY) != outputTensorNames.end())
0168 fIsModelOutput = true;
0169 }
0170
0171 std::string Generate(std::string OpName) override {
0172 if (fIsOutputConstant) return "";
0173 OpName = "op_" + OpName;
0174
0175 if (fShapeY.empty()) {
0176 throw std::runtime_error("TMVA SOFIE Comparision Op called to Generate without being initialized first");
0177 }
0178 std::stringstream out;
0179 out << SP << "\n//------ " << ComparisionTrait<T,Op>::Name() << "\n";
0180 size_t length = ConvertShapeToLength(fShapeY);
0181
0182 if (!fNBroadcastedX1.empty()) {
0183 std::string type1 = ConvertTypeToString(fTensorType1);
0184 out << SP << "// Broadcasting uninitialized tensor " << fNX1 << "\n";
0185 out << SP << "{\n";
0186 out << SP << SP << type1 << "* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << type1 << ">(tensor_" << fNX1 << ", " << ConvertShapeToString(fShapeX1) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0187 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX1 << ");\n";
0188 out << SP << SP << "delete[] data;\n";
0189 out << SP << "}\n";
0190 }
0191
0192 if (!fNBroadcastedX2.empty()) {
0193 std::string type2 = ConvertTypeToString(fTensorType2);
0194 out << SP << "// Broadcasting uninitialized tensor " << fNX2 << "\n";
0195 out << SP << "{\n";
0196 out << SP << SP << type2 << "* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << type2 << ">(tensor_" << fNX2 << ", " << ConvertShapeToString(fShapeX2) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0197 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX2 << ");\n";
0198 out << SP << SP << "delete[] data;\n";
0199 out << SP << "}\n";
0200 }
0201 const std::string& nameX1 = fNBroadcastedX1.empty()? fNX1 : fNBroadcastedX1;
0202 const std::string& nameX2 = fNBroadcastedX2.empty()? fNX2 : fNBroadcastedX2;
0203
0204 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0205 out << SP << SP << "fTensor_" << fNY << "[id] = " << ComparisionTrait<T,Op>::Op( "tensor_" + nameX1 + "[id]" , "tensor_" + nameX2 + "[id]") << " ;\n";
0206 out << SP << "}\n";
0207
0208 if (!fIsModelOutput)
0209 out << SP << "const std::vector<bool> & tensor_" << fNY << " = fTensor_" << fNY << ";\n";
0210
0211 return out.str();
0212 }
0213
0214 };
0215
0216 }
0217 }
0218 }
0219
0220
0221 #endif