Warning, file /include/root/TMVA/ROperator_Where.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROperator_Where
0002 #define TMVA_SOFIE_ROperator_Where
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014
0015
0016 template<typename T>
0017 class ROperator_Where final : public ROperator{
0018 private:
0019
0020 bool fIsInputBoolTensor = false;
0021
0022
0023 std::string fNA;
0024 std::string fNB;
0025 std::string fNC;
0026 std::string fNBroadcastedA;
0027 std::string fNBroadcastedB;
0028 std::string fNBroadcastedC;
0029 std::string fNY;
0030
0031
0032 std::vector<size_t> fShapeA;
0033 std::vector<size_t> fShapeB;
0034 std::vector<size_t> fShapeC;
0035 std::vector<size_t> fShapeY;
0036
0037
0038 public:
0039 ROperator_Where(){}
0040 ROperator_Where(const std::string & nameA, const std::string & nameB, const std::string & nameC, const std::string & nameY):
0041 fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)){
0042 fInputTensorNames = { fNA, fNB, fNC };
0043 fOutputTensorNames = { fNY };
0044 }
0045
0046
0047 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0048 return input;
0049 }
0050
0051
0052 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0053
0054 auto ret = std::vector<std::vector<size_t>>(1, input[0]);
0055 return ret;
0056 }
0057
0058 void Initialize(RModel& model) override {
0059
0060 if (!model.CheckIfTensorAlreadyExist(fNA)){
0061 throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNA + "is not found in model");
0062 }
0063 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0064 throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNB + "is not found in model");
0065 }
0066 if (!model.CheckIfTensorAlreadyExist(fNC)) {
0067 throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNC + "is not found in model");
0068 }
0069
0070 if (model.IsReadyInputTensor(fNC))
0071 fIsInputBoolTensor = true;
0072
0073 fShapeA = model.GetTensorShape(fNA);
0074 fShapeB = model.GetTensorShape(fNB);
0075 fShapeC = model.GetTensorShape(fNC);
0076 bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB) || !UTILITY::AreSameShape(fShapeA, fShapeC);
0077 if (broadcast) {
0078
0079 size_t lengthA = ConvertShapeToLength(fShapeA);
0080 size_t lengthB = ConvertShapeToLength(fShapeB);
0081 size_t lengthC = ConvertShapeToLength(fShapeC);
0082 bool broadcastA = false, broadcastB = false, broadcastC = false;
0083 if (lengthA >= lengthB && lengthA >= lengthC) {
0084 fShapeY = fShapeA;
0085
0086 broadcastB = (lengthB != lengthA);
0087 broadcastC = (lengthC != lengthA);
0088 }
0089 else if (lengthB >= lengthA && lengthB >= lengthC) {
0090 fShapeY = fShapeB;
0091
0092 broadcastA = (lengthA != lengthB);
0093 broadcastC = (lengthC != lengthB);
0094 }
0095 else if (lengthC >= lengthA && lengthC >= lengthB) {
0096 fShapeY = fShapeC;
0097
0098 broadcastA = (lengthA != lengthC);
0099 broadcastB = (lengthB != lengthC);
0100 }
0101
0102
0103 if (broadcastA) {
0104 fNBroadcastedA = "BC_" + fNA + "_to_" + fNY;
0105 if (model.IsInitializedTensor(fNA)) {
0106 auto data = model.GetInitializedTensorData(fNA);
0107 std::shared_ptr<void> broadcastedData(
0108 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeA, fShapeY),
0109 std::default_delete<T[]>());
0110
0111 model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData);
0112 fShapeA = fShapeY;
0113 } else {
0114
0115 model.AddIntermediateTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY);
0116 }
0117 }
0118
0119 if (broadcastB) {
0120 fNBroadcastedB = "BC_" + fNB + "_to_" + fNY;
0121 if (model.IsInitializedTensor(fNB)) {
0122 auto data = model.GetInitializedTensorData(fNB);
0123 std::shared_ptr<void> broadcastedData(
0124 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeB, fShapeY),
0125 std::default_delete<T[]>());
0126
0127 model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData);
0128 fShapeB = fShapeY;
0129 } else {
0130
0131 model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY);
0132 }
0133 }
0134
0135 if (broadcastC) {
0136 fNBroadcastedC = "BC_" + fNC + "_to_" + fNY;
0137 if (model.IsInitializedTensor(fNC)) {
0138 auto data = model.GetInitializedTensorData(fNC);
0139 std::shared_ptr<void> broadcastedData(
0140 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeC, fShapeY),
0141 std::default_delete<T[]>());
0142
0143 model.AddConstantTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY, broadcastedData);
0144 fShapeC = fShapeY;
0145 } else {
0146
0147 model.AddIntermediateTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY);
0148 }
0149 }
0150 } else {
0151 fShapeY = fShapeA;
0152 }
0153
0154 if (model.IsInitializedTensor(fNA) && model.IsInitializedTensor(fNB) && model.IsInitializedTensor(fNC)) {
0155 std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0156 std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0157 std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC;
0158 auto dataA = static_cast<T *>(model.GetInitializedTensorData(nameA).get());
0159 auto dataB = static_cast<T *>(model.GetInitializedTensorData(nameB).get());
0160 auto dataC = static_cast<bool *>(model.GetInitializedTensorData(nameC).get());
0161 std::vector<T> dataY(ConvertShapeToLength(fShapeY));
0162 for (size_t i = 0; i < dataY.size(); i++)
0163 dataY[i] = (dataC[i]) ? dataA[i] : dataB[i];
0164 model.AddConstantTensor<T>(fNY, fShapeY, dataY.data());
0165
0166 model.SetNotWritableInitializedTensor(nameA);
0167 model.SetNotWritableInitializedTensor(nameB);
0168 model.SetNotWritableInitializedTensor(nameC);
0169
0170 fIsOutputConstant = true;
0171 if (model.Verbose())
0172 std::cout << "Where op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : "
0173 << ConvertValuesToString(dataY) << std::endl;
0174
0175
0176 fOutputTensorNames.pop_back();
0177 }
0178 else {
0179 model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY);
0180 }
0181 }
0182
0183 std::string GenerateInitCode() override {
0184 std::stringstream out;
0185 return out.str();
0186 }
0187
0188 std::string Generate(std::string OpName) override {
0189
0190 if (fIsOutputConstant) return "";
0191
0192 OpName = "op_" + OpName;
0193
0194 if (fShapeY.empty()) {
0195 throw std::runtime_error("TMVA SOFIE Where Op called to Generate without being initialized first");
0196 }
0197 std::stringstream out;
0198 out << SP << "\n//-------- Where \n";
0199 size_t length = ConvertShapeToLength(fShapeY);
0200 std::string typeName = TensorType<T>::Name();
0201
0202 if (fShapeA != fShapeY) {
0203 out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n";
0204
0205 out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY)
0206 << ", fTensor_" << fNBroadcastedA << ");\n";
0207 }
0208
0209 if (fShapeB != fShapeY) {
0210 out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n";
0211
0212 out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << typeName << ">(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY)
0213 << ", fTensor_" << fNBroadcastedB << ");\n";
0214 }
0215
0216 if (fShapeC != fShapeY) {
0217
0218 if (fIsInputBoolTensor) {
0219 size_t inputLength = ConvertShapeToLength(fShapeC);
0220 out << SP << "std::vector<bool> fTensor_" << fNC << "(tensor_" << fNC << ", tensor_" << fNC << " + " << inputLength << ");\n";
0221 }
0222 out << SP << "// Broadcasting uninitialized tensor " << fNC << "\n";
0223
0224
0225 out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(fTensor_" << fNC << ", " << ConvertShapeToString(fShapeC) << ", " << ConvertShapeToString(fShapeY)
0226 << ", fTensor_" << fNBroadcastedC << ");\n";
0227 }
0228 std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA;
0229 std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB;
0230 std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC;
0231 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0232
0233 out << SP << SP << "tensor_" << fNY << "[id] = " << "(fTensor_" << nameC << "[id]) ? tensor_"
0234 << nameA << "[id] : tensor_" + nameB + "[id];\n";
0235 out << SP << "}\n";
0236 return out.str();
0237 }
0238
0239 };
0240
0241 }
0242 }
0243 }
0244
0245
0246 #endif