Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:58

0001 
0002 #ifndef TMVA_SOFIE_ROperator_Comparision
0003 #define TMVA_SOFIE_ROperator_Comparision
0004 
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008 
0009 #include <sstream>
0010 
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014 
0015 enum EComparisionOperator { Eq, Less, LessEq, Greater, GreaterEq };
0016 
0017 template <typename T, EComparisionOperator Op1>
0018 struct ComparisionTrait{};
0019 
0020 template <typename T>
0021 struct ComparisionTrait<T, Eq> {
0022    static const std::string Name() { return "Equal"; }
0023    static std::string Op(const std::string & t1, const std::string t2) { return t1 + "==" +  t2 + "? true : false "; }
0024 };
0025 
0026 template <typename T>
0027 struct ComparisionTrait<T, Less> {
0028    static const std::string Name() { return "Less"; }
0029    static std::string Op(const std::string & t1, const std::string t2) { return t1 + "<" + t2 + "? true : false "; }
0030 };
0031 
0032 template <typename T>
0033 struct ComparisionTrait<T, LessEq> {
0034    static const std::string Name() { return "LessOrEqual"; }
0035    static std::string Op(const std::string & t1, const std::string t2) { return t1 + "<=" +  t2 + "? true : false ";  }
0036 };
0037 
0038 template <typename T>
0039 struct ComparisionTrait<T, Greater> {
0040    static const std::string Name() { return "Greater"; }
0041    static std::string Op(const std::string & t1, const std::string t2) { return  t1 + ">" +  t2 + "? true : false "; }
0042 };
0043 
0044 template <typename T>
0045 struct ComparisionTrait<T, GreaterEq> {
0046    static const std::string Name() { return "GreaterOrEqual"; }
0047    static std::string Op(const std::string & t1, const std::string t2) { return t1 + ">=" +  t2 + "? true : false " ; }
0048 };
0049 
0050 template<typename T, EComparisionOperator Op>
0051 class ROperator_Comparision final : public ROperator{
0052 private:
0053 
0054    std::string fNX1;
0055    std::string fNX2;
0056    std::string fNY;
0057    std::vector<size_t> fShapeX1;
0058    std::vector<size_t> fShapeX2;
0059    std::vector<size_t> fShapeY;
0060    std::string fNBroadcastedX1;
0061    std::string fNBroadcastedX2;
0062    bool fBroadcast = false;
0063 
0064 
0065 public:
0066    ROperator_Comparision(){}
0067    ROperator_Comparision(std::string nameX1, std::string nameX2, std::string nameY):
0068       fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
0069 
0070    // type of output given input
0071    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0072       return input;
0073    }
0074 
0075    // shape of output tensors given input tensors
0076    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0077       auto ret = input; // return vector size 1 with first input
0078       return ret;
0079    }
0080 
0081    void Initialize(RModel& model) override {
0082       // input must be a graph input, or already initialized intermediate tensor
0083       if (!model.CheckIfTensorAlreadyExist(fNX1)){
0084          throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX1 + "is not found in model");
0085       }
0086       if (!model.CheckIfTensorAlreadyExist(fNX2)) {
0087          throw std::runtime_error(std::string("TMVA SOFIE Comparision Op Input Tensor ") + fNX2 + "is not found in model");
0088       }
0089       fShapeX1 = model.GetTensorShape(fNX1);
0090       fShapeX2 = model.GetTensorShape(fNX2);
0091       bool broadcast = !UTILITY::AreSameShape(fShapeX1, fShapeX2);
0092       if (broadcast) {
0093          // Y is the common shape of A and B
0094          fShapeY = UTILITY::UnidirectionalBroadcastShape(fShapeX1, fShapeX2);
0095          bool broadcastX1 = !UTILITY::AreSameShape(fShapeX1, fShapeY);
0096          bool broadcastX2 = !UTILITY::AreSameShape(fShapeX2, fShapeY);
0097          // Broadcast A to Y
0098          if (broadcastX1) {
0099             if (model.IsInitializedTensor(fNX1)) {
0100                auto data = model.GetInitializedTensorData(fNX1);
0101                std::shared_ptr<void> broadcastedData(
0102                   UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeX1, fShapeY),
0103                   std::default_delete<float[]>());
0104                // Update the data and the shape of A
0105                model.UpdateInitializedTensor(fNX1, model.GetTensorType(fNX1), fShapeY, broadcastedData);
0106                fShapeX1 = fShapeY;
0107             } else {
0108                // Add an intermediate tensor for broadcasting A
0109                fNBroadcastedX1 = "Broadcasted" + fNX1;
0110                model.AddIntermediateTensor(fNBroadcastedX1, model.GetTensorType(fNX1), fShapeY);
0111             }
0112          }
0113          // Broadcast B to Y
0114          if (broadcastX2) {
0115             if (model.IsInitializedTensor(fNX2)) {
0116                auto data = model.GetInitializedTensorData(fNX2);
0117                std::shared_ptr<void> broadcastedData(
0118                   UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeX2, fShapeY),
0119                   std::default_delete<float[]>());
0120                // Update the data and the shape of B
0121                model.UpdateInitializedTensor(fNX2, model.GetTensorType(fNX2), fShapeY, broadcastedData);
0122                fShapeX2 = fShapeY;
0123             } else {
0124                // Add an intermediate tensor for broadcasting B
0125                fNBroadcastedX2 = "Broadcasted" + fNX2;
0126                model.AddIntermediateTensor(fNBroadcastedX2, model.GetTensorType(fNX2), fShapeY);
0127             }
0128          }
0129       } else {
0130          fShapeY = fShapeX1;
0131       }
0132       model.AddIntermediateTensor(fNY, ETensorType::BOOL , fShapeY);
0133    }
0134 
0135    std::string Generate(std::string OpName) override {
0136       OpName = "op_" + OpName;
0137 
0138      if (fShapeY.empty()) {
0139          throw std::runtime_error("TMVA SOFIE Comparision Op called to Generate without being initialized first");
0140       }
0141       std::stringstream out;
0142       out << SP << "\n//------ " << ComparisionTrait<T,Op>::Name() << "\n";
0143       size_t length = ConvertShapeToLength(fShapeY);
0144       // Broadcast A if it's uninitialized
0145       if (!fNBroadcastedX1.empty()) {
0146          out << SP << "// Broadcasting uninitialized tensor " << fNX1 << "\n";
0147          out << SP << "{\n";
0148          out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNX1 << ", " << ConvertShapeToString(fShapeX1) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0149          out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX1 << ");\n";
0150          out << SP << SP << "delete[] data;\n";
0151          out << SP << "}\n";
0152       }
0153       // Broadcast B if it's uninitialized
0154       if (!fNBroadcastedX2.empty()) {
0155          out << SP << "// Broadcasting uninitialized tensor " << fNX2 << "\n";
0156          out << SP << "{\n";
0157          out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNX2 << ", " << ConvertShapeToString(fShapeX2) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0158          out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNBroadcastedX2 << ");\n";
0159          out << SP << SP << "delete[] data;\n";
0160          out << SP << "}\n";
0161       }
0162       const std::string& nameX1 = fNBroadcastedX1.empty()? fNX1 : fNBroadcastedX1;
0163       const std::string& nameX2 = fNBroadcastedX2.empty()? fNX2 : fNBroadcastedX2;
0164 
0165       out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0166       out << SP << SP << "fTensor_" << fNY << "[id] = " << ComparisionTrait<T,Op>::Op( "tensor_" + nameX1 + "[id]" , "tensor_" + nameX2 + "[id]") <<  " ;\n";
0167       out << SP << "}\n";
0168 
0169       return out.str();
0170    }
0171 
0172 };
0173 
0174 }//SOFIE
0175 }//Experimental
0176 }//TMVA
0177 
0178 
0179 #endif //TMVA_SOFIE_ROperator_Comparision