File indexing completed on 2025-01-18 10:11:09
0001 #ifndef TMVA_SOFIE_ROPERATOR_TRANSPOSE
0002 #define TMVA_SOFIE_ROPERATOR_TRANSPOSE
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009 #include <cassert>
0010
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014
0015
0016
0017
0018 template <typename T>
0019 class ROperator_Transpose final : public ROperator
0020 {
0021
0022 private:
0023 std::vector<int_t> fAttrPerm;
0024
0025 std::string fNData;
0026 std::string fNOutput;
0027 std::vector<size_t> fShapeData;
0028 std::vector<size_t> fShapeOutput;
0029
0030 public:
0031
0032 ROperator_Transpose(){}
0033 ROperator_Transpose(std::vector<int_t> attr_perm, std::string nameData, std::string nameOutput):
0034 fAttrPerm(attr_perm), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)) {
0035 }
0036
0037 ROperator_Transpose(std::string nameData, std::string nameOutput):
0038 fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)) {
0039 }
0040
0041 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0042 return input;
0043 }
0044
0045 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0046 if (input.size() > 1) throw std::runtime_error("TMVA SOFIE Tranpose Op Shape Inference only need 1 input tensor");
0047 auto& data = input[0];
0048 if (fAttrPerm.size() != data.size() )
0049 throw std::runtime_error("TMVA SOFIE Tranpose Op - Invalid axes attributes");
0050
0051 std::vector<size_t> output_shape(fAttrPerm.size());
0052 for (size_t i = 0; i < fAttrPerm.size(); i++){
0053 output_shape[i] = data[fAttrPerm[i]];
0054 }
0055 std::vector<std::vector<size_t>> ret;
0056 ret.push_back(output_shape);
0057 return ret;
0058 }
0059
0060
0061 void Initialize(RModel& model){
0062 if (model.CheckIfTensorAlreadyExist(fNData) == false){
0063 std::cout<<"Input tensor for transspose: "<<fNData<<'\n';
0064 throw std::runtime_error("TMVA SOFIE Tranpose Op Input Tensor is not found in model");
0065 }
0066 fShapeData = model.GetTensorShape(fNData);
0067 if (fAttrPerm.empty()){
0068 fAttrPerm.reserve(fShapeData.size());
0069 for (int i = fShapeData.size() - 1; i >= 0; i--){
0070 fAttrPerm.push_back(i);
0071 }
0072 }
0073 std::vector<std::vector<size_t>> inputs = { fShapeData };
0074 fShapeOutput = ShapeInference(inputs).front();
0075 model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0076 }
0077
0078 std::string Generate(std::string OpName){
0079 OpName = "op_" + OpName;
0080 if (fShapeData.empty() || fShapeOutput.empty()){
0081 throw std::runtime_error("TMVA SOFIE Transpose Op called to Generate without being initialized first");
0082 }
0083 int dim = fShapeData.size();
0084 auto inStrides = UTILITY::ComputeStrideFromShape(fShapeData);
0085 auto outStrides = UTILITY::ComputeStrideFromShape(fShapeOutput);
0086 size_t length = inStrides[0]*fShapeData[0];
0087 assert (length == outStrides[0]*fShapeOutput[0]);
0088
0089 std::stringstream out;
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099
0100 out << SP << "///------- Transpose operator\n" << std::endl;
0101 out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
0102 out << SP << SP << "tensor_" << fNOutput << "[id] = tensor_" << fNData << "[ ";
0103
0104 std::vector<std::string> i_out(dim);
0105 for (int k =0; k < dim; k++){
0106 if (k == 0)
0107 i_out[k] = "id";
0108 else
0109 i_out[k] = "(id % " + std::to_string(outStrides[k-1]) + ")";
0110 if (k < dim-1)
0111 i_out[k] += " / " + std::to_string(outStrides[k]);
0112 }
0113
0114
0115 for (int k =0; k < dim; k++){
0116
0117 int l = std::find(fAttrPerm.begin(), fAttrPerm.end(), k) - fAttrPerm.begin();
0118 assert(l >= 0 && l < dim);
0119 out << "( " << i_out[l] << " )";
0120 if (k < dim-1) {
0121 out << " * " << inStrides[k];
0122 out << " + ";
0123 }
0124 }
0125 out << "];\n";
0126 out << SP << "}\n";
0127 return out.str();
0128 }
0129
0130
0131 };
0132
0133 }
0134 }
0135 }
0136
0137
0138 #endif