Warning, file /include/root/TMVA/ROperator_Split.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_Split
0002 #define TMVA_SOFIE_ROPERATOR_Split
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014
0015 class ROperator_Split final : public ROperator
0016 {
0017
0018 private:
0019
0020 int fAxis = 0;
0021 std::string fNX;
0022 std::string fNSplit;
0023 std::vector<std::string> fNYs;
0024 std::vector<size_t> fInputShape;
0025 std::vector<int64_t> fSplit;
0026 std::vector<std::vector<size_t>> fOutputShapes;
0027
0028
0029
0030 public:
0031 ROperator_Split(){}
0032 ROperator_Split(const std::string & nameX, const std::string & nameS, int axis, const std::vector<std::string> & namesY):
0033 fAxis(axis), fNX(UTILITY::Clean_name(nameX)), fNSplit(UTILITY::Clean_name(nameS)) {
0034 fNYs.reserve(namesY.size());
0035 for (auto & name : namesY)
0036 fNYs.push_back(UTILITY::Clean_name(name));
0037
0038 fInputTensorNames = { fNX };
0039 fOutputTensorNames.resize(fNYs.size());
0040 std::transform(fNYs.begin(), fNYs.end(), fOutputTensorNames.begin(),
0041 [](const std::string& s) -> std::string_view { return s; });
0042 }
0043
0044 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0045 return input;
0046 }
0047
0048 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0049 auto ret = input;
0050 return ret;
0051 }
0052
0053 void Initialize(RModel& model) override {
0054 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0055 throw std::runtime_error("TMVA SOFIE Split Op Input Tensor is not found in model");
0056 }
0057 fInputShape = model.GetTensorShape(fNX);
0058
0059
0060 if (fAxis < 0) fAxis += fInputShape.size();
0061 if (fAxis < 0 || fAxis >= static_cast<int>(fInputShape.size()) )
0062 throw std::runtime_error("TMVA SOFIE Split - invalid axis " + std::to_string(fAxis));
0063
0064
0065 size_t nsplit = fNYs.size();
0066
0067 if (fNSplit.empty()) {
0068 int64_t splitValue = 0;
0069 if (fInputShape[fAxis] % nsplit == 0) {
0070 splitValue = fInputShape[fAxis]/nsplit;
0071 fSplit = std::vector<int64_t>(nsplit, splitValue);
0072 } else {
0073
0074 splitValue = std::ceil(double(fInputShape[fAxis])/nsplit);
0075 fSplit = std::vector<int64_t>(nsplit-1, splitValue);
0076 fSplit.push_back(fInputShape[fAxis] % splitValue);
0077 }
0078 } else {
0079
0080 if (!model.IsInitializedTensor(fNSplit))
0081 throw std::runtime_error("TMVA SOFIE Split - non-initialized split tensors are not supported");
0082 auto splitShape = model.GetTensorShape(fNSplit);
0083 if (splitShape.size() != 1 || splitShape[0] != nsplit)
0084 throw std::runtime_error("TMVA SOFIE Split - split input tensor has invalid shape");
0085 auto split_data = static_cast<int64_t *>(model.GetInitializedTensorData(fNSplit).get());
0086 fSplit = std::vector<int64_t>(split_data, split_data + nsplit);
0087 }
0088
0089 size_t tot_split = 0;
0090 for (size_t i = 0; i < fNYs.size(); i++) {
0091 std::vector<size_t> outputShape = fInputShape;
0092 outputShape[fAxis] = fSplit[i];
0093 tot_split += fSplit[i];
0094 model.AddIntermediateTensor(fNYs[i], model.GetTensorType(fNX), outputShape);
0095 fOutputShapes.push_back(outputShape);
0096 }
0097 if (tot_split != fInputShape[fAxis])
0098 throw std::runtime_error("TMVA SOFIE Split - Sum of split sizes must match the input dimension along the axis");
0099
0100
0101 if (model.Verbose()) {
0102 std::cout << "Split - input shape " << ConvertShapeToString(fInputShape) << " --> ";
0103 for (auto & s : fOutputShapes)
0104 std::cout << ConvertShapeToString(s) << " ";
0105 std::cout << std::endl;
0106 }
0107 }
0108
0109
0110 std::string Generate(std::string OpName) override {
0111 OpName = "op_" + OpName;
0112 if (fOutputShapes.empty()){
0113 throw std::runtime_error("TMVA SOFIE Operator Split called to Generate without being initialized first");
0114 }
0115
0116 auto input_strides = UTILITY::ComputeStrideFromShape(fInputShape);
0117
0118
0119 std::stringstream out;
0120 out << "\n" << SP << "//------ Split\n";
0121 out << SP << "size_t " << OpName << "_axis_offset = 0;\n";
0122
0123 for (size_t i = 0; i < fNYs.size(); i++) {
0124 size_t length = ConvertShapeToLength(fOutputShapes[i]);
0125 auto output_strides = UTILITY::ComputeStrideFromShape(fOutputShapes[i]);
0126
0127 out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
0128
0129 out << SP << SP << "int input_index = 0;\n";
0130 out << SP << SP << "int remaining = id;\n";
0131
0132 for (size_t k = 0; k < fOutputShapes[i].size(); ++k) {
0133 out << SP << SP << "// dim " << k << "\n";
0134 if (k < fOutputShapes[i].size()-1) {
0135 out << SP << SP << "input_index += (int(remaining / " << output_strides[k] << ")";
0136
0137 if (k == static_cast<size_t>(fAxis) && i > 0)
0138 out << " + " << OpName << "_axis_offset";
0139 out << ") * " << input_strides[k] << ";\n";
0140 out << SP << SP << "remaining %= " << output_strides[k] << ";\n";
0141 } else {
0142
0143 out << SP << SP << "input_index += remaining";
0144 if (k == static_cast<size_t>(fAxis) && i > 0)
0145 out << " + " << OpName << "_axis_offset";
0146 out << ";\n\n";
0147 }
0148 }
0149
0150 out << SP << SP << "tensor_" << fNYs[i] << "[id] = tensor_" << fNX <<"[input_index];\n";
0151 out << SP << "}\n";
0152 if (i < fNYs.size()-1) out << SP << OpName << "_axis_offset += " << fSplit[i] << ";\n";
0153 }
0154 return out.str();
0155 }
0156
0157 };
0158
0159 }
0160 }
0161 }
0162
0163
0164 #endif