File indexing completed on 2025-01-30 10:22:58
0001
0002 #ifndef TMVA_SOFIE_ROperator_Comparision
0003 #define TMVA_SOFIE_ROperator_Comparision
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 #include <sstream>
0010
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014
0015 enum EComparisionOperator { Eq, Less, LessEq, Greater, GreaterEq };
0016
0017 template <typename T, EComparisionOperator Op1>
0018 struct ComparisionTrait{};
0019
0020 template <typename T>
0021 struct ComparisionTrait<T, Eq> {
0022 static const std::string Name() { return "Equal"; }
0023 static std::string Op(const std::string & t1, const std::string t2) { return t1 + "==" + t2 + "? true : false "; }
0024 };
0025
0026 template <typename T>
0027 struct ComparisionTrait<T, Less> {
0028 static const std::string Name() { return "Less"; }
0029 static std::string Op(const std::string & t1, const std::string t2) { return t1 + "<" + t2 + "? true : false "; }
0030 };
0031
0032 template <typename T>
0033 struct ComparisionTrait<T, LessEq> {
0034 static const std::string Name() { return "LessOrEqual"; }
0035 static std::string Op(const std::string & t1, const std::string t2) { return t1 + "<=" + t2 + "? true : false "; }
0036 };
0037
0038 template <typename T>
0039 struct ComparisionTrait<T, Greater> {
0040 static const std::string Name() { return "Greater"; }
0041 static std::string Op(const std::string & t1, const std::string t2) { return t1 + ">" + t2 + "? true : false "; }
0042 };
0043
0044 template <typename T>
0045 struct ComparisionTrait<T, GreaterEq> {
0046 static const std::string Name() { return "GreaterOrEqual"; }
0047 static std::string Op(const std::string & t1, const std::string t2) { return t1 + ">=" + t2 + "? true : false " ; }
0048 };
0049
0050 template<typename T, EComparisionOperator Op>
0051 class ROperator_Comparision final : public ROperator{
0052 private:
0053
0054 std::string fNX1;
0055 std::string fNX2;
0056 std::string fNY;
0057 std::vector<size_t> fShapeX1;
0058 std::vector<size_t> fShapeX2;
0059 std::vector<size_t> fShapeY;
0060 std::string fNBroadcastedX1;
0061 std::string fNBroadcastedX2;
0062 bool fBroadcast = false;
0063
0064
0065 public:
0066 ROperator_Comparision(){}
0067 ROperator_Comparision(std::string nameX1, std::string nameX2, std::string nameY):
0068 fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
0069
0070
0071 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0072 return input;
0073 }
0074
0075
0076 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0077 auto ret = input;
0078 return ret;
0079 }
0080
0081 void Initialize(RModel& model) override {
0082
0083 if (!model.CheckIfTensorAlreadyExist(fNX1)){
0084 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX1 + "is not found in model");
0085 }
0086 if (!model.CheckIfTensorAlreadyExist(fNX2)) {
0087 throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX2 + "is not found in model");
0088 }
0089 fShapeX1 = model.GetTensorShape(fNX1);
0090 fShapeX2 = model.GetTensorShape(fNX2);
0091 bool broadcast = !UTILITY::AreSameShape(fShapeX1, fShapeX2);
0092 if (broadcast) {
0093
0094 fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeX1, fShapeX2);
0095 bool broadcastX1 = !UTILITY::AreSameShape(fShapeX1, fShapeY);
0096 bool broadcastX2 = !UTILITY::AreSameShape(fShapeX2, fShapeY);
0097
0098 if (broadcastX1) {
0099 if (model.IsInitializedTensor(fNX1)) {
0100 auto data = model.GetInitializedTensorData(fNX1);
0101 std::shared_ptr<void> broadcastedData(
0102 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeX1, fShapeY),
0103 std::default_delete<float[]>());
0104
0105 model.UpdateInitializedTensor(fNX1, model.GetTensorType(fNX1), fShapeY, broadcastedData);
0106 fShapeX1 = fShapeY;
0107 } else {
0108
0109 fNBroadcastedX1 = "Broadcasted" + fNX1;
0110 model.AddIntermediateTensor(fNBroadcastedX1, model.GetTensorType(fNX1), fShapeY);
0111 }
0112 }
0113
0114 if (broadcastX2) {
0115 if (model.IsInitializedTensor(fNX2)) {
0116 auto data = model.GetInitializedTensorData(fNX2);
0117 std::shared_ptr<void> broadcastedData(
0118 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeX2, fShapeY),
0119 std::default_delete<float[]>());
0120
0121 model.UpdateInitializedTensor(fNX2, model.GetTensorType(fNX2), fShapeY, broadcastedData);
0122 fShapeX2 = fShapeY;
0123 } else {
0124
0125 fNBroadcastedX2 = "Broadcasted" + fNX2;
0126 model.AddIntermediateTensor(fNBroadcastedX2, model.GetTensorType(fNX2), fShapeY);
0127 }
0128 }
0129 } else {
0130 fShapeY = fShapeX1;
0131 }
0132 model.AddIntermediateTensor(fNY, ETensorType::BOOL , fShapeY);
0133 }
0134
0135 std::string Generate(std::string OpName) override {
0136 OpName = "op_" + OpName;
0137
0138 if (fShapeY.empty()) {
0139 throw std::runtime_error("TMVA SOFIE Comparision Op called to Generate without being initialized first");
0140 }
0141 std::stringstream out;
0142 out << SP << "\n//------ " << ComparisionTrait<T,Op>::Name() << "\n";
0143 size_t length = ConvertShapeToLength(fShapeY);
0144
0145 if (!fNBroadcastedX1.empty()) {
0146 out << SP << "// Broadcasting uninitialized tensor " << fNX1 << "\n";
0147 out << SP << "{\n";
0148 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNX1 << ", " << ConvertShapeToString(fShapeX1) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0149 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX1 << ");\n";
0150 out << SP << SP << "delete[] data;\n";
0151 out << SP << "}\n";
0152 }
0153
0154 if (!fNBroadcastedX2.empty()) {
0155 out << SP << "// Broadcasting uninitialized tensor " << fNX2 << "\n";
0156 out << SP << "{\n";
0157 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNX2 << ", " << ConvertShapeToString(fShapeX2) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0158 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX2 << ");\n";
0159 out << SP << SP << "delete[] data;\n";
0160 out << SP << "}\n";
0161 }
0162 const std::string& nameX1 = fNBroadcastedX1.empty()? fNX1 : fNBroadcastedX1;
0163 const std::string& nameX2 = fNBroadcastedX2.empty()? fNX2 : fNBroadcastedX2;
0164
0165 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0166 out << SP << SP << "fTensor_" << fNY << "[id] = " << ComparisionTrait<T,Op>::Op( "tensor_" + nameX1 + "[id]" , "tensor_" + nameX2 + "[id]") << " ;\n";
0167 out << SP << "}\n";
0168
0169 return out.str();
0170 }
0171
0172 };
0173
0174 }
0175 }
0176 }
0177
0178
0179 #endif