Warning, file /include/root/TMVA/ROperator_Split.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_Split
0002 #define TMVA_SOFIE_ROPERATOR_Split
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014
0015 class ROperator_Split final : public ROperator
0016 {
0017
0018 private:
0019
0020 int fAxis = 0;
0021 std::string fNX;
0022 std::string fNSplit;
0023 std::vector<std::string> fNYs;
0024 std::vector<Dim> fInputShape;
0025 std::vector<int64_t> fSplit;
0026 std::vector<std::vector<Dim>> fOutputShapes;
0027
0028
0029
0030 public:
0031 ROperator_Split(){}
0032 ROperator_Split(const std::string & nameX, const std::string & nameS, int axis, const std::vector<std::string> & namesY):
0033 fAxis(axis), fNX(UTILITY::Clean_name(nameX)), fNSplit(UTILITY::Clean_name(nameS)) {
0034 fNYs.reserve(namesY.size());
0035 for (auto & name : namesY)
0036 fNYs.push_back(UTILITY::Clean_name(name));
0037
0038 fInputTensorNames = { fNX };
0039 fOutputTensorNames.resize(fNYs.size());
0040 std::transform(fNYs.begin(), fNYs.end(), fOutputTensorNames.begin(),
0041 [](const std::string& s) -> std::string_view { return s; });
0042 }
0043
0044 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0045 return input;
0046 }
0047
0048 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0049 auto ret = input;
0050 return ret;
0051 }
0052
0053 void Initialize(RModel& model) override {
0054 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0055 throw std::runtime_error("TMVA SOFIE Split Op Input Tensor is not found in model");
0056 }
0057 fInputShape = model.GetDimTensorShape(fNX);
0058
0059
0060 if (fAxis < 0) fAxis += fInputShape.size();
0061 if (fAxis < 0 || fAxis >= static_cast<int>(fInputShape.size()) )
0062 throw std::runtime_error("TMVA SOFIE Split - invalid axis " + std::to_string(fAxis));
0063
0064
0065 if (fInputShape[fAxis].isParam)
0066 throw std::runtime_error("TMVA SOFIE Split - splitting in dynamic axis is not supported");
0067
0068 size_t origValue = fInputShape[fAxis].dim;
0069
0070
0071 size_t nsplit = fNYs.size();
0072
0073 if (fNSplit.empty()) {
0074 int64_t splitValue = 0;
0075 if (origValue % nsplit == 0) {
0076 splitValue = origValue/nsplit;
0077 fSplit = std::vector<int64_t>(nsplit, splitValue);
0078 } else {
0079
0080 splitValue = std::ceil(double(origValue)/nsplit);
0081 fSplit = std::vector<int64_t>(nsplit-1, splitValue);
0082 fSplit.push_back(origValue % splitValue);
0083 }
0084 } else {
0085
0086
0087 if (!model.IsInitializedTensor(fNSplit))
0088 throw std::runtime_error("TMVA SOFIE Split - non-initialized split tensors are not supported");
0089 auto splitShape = model.GetTensorShape(fNSplit);
0090 if (splitShape.size() != 1 || splitShape[0] != nsplit)
0091 throw std::runtime_error("TMVA SOFIE Split - split input tensor has invalid shape");
0092 auto split_data = static_cast<int64_t *>(model.GetInitializedTensorData(fNSplit).get());
0093 fSplit = std::vector<int64_t>(split_data, split_data + nsplit);
0094 }
0095
0096 size_t tot_split = 0;
0097 for (size_t i = 0; i < fNYs.size(); i++) {
0098 std::vector<Dim> outputShape = fInputShape;
0099 outputShape[fAxis] = Dim{ static_cast<size_t>(fSplit[i]) };
0100 tot_split += fSplit[i];
0101 model.AddIntermediateTensor(fNYs[i], model.GetTensorType(fNX), outputShape);
0102 fOutputShapes.push_back(outputShape);
0103 }
0104 if (tot_split != origValue)
0105 throw std::runtime_error("TMVA SOFIE Split - Sum of split sizes must match the input dimension along the axis");
0106
0107
0108 if (model.Verbose()) {
0109 std::cout << "Split - input shape " << ConvertShapeToString(fInputShape) << " --> ";
0110 for (auto & s : fOutputShapes)
0111 std::cout << ConvertShapeToString(s) << " ";
0112 std::cout << std::endl;
0113 }
0114 }
0115
0116
0117 std::string Generate(std::string OpName) override {
0118 OpName = "op_" + OpName;
0119 if (fOutputShapes.empty()){
0120 throw std::runtime_error("TMVA SOFIE Operator Split called to Generate without being initialized first");
0121 }
0122
0123 auto input_strides = UTILITY::ComputeStrideFromShape(fInputShape);
0124
0125
0126 std::stringstream out;
0127 out << "\n" << SP << "//------ Split\n";
0128 out << SP << "size_t " << OpName << "_axis_offset = 0;\n";
0129
0130 for (size_t i = 0; i < fNYs.size(); i++) {
0131 auto length = ConvertDimShapeToLength(fOutputShapes[i]);
0132 auto output_strides = UTILITY::ComputeStrideFromShape(fOutputShapes[i]);
0133
0134 out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
0135
0136 out << SP << SP << "int input_index = 0;\n";
0137 out << SP << SP << "int remaining = id;\n";
0138
0139 for (size_t k = 0; k < fOutputShapes[i].size(); ++k) {
0140 out << SP << SP << "// dim " << k << "\n";
0141 if (k < fOutputShapes[i].size()-1) {
0142 out << SP << SP << "input_index += (int(remaining / " << output_strides[k] << ")";
0143
0144 if (k == static_cast<size_t>(fAxis) && i > 0)
0145 out << " + " << OpName << "_axis_offset";
0146 out << ") * " << input_strides[k] << ";\n";
0147 out << SP << SP << "remaining %= " << output_strides[k] << ";\n";
0148 } else {
0149
0150 out << SP << SP << "input_index += remaining";
0151 if (k == static_cast<size_t>(fAxis) && i > 0)
0152 out << " + " << OpName << "_axis_offset";
0153 out << ";\n\n";
0154 }
0155 }
0156
0157 out << SP << SP << "tensor_" << fNYs[i] << "[id] = tensor_" << fNX <<"[input_index];\n";
0158 out << SP << "}\n";
0159 if (i < fNYs.size()-1) out << SP << OpName << "_axis_offset += " << fSplit[i] << ";\n";
0160 }
0161 return out.str();
0162 }
0163
0164 };
0165
0166 }
0167 }
0168 }
0169
0170
0171 #endif