Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_Where.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROperator_Where
0002 #define TMVA_SOFIE_ROperator_Where
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 
0015 
0016 template<typename T>
0017 class ROperator_Where final : public ROperator{
0018 private:
0019 
0020    bool fIsInputBoolTensor = false;
0021 
0022 
0023    std::string fNA;
0024    std::string fNB;
0025    std::string fNC;
0026    std::string fNBroadcastedA;
0027    std::string fNBroadcastedB;
0028    std::string fNBroadcastedC;
0029    std::string fNY;
0030 
0031 
0032    std::vector<size_t> fShapeA;
0033    std::vector<size_t> fShapeB;
0034    std::vector<size_t> fShapeC;
0035    std::vector<size_t> fShapeY;
0036 
0037 
0038 public:
0039    ROperator_Where(){}
0040    ROperator_Where(const std::string & nameA, const std::string & nameB, const std::string & nameC, const std::string & nameY):
0041       fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)){
0042          fInputTensorNames = { fNA, fNB, fNC };
0043          fOutputTensorNames = { fNY };
0044       }
0045 
0046    // type of output given input
0047    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0048       return input;
0049    }
0050 
0051    // shape of output tensors given input tensors
0052    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0053       // assume now inputs have same shape (no broadcasting)
0054       auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
0055       return ret;
0056    }
0057 
0058    void Initialize(RModel& model) override {
0059       // input must be a graph input, or already initialized intermediate tensor
0060       if (!model.CheckIfTensorAlreadyExist(fNA)){
0061          throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNA + "is not found in model");
0062       }
0063       if (!model.CheckIfTensorAlreadyExist(fNB)) {
0064          throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNB + "is not found in model");
0065       }
0066       if (!model.CheckIfTensorAlreadyExist(fNC)) {
0067          throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNC + "is not found in model");
0068       }
0069       // check if fNC input tensor is boolean
0070       if (model.IsReadyInputTensor(fNC))
0071          fIsInputBoolTensor = true;
0072       // check broadcast for A, B and C
0073       fShapeA = model.GetTensorShape(fNA);
0074       fShapeB = model.GetTensorShape(fNB);
0075       fShapeC = model.GetTensorShape(fNC);
0076       bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB) || !UTILITY::AreSameShape(fShapeA, fShapeC);
0077       if (broadcast) {
0078          // find shape to broadcast between A,B,C looking for max length
0079          size_t lengthA = ConvertShapeToLength(fShapeA);
0080          size_t lengthB = ConvertShapeToLength(fShapeB);
0081          size_t lengthC = ConvertShapeToLength(fShapeC);
0082          bool broadcastA = false, broadcastB = false, broadcastC = false;
0083          if (lengthA >= lengthB && lengthA >= lengthC) {
0084             fShapeY = fShapeA;
0085             //broadcast B and C if different than A
0086             broadcastB = (lengthB != lengthA);
0087             broadcastC = (lengthC != lengthA);
0088          }
0089          else if (lengthB >= lengthA && lengthB >= lengthC) {
0090             fShapeY = fShapeB;
0091             //broadcast A and C if different than B
0092             broadcastA = (lengthA != lengthB);
0093             broadcastC = (lengthC != lengthB);
0094          }
0095          else if (lengthC >= lengthA && lengthC >= lengthB) {
0096             fShapeY = fShapeC;
0097             //broadcast A and B if different than C
0098             broadcastA = (lengthA != lengthC);
0099             broadcastB = (lengthB != lengthC);
0100          }
0101 
0102          // Broadcast A to Y
0103          if (broadcastA) {
0104             fNBroadcastedA = "BC_" + fNA + "_to_" + fNY;
0105             if (model.IsInitializedTensor(fNA)) {
0106                auto data = model.GetInitializedTensorData(fNA);
0107                std::shared_ptr<void> broadcastedData(
0108                   UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeA, fShapeY),
0109                   std::default_delete<T[]>());
0110                // Update the data and the shape of A
0111                model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0112                fShapeA = fShapeY;
0113             } else {
0114                // Add an intermediate tensor for broadcasting A
0115                model.AddIntermediateTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY);
0116             }
0117          }
0118          // Broadcast B to Y
0119          if (broadcastB) {
0120             fNBroadcastedB = "BC_" + fNB + "_to_" + fNY;
0121             if (model.IsInitializedTensor(fNB)) {
0122                auto data = model.GetInitializedTensorData(fNB);
0123                std::shared_ptr<void> broadcastedData(
0124                   UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeB, fShapeY),
0125                   std::default_delete<T[]>());
0126                // do not update tensor B but add broadcasted one (since it can be input to some other operators)
0127                model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0128                fShapeB = fShapeY;
0129             } else {
0130                // Add an intermediate tensor for broadcasting B
0131                model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY);
0132             }
0133          }
0134          // Broadcast C to Y
0135          if (broadcastC) {
0136             fNBroadcastedC = "BC_" + fNC + "_to_" + fNY;
0137             if (model.IsInitializedTensor(fNC)) {
0138                auto data = model.GetInitializedTensorData(fNC);
0139                std::shared_ptr<void> broadcastedData(
0140                   UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeC, fShapeY),
0141                   std::default_delete<T[]>());
0142                // do not update tensor C but add broadcasted one (since it can be input to some other operators)
0143                model.AddConstantTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY, broadcastedData);
0144                fShapeC = fShapeY;
0145             } else {
0146                // Add an intermediate tensor for broadcasting B
0147                model.AddIntermediateTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY);
0148             }
0149          }
0150       } else {
0151          fShapeY = fShapeA;
0152       }
0153       // check case of constant  output (if all inputs are defined)
0154       if (model.IsInitializedTensor(fNA) && model.IsInitializedTensor(fNB) && model.IsInitializedTensor(fNC)) {
0155          std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0156          std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0157          std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC;
0158          auto dataA = static_cast<T *>(model.GetInitializedTensorData(nameA).get());
0159          auto dataB = static_cast<T *>(model.GetInitializedTensorData(nameB).get());
0160          auto dataC = static_cast<bool *>(model.GetInitializedTensorData(nameC).get());
0161          std::vector<T> dataY(ConvertShapeToLength(fShapeY));
0162          for (size_t i = 0; i < dataY.size(); i++)
0163              dataY[i] = (dataC[i]) ? dataA[i] : dataB[i];
0164          model.AddConstantTensor<T>(fNY, fShapeY, dataY.data());
0165          // flag tensors to not be written in a file
0166          model.SetNotWritableInitializedTensor(nameA);
0167          model.SetNotWritableInitializedTensor(nameB);
0168          model.SetNotWritableInitializedTensor(nameC);
0169 
0170          fIsOutputConstant = true;
0171          if (model.Verbose())
0172             std::cout << "Where op ---> " << fNY << "  " << ConvertShapeToString(fShapeY) << " : "
0173                << ConvertValuesToString(dataY) << std::endl;
0174          
0175          // output is a constant tensor
0176          fOutputTensorNames.pop_back();
0177       }
0178       else {
0179         model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0180       }
0181    }
0182 
0183    std::string GenerateInitCode() override {
0184       std::stringstream out;
0185       return out.str();
0186    }
0187 
0188    std::string Generate(std::string OpName) override {
0189 
0190       if (fIsOutputConstant) return "";
0191 
0192       OpName = "op_" + OpName;
0193 
0194       if (fShapeY.empty()) {
0195          throw std::runtime_error("TMVA SOFIE Where Op called to Generate without being initialized first");
0196       }
0197       std::stringstream out;
0198       out << SP << "\n//-------- Where   \n";
0199       size_t length = ConvertShapeToLength(fShapeY);
0200       std::string typeName = TensorType<T>::Name();
0201       // Broadcast A if it's uninitialized
0202       if (fShapeA != fShapeY) {
0203          out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
0204          //out << SP << "{\n";
0205          out << SP  << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY)
0206                          << ", fTensor_" << fNBroadcastedA << ");\n";
0207       }
0208       // Broadcast B if it's uninitialized
0209       if (fShapeB != fShapeY) {
0210          out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
0211          //out << SP << "{\n";
0212          out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY)
0213                    << ", fTensor_" << fNBroadcastedB << ");\n";
0214       }
0215        // Broadcast C if it's uninitialized
0216       if (fShapeC != fShapeY) {
0217          // special case if C is an input tensor
0218          if (fIsInputBoolTensor) {
0219             size_t inputLength = ConvertShapeToLength(fShapeC);
0220             out << SP << "std::vector<bool> fTensor_" << fNC << "(tensor_" << fNC <<  ", tensor_" << fNC << " + " << inputLength << ");\n";
0221          }
0222          out << SP << "// Broadcasting uninitialized tensor " << fNC << "\n";
0223          //out << SP << "{\n";
0224          // for boolean we need to pass vector<bool> and use the non-template version of the function
0225          out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(fTensor_" << fNC << ", " << ConvertShapeToString(fShapeC) << ", " << ConvertShapeToString(fShapeY)
0226                    << ", fTensor_" << fNBroadcastedC << ");\n";
0227       }
0228       std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0229       std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0230       std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC;
0231       out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0232       // get output tensor applying condition (note we need to use directly the vector<bool> since v.data(),  i.e the data pointer, does not exist)
0233       out << SP << SP << "tensor_" << fNY << "[id] = "  << "(fTensor_" << nameC << "[id]) ? tensor_"
0234                                << nameA << "[id] : tensor_" + nameB + "[id];\n";
0235       out << SP << "}\n";
0236       return out.str();
0237    }
0238 
0239 };
0240 
0241 }//SOFIE
0242 }//Experimental
0243 }//TMVA
0244 
0245 
0246 #endif //TMVA_SOFIE_ROperator_Where