Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:05

0001 #ifndef TMVA_SOFIE_ROPERATOR_BASICNARY
0002 #define TMVA_SOFIE_ROPERATOR_BASICNARY
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <vector>
0009 #include <sstream>
0010 #include <algorithm>
0011 
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015 
0016 enum class EBasicNaryOperator {Max, Min, Mean, Sum};
0017 
0018 template<typename T, EBasicNaryOperator Op>
0019 struct NaryOperatorTraits {};
0020 
0021 template<typename T>
0022 struct NaryOperatorTraits<T, EBasicNaryOperator::Max> {
0023    static const std::string Name() {return "Max";}
0024    static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
0025       std::stringstream out;
0026       out << "\t" << "\t" << res << " = " << inputs[0] << ";\n";
0027       for (size_t i = 1; i < inputs.size(); i++) {
0028          out << "\t" << "\t" << res << " = std::max(" << res << ", " << inputs[i] << ");\n";
0029       }
0030       return out.str();
0031    }
0032 };
0033 
0034 template<typename T>
0035 struct NaryOperatorTraits<T, EBasicNaryOperator::Min> {
0036    static const std::string Name() {return "Min";}
0037    static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
0038       std::stringstream out;
0039       out << "\t" << "\t" << res << " = " << inputs[0] << ";\n";
0040       for (size_t i = 1; i < inputs.size(); i++) {
0041          out << "\t" << "\t" << res << " = std::min(" << res << ", " << inputs[i] << ");\n";
0042       }
0043       return out.str();
0044    }
0045 };
0046 
0047 template<typename T>
0048 struct NaryOperatorTraits<T, EBasicNaryOperator::Mean> {};
0049 
0050 template<>
0051 struct NaryOperatorTraits<float, EBasicNaryOperator::Mean> {
0052    static const std::string Name() {return "Mean";}
0053    static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
0054       std::stringstream out;
0055       out << "\t" << "\t" << res << " = (" << inputs[0];
0056       for (size_t i = 1; i < inputs.size(); i++) {
0057          out << " + " << inputs[i];
0058       }
0059       out << ") / float(" << inputs.size() << ");\n";
0060       return out.str();
0061    }
0062 };
0063 
0064 template<typename T>
0065 struct NaryOperatorTraits<T, EBasicNaryOperator::Sum> {
0066    static const std::string Name() {return "Sum";}
0067    static std::string Op(const std::string& res, std::vector<std::string>& inputs) {
0068       std::stringstream out;
0069       out << "\t" << "\t" << res << " = " << inputs[0];
0070       for (size_t i = 1; i < inputs.size(); i++) {
0071          out << " + " << inputs[i];
0072       }
0073       out << ";\n";
0074       return out.str();
0075    }
0076 };
0077 
0078 template <typename T, EBasicNaryOperator Op>
0079 class ROperator_BasicNary final : public ROperator
0080 {
0081 
0082 private:
0083 
0084    std::vector<std::string> fNInputs;
0085    std::string fNY;
0086    std::vector<std::vector<size_t>> fShapeInputs;
0087 
0088    std::vector<std::string> fNBroadcastedInputs;
0089    std::vector<size_t> fShapeY;
0090 
0091    bool fBroadcast = false;
0092 
0093    std::string fType;
0094 
0095 public:
0096    ROperator_BasicNary(){}
0097 
0098    ROperator_BasicNary( const std::vector<std::string> & inputNames, const std::string& nameY):
0099    fNY(UTILITY::Clean_name(nameY)){
0100       fNInputs.reserve(inputNames.size());
0101       for (auto & name : inputNames)
0102          fNInputs.push_back(UTILITY::Clean_name(name));
0103    }
0104 
0105    // type of output given input
0106    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0107       return input;
0108    }
0109 
0110    // shape of output tensors given input tensors
0111    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0112       auto ret = std::vector<std::vector<size_t>>(1, input[0]);
0113       return ret;
0114    }
0115 
0116    void Initialize(RModel& model){
0117       for (auto &it : fNInputs) {
0118          if (!model.CheckIfTensorAlreadyExist(it)) {
0119             throw std::runtime_error("TMVA SOFIE BasicNary Op Input Tensor " + it + " is not found in model");
0120          }
0121          fShapeInputs.push_back(model.GetTensorShape(it));
0122       }
0123       // Find the common shape of the input tensors
0124       fShapeY = UTILITY::MultidirectionalBroadcastShape(fShapeInputs);
0125       model.AddIntermediateTensor(fNY, model.GetTensorType(fNInputs[0]), fShapeY);
0126       // Broadcasting
0127       size_t N = fNInputs.size();
0128       fNBroadcastedInputs.reserve(N);
0129       for (size_t i = 0; i < N; i++) {
0130          if (!UTILITY::AreSameShape(model.GetTensorShape(fNInputs[i]), fShapeY)) {
0131             fBroadcast = true;
0132             std::string name = "Broadcasted"  + fNInputs[i];
0133             model.AddIntermediateTensor(name, model.GetTensorType(fNInputs[0]), fShapeY);
0134             fNBroadcastedInputs.emplace_back("tensor_" + name);
0135          } else {
0136             fNBroadcastedInputs.emplace_back("tensor_" + fNInputs[i]);
0137          }
0138       }
0139       fType = ConvertTypeToString(model.GetTensorType(fNInputs[0]));
0140    }
0141 
0142    std::string Generate(std::string OpName){
0143       OpName = "op_" + OpName;
0144       if (fShapeY.empty()) {
0145          throw std::runtime_error("TMVA SOFIE BasicNary called to Generate without being initialized first");
0146       }
0147       std::stringstream out;
0148       size_t length = ConvertShapeToLength(fShapeY);
0149       out << SP << "\n//------ BasicNary operator\n";
0150       if (fBroadcast) {
0151          for (size_t i = 0; i < fNInputs.size(); i++) {
0152             if (fNBroadcastedInputs[i] != fNInputs[i]) {
0153                out << SP << SP << "// Broadcasting " << fNInputs[i] << " to " << ConvertShapeToString(fShapeY) << "\n";
0154                out << SP << SP << "{\n";
0155                out << SP << SP << SP << fType << "* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << fType << ">(tensor_" + fNInputs[i] << ", " << ConvertShapeToString(fShapeInputs[i]);
0156                out << ", " << ConvertShapeToString(fShapeY) << ");\n";
0157                out << SP << SP << SP << "std::copy(data, data + " << length << ", " << fNBroadcastedInputs[i] << ");\n";
0158                out << SP << SP << SP << "delete[] data;\n";
0159                out << SP << SP << "}\n";
0160             }
0161          }
0162       }
0163 
0164       if (fNInputs.size() == 1) {
0165          out << SP << "std::copy(tensor_" << fNInputs[0] << ", tensor_" << fNInputs[0] << " + ";
0166          out << length << ", tensor_" << fNY << ");\n";
0167       } else {
0168          std::vector<std::string> inputs(fNBroadcastedInputs.size());
0169          for (size_t i = 0; i < fNBroadcastedInputs.size(); i++) {
0170             inputs[i] = fNBroadcastedInputs[i] + "[id]";
0171          }
0172          out << SP << "for (size_t id = 0; id < " << length << "; id++) {\n";
0173          out << NaryOperatorTraits<T,Op>::Op("tensor_" + fNY + "[id]", inputs);
0174          out << SP << "}\n";
0175       }
0176       return out.str();
0177    }
0178 
0179    std::vector<std::string> GetStdLibs() {return { std::string("cmath") }; }
0180 };
0181 
0182 }//SOFIE
0183 }//Experimental
0184 }//TMVA
0185 
0186 
0187 #endif //TMVA_SOFIE_ROPERATOR_BasicNary