File indexing completed on 2025-12-13 10:27:19
0001
0002 #ifndef TMVA_SOFIE_ROperator_Comparision
0003 #define TMVA_SOFIE_ROperator_Comparision
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 #include <sstream>
0010
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014
0015 enum EComparisionOperator { Eq, Less, LessEq, Greater, GreaterEq };
0016
0017 template <typename T, EComparisionOperator Op1>
0018 struct ComparisionTrait{};
0019
0020 template <typename T>
0021 struct ComparisionTrait<T, Eq> {
0022 static const std::string Name() { return "Equal"; }
0023 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " == " + t2; }
0024 static bool Result(T v1, T v2) { return v1 == v2;}
0025 };
0026
0027 template <typename T>
0028 struct ComparisionTrait<T, Less> {
0029 static const std::string Name() { return "Less"; }
0030 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " < " + t2; }
0031 static bool Result(T v1, T v2) { return v1 < v2;}
0032 };
0033
0034 template <typename T>
0035 struct ComparisionTrait<T, LessEq> {
0036 static const std::string Name() { return "LessOrEqual"; }
0037 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " <= " + t2; }
0038 static bool Result(T v1, T v2) { return v1 <= v2;}
0039 };
0040
0041 template <typename T>
0042 struct ComparisionTrait<T, Greater> {
0043 static const std::string Name() { return "Greater"; }
0044 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " > " + t2; }
0045 static bool Result(T v1, T v2) { return v1 > v2;}
0046 };
0047
0048 template <typename T>
0049 struct ComparisionTrait<T, GreaterEq> {
0050 static const std::string Name() { return "GreaterOrEqual"; }
0051 static std::string Op(const std::string & t1, const std::string t2) { return t1 + " >= " + t2 ; }
0052 static bool Result(T v1, T v2) { return v1 >= v2;}
0053 };
0054
0055 template<typename T, EComparisionOperator Op>
0056 class ROperator_Comparision final : public ROperator{
0057 private:
0058
0059 bool fIsModelOutput = false;
0060 std::string fNX1;
0061 std::string fNX2;
0062 std::string fNY;
0063 std::vector<size_t> fShapeX1;
0064 std::vector<size_t> fShapeX2;
0065 std::vector<Dim> fDimShapeX1;
0066 std::vector<Dim> fDimShapeX2;
0067 std::vector<size_t> fShapeY;
0068 std::string fNBroadcastedX1;
0069 std::string fNBroadcastedX2;
0070 ETensorType fTensorType1 = ETensorType::UNDEFINED;
0071 ETensorType fTensorType2 = ETensorType::UNDEFINED;
0072 bool fBroadcast = false;
0073
0074
0075 public:
0076 ROperator_Comparision(){}
0077 ROperator_Comparision(const std::string & nameX1, const std::string & nameX2, const std::string & nameY):
0078 fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){
0079 fInputTensorNames = { fNX1, fNX2 };
0080
0081
0082 fOutputTensorNames = { fNY };
0083 }
0084
0085
0086 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0087 return input;
0088 }
0089
0090
0091 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0092 auto ret = input;
0093 return ret;
0094 }
0095
0096 void Initialize(RModel& model) override {
0097
0098 if (!model.CheckIfTensorAlreadyExist(fNX1)){
0099 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX1 + "is not found in model");
0100 }
0101 if (!model.CheckIfTensorAlreadyExist(fNX2)) {
0102 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX2 + "is not found in model");
0103 }
0104 if (model.IsDynamicTensor(fNX1))
0105 fDimShapeX1 = model.GetDynamicTensorShape(fNX1);
0106 else {
0107 fShapeX1 = model.GetTensorShape(fNX1);
0108 fDimShapeX1 = ConvertShapeToDim(fShapeX1);
0109 }
0110 if (model.IsDynamicTensor(fNX2))
0111 fDimShapeX2 = model.GetDynamicTensorShape(fNX2);
0112 else {
0113 fShapeX2 = model.GetTensorShape(fNX2);
0114 fDimShapeX2 = ConvertShapeToDim(fShapeX2);
0115 }
0116 fTensorType1 = model.GetTensorType(fNX1);
0117 fTensorType2 = model.GetTensorType(fNX2);
0118 bool broadcast = !UTILITY::AreSameShape(fShapeX1, fShapeX2);
0119 if (broadcast) {
0120
0121 fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeX1, fShapeX2);
0122 bool broadcastX1 = !UTILITY::AreSameShape(fShapeX1, fShapeY);
0123 bool broadcastX2 = !UTILITY::AreSameShape(fShapeX2, fShapeY);
0124
0125 if (broadcastX1) {
0126 if (model.IsInitializedTensor(fNX1)) {
0127 auto data = model.GetInitializedTensorData(fNX1);
0128 std::shared_ptr<void> broadcastedData(
0129 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeX1, fShapeY),
0130 std::default_delete<T[]>());
0131
0132 model.UpdateInitializedTensor(fNX1, model.GetTensorType(fNX1), fShapeY, broadcastedData);
0133 fShapeX1 = fShapeY;
0134 } else {
0135
0136 fNBroadcastedX1 = "Broadcasted" + fNX1;
0137 model.AddIntermediateTensor(fNBroadcastedX1, model.GetTensorType(fNX1), fShapeY);
0138 }
0139 }
0140
0141 if (broadcastX2) {
0142 if (model.IsInitializedTensor(fNX2)) {
0143 auto data = model.GetInitializedTensorData(fNX2);
0144 std::shared_ptr<void> broadcastedData(
0145 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeX2, fShapeY),
0146 std::default_delete<T[]>());
0147
0148 model.UpdateInitializedTensor(fNX2, model.GetTensorType(fNX2), fShapeY, broadcastedData);
0149 fShapeX2 = fShapeY;
0150 } else {
0151
0152 fNBroadcastedX2 = "Broadcasted" + fNX2;
0153 model.AddIntermediateTensor(fNBroadcastedX2, model.GetTensorType(fNX2), fShapeY);
0154 }
0155 }
0156 } else {
0157 fShapeY = fShapeX1;
0158 }
0159
0160 T * data1 = nullptr;
0161 T * data2 = nullptr;
0162 std::vector<Dim> shapeData1;
0163 std::vector<Dim> shapeData2;
0164 size_t length = ConvertShapeToLength(fShapeY);
0165 bool * outData = new bool[length];
0166 if (model.IsInitializedTensor(fNX1)) {
0167 data1 = static_cast<T *>(model.GetInitializedTensorData(fNX1).get());
0168 } else if (model.IsShapeTensor(fNX1)) {
0169 shapeData1 = model.GetShapeTensorValues(fNX1);
0170 }
0171 if (model.IsInitializedTensor(fNX2)) {
0172 data2 = static_cast<T *>(model.GetInitializedTensorData(fNX2).get());
0173 } else if (model.IsShapeTensor(fNX2)) {
0174 shapeData2 = model.GetShapeTensorValues(fNX2);
0175 }
0176 if (data1 && data2) {
0177 fIsOutputConstant = true;
0178 for (size_t i = 0; i < length; i++)
0179 outData[i] = ComparisionTrait<T,Op>::Result(data1[i], data2[i]);
0180 model.AddConstantTensor(fNY, fShapeY, outData);
0181 if (model.Verbose())
0182 std::cout << ComparisionTrait<T,Op>::Name() << " op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : "
0183 << ConvertValuesToString(length,outData) << std::endl;
0184 } else if ((data1 || !shapeData1.empty()) && (data2 || !shapeData2.empty())) {
0185 fIsOutputConstant = true;
0186 if (data1 && !data2) {
0187
0188 for (size_t i = 0; i < length; i++) {
0189 if (shapeData2[i].isParam) {
0190 if (shapeData2[i].dim == size_t(-1) || data1[i] > 0) {
0191 fIsOutputConstant = false;
0192 break;
0193 } else {
0194
0195 shapeData2[i].dim = 0;
0196 }
0197 }
0198 outData[i] = ComparisionTrait<T,Op>::Result(data1[i], static_cast<T>(shapeData2[i].dim));
0199 }
0200 } else if (!data1 && data2) {
0201
0202 for (size_t i = 0; i < length; i++) {
0203 if (shapeData1[i].isParam) {
0204 if (shapeData1[i].dim == size_t(-1) || data2[i] > 0) {
0205 fIsOutputConstant = false;
0206 break;
0207 } else {
0208
0209 shapeData1[i].dim = 0;
0210 }
0211 }
0212 outData[i] = ComparisionTrait<T,Op>::Result(static_cast<T>(shapeData1[i].dim), data2[i]);
0213 }
0214 } else if (!shapeData1.empty() && !shapeData2.empty() ) {
0215
0216 for (size_t i = 0; i < length; i++) {
0217 if (!shapeData1[i].isParam && !shapeData2[i].isParam) {
0218 outData[i] = ComparisionTrait<T,Op>::Result(shapeData1[i].dim, shapeData2[i].dim);
0219 }
0220 else if (shapeData1[i].isParam && shapeData2[i].isParam) {
0221 if (shapeData1[i].param == shapeData2[i].param)
0222 outData[i] = ComparisionTrait<int,Op>::Result(1,1);
0223 else {
0224 fIsOutputConstant = false;
0225 break;
0226 }
0227 }
0228 else {
0229 fIsOutputConstant = false;
0230 break;
0231 }
0232 }
0233 }
0234 if (fIsOutputConstant) {
0235 model.AddConstantTensor(fNY, fShapeY, outData);
0236 if (model.Verbose())
0237 std::cout << ComparisionTrait<T,Op>::Name() << " op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : "
0238 << ConvertValuesToString(length,outData) << " (constant) " << std::endl;
0239
0240 }
0241 }
0242 delete [] outData;
0243 if (!fIsOutputConstant) {
0244 model.AddIntermediateTensor(fNY, ETensorType::BOOL , fShapeY);
0245 if (model.Verbose())
0246 std::cout << ComparisionTrait<T,Op>::Name() << " op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << std::endl;
0247 }
0248
0249
0250 const auto & outputTensorNames = model.GetOutputTensorNames();
0251 fIsModelOutput = false;
0252 if (std::find(outputTensorNames.begin(), outputTensorNames.end(), fNY) != outputTensorNames.end())
0253 fIsModelOutput = true;
0254 }
0255
0256 std::string Generate(std::string opName) override {
0257 if (fIsOutputConstant) return "";
0258 opName = "op_" + opName;
0259
0260 if (fShapeY.empty()) {
0261 throw std::runtime_error("TMVA SOFIE Comparision Op called to Generate without being initialized first");
0262 }
0263 std::stringstream out;
0264 out << SP << "\n//------ " << ComparisionTrait<T,Op>::Name() << " " << opName
0265 << " --> " << ConvertShapeToString(fShapeY) << "\n";
0266 size_t length = ConvertShapeToLength(fShapeY);
0267
0268 if (!fNBroadcastedX1.empty()) {
0269 std::string type1 = ConvertTypeToString(fTensorType1);
0270 out << SP << "// Broadcasting uninitialized tensor " << fNX1 << "\n";
0271 out << SP << "{\n";
0272 out << SP << SP << type1 << "* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << type1 << ">(tensor_" << fNX1 << ", " << ConvertShapeToString(fShapeX1) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0273 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX1 << ");\n";
0274 out << SP << SP << "delete[] data;\n";
0275 out << SP << "}\n";
0276 }
0277
0278 if (!fNBroadcastedX2.empty()) {
0279 std::string type2 = ConvertTypeToString(fTensorType2);
0280 out << SP << "// Broadcasting uninitialized tensor " << fNX2 << "\n";
0281 out << SP << "{\n";
0282 out << SP << SP << type2 << "* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << type2 << ">(tensor_" << fNX2 << ", " << ConvertShapeToString(fShapeX2) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0283 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX2 << ");\n";
0284 out << SP << SP << "delete[] data;\n";
0285 out << SP << "}\n";
0286 }
0287 const std::string& nameX1 = fNBroadcastedX1.empty()? fNX1 : fNBroadcastedX1;
0288 const std::string& nameX2 = fNBroadcastedX2.empty()? fNX2 : fNBroadcastedX2;
0289
0290 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0291 out << SP << SP << "fTensor_" << fNY << "[id] = " << ComparisionTrait<T,Op>::Op( "tensor_" + nameX1 + "[id]" , "tensor_" + nameX2 + "[id]") << " ;\n";
0292 out << SP << "}\n";
0293
0294 if (!fIsModelOutput)
0295 out << SP << "const std::vector<std::uint8_t> & tensor_" << fNY << " = fTensor_" << fNY << ";\n";
0296
0297 return out.str();
0298 }
0299
0300 };
0301
0302 }
0303 }
0304 }
0305
0306
0307 #endif