Warning, file /include/root/TMVA/ROperator_Tile.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_Tile
0002 #define TMVA_SOFIE_ROPERATOR_Tile
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template <typename T>
0015 class ROperator_Tile final : public ROperator
0016 {
0017
0018 private:
0019
0020 std::string fNRepeats;
0021 std::string fNInput;
0022 std::string fNY;
0023 std::vector<size_t>fShapeInput;
0024 std::vector<size_t> fShapeY;
0025
0026 public:
0027 ROperator_Tile(){}
0028 ROperator_Tile(std::string nameRepeat, std::string nameInput, std::string nameY):
0029 fNRepeats(UTILITY::Clean_name(nameRepeat)),fNInput(UTILITY::Clean_name(nameInput)), fNY(UTILITY::Clean_name(nameY)){
0030 fInputTensorNames = { fNRepeats, fNInput };
0031 fOutputTensorNames = { fNY };
0032 }
0033
0034 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0035 return input;
0036 }
0037
0038 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0039 std::vector<size_t> ret = input[0];
0040
0041 for(size_t i=0; i < input[1].size(); i++) {
0042 ret[i]=ret[i]*input[1][i];
0043 }
0044 return {ret};
0045 }
0046
0047 void Initialize(RModel& model) override {
0048
0049 if (model.CheckIfTensorAlreadyExist(fNInput) == false){
0050 throw std::runtime_error("TMVA SOFIE Tile Op Input Tensor is not found in model");
0051 }
0052 if (model.CheckIfTensorAlreadyExist(fNRepeats) == false){
0053 throw std::runtime_error("TMVA SOFIE Tile Op Input Tensor is not found in model");
0054 }
0055 fShapeInput=model.GetTensorShape(fNInput);
0056
0057
0058
0059 if (!model.IsInitializedTensor(fNRepeats)) {
0060 throw std::runtime_error("TMVA SOFIE Tile Op: non-initialized repeats input is not supported");
0061 }
0062
0063
0064 auto repptr = model.GetInitializedTensorData(fNRepeats);
0065
0066 auto repeats_data = static_cast<int64_t*>(repptr.get());
0067 if (repeats_data == nullptr) {
0068 throw std::runtime_error("Failed to retrieve the data for the repeats tensor.");
0069 }
0070
0071 auto repeats_shape = model.GetTensorShape(fNRepeats);
0072
0073 if (repeats_shape.size() != 1) {
0074 throw std::runtime_error("Repeats tensor is not 1D.");
0075 }
0076 size_t num_elements = repeats_shape[0];
0077
0078 std::vector<size_t> repeats_vector(num_elements);
0079 std::copy(repeats_data, repeats_data + num_elements, repeats_vector.begin());
0080
0081
0082 fShapeY = ShapeInference({fShapeInput,repeats_vector})[0];
0083
0084 model.AddIntermediateTensor(fNY, model.GetTensorType(fNInput), fShapeY);
0085
0086 if (model.Verbose())
0087 std::cout << "Tile: " << fNInput << " " << ConvertShapeToString(fShapeInput) << " -> " << fNY << " with shape " << ConvertShapeToString(fShapeY)
0088 << " given repeats " << ConvertShapeToString(repeats_vector) << std::endl;
0089 }
0090
0091 std::string Generate(std::string OpName) override {
0092 OpName = "op_" + OpName;
0093 if (fShapeInput.empty() || fShapeY.empty()) {
0094 throw std::runtime_error("TMVA SOFIE Tile Op called to Generate without being initialized first");
0095 }
0096
0097
0098
0099
0100
0101 std::stringstream out;
0102 std::string input = "tensor_" + fNInput;
0103 std::string output = "tensor_" + fNY;
0104 out << "///-------- Tile operator\n";
0105 out << "{\n";
0106 out << "const int input_shape[" << fShapeInput.size() << "] = " << ConvertShapeToString(fShapeInput) << ";\n";
0107
0108 out << "int inputLength = " << ConvertShapeToLength(fShapeInput) << ";\n";
0109 out << "int s = 1;\n";
0110
0111 out << "for (int i = " << fShapeInput.size()-1 << "; i >=0; i--) {\n";
0112 out << SP << "int r = tensor_" << fNRepeats << "[i];\n";
0113
0114
0115 out << SP << "int i_offset = 0, o_offset = 0;\n";
0116 out << SP << "s = s * input_shape[i];\n";
0117
0118 out << SP << "if (i == " << fShapeInput.size()-1 << ") {\n";
0119 out << SP << SP << "for (int j = 0; j < inputLength/s ; j++) {\n";
0120 out << SP << SP << SP << "for (int k = 0; k < r ; k++) {\n";
0121 out << SP << SP << SP << SP << "std::copy(" << input << "+ i_offset, "
0122 << input << "+ i_offset + s, " << output << "+ o_offset);\n";
0123 out << SP << SP << SP << SP << "o_offset += s;\n";
0124 out << SP << SP << SP << "}\n";
0125 out << SP << SP << SP << "i_offset += s;\n";
0126 out << SP << SP << "}\n";
0127 out << SP << "} else {\n";
0128
0129 out << SP << SP << "for (int j = inputLength/s - 1 ; j>=0; j--) {\n";
0130 out << SP << SP << SP << "o_offset = j*s*r;\n";
0131 out << SP << SP << SP << "i_offset = j*s;\n";
0132 out << SP << SP << SP << "for (int k = 0; k < r ; k++) {\n";
0133 out << SP << SP << SP << SP << "std::copy(" << output << "+ i_offset, "
0134 << output << "+ i_offset + s, " << output << "+ o_offset);\n";
0135 out << SP << SP << SP << SP << "o_offset += s;\n";
0136 out << SP << SP << SP << "}\n";
0137 out << SP << SP << "}\n";
0138 out << SP << "}\n";
0139 out << SP << "s *= r;\n";
0140 out << SP << "inputLength *= r;\n";
0141 out << "}\n";
0142 out << "}\n";
0143 return out.str();
0144 }
0145 };
0146
0147 }
0148 }
0149 }
0150
0151 #endif