File indexing completed on 2025-01-18 10:11:08
0001 #ifndef TMVA_SOFIE_ROPERATOR_SLICE
0002 #define TMVA_SOFIE_ROPERATOR_SLICE
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <cassert>
0009 #include <sstream>
0010 #include <numeric>
0011
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015
0016
0017
0018 template <typename T, typename IType>
0019 class ROperator_Slice final : public ROperator
0020 {
0021
0022 private:
0023
0024 std::string fNData;
0025 std::string fNOutput;
0026 std::vector<std::string> fNames;
0027 std::vector<size_t> fShapeInput;
0028 std::vector<size_t> fShapeOutput;
0029
0030
0031 std::vector<size_t> fStart;
0032 std::vector<size_t> fEnd;
0033 std::vector<size_t> fSteps;
0034
0035 std::vector<std::vector<IType>> fAttributes;
0036
0037
0038 public:
0039
0040 ROperator_Slice(){}
0041
0042
0043 ROperator_Slice(std::string nameData, std::vector<std::string> names, std::string nameOutput)
0044 : fNData(UTILITY::Clean_name(nameData)),
0045 fNOutput(UTILITY::Clean_name(nameOutput))
0046 {
0047 fNames.resize(4);
0048 for (size_t i = 0; i < names.size(); ++i) {
0049 fNames[i] = UTILITY::Clean_name(names[i]);
0050 }
0051
0052 if (names.size() == 3) {
0053 if (names[2] != "axes") {
0054 fNames[3] = fNames[2];
0055 fNames[2] = "";
0056 }
0057 else {
0058 fNames[3] = "";
0059 }
0060 }
0061 }
0062
0063 ROperator_Slice(std::string nameData, std::vector<IType> starts, std::vector<IType> ends, std::vector<IType> axes, std::string nameOutput)
0064 : fNData(UTILITY::Clean_name(nameData)),
0065 fNOutput(UTILITY::Clean_name(nameOutput))
0066 {
0067 fAttributes.push_back(starts);
0068 fAttributes.push_back(ends);
0069 fAttributes.push_back(axes);
0070 }
0071
0072
0073 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0074 auto ret = std::vector<ETensorType>(1, input[0]);
0075 return ret;
0076 }
0077
0078
0079 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0080 auto & input_shape = input[0];
0081
0082 std::vector<std::vector<size_t>> ret(1, input_shape);
0083 auto & output_shape = ret[0];
0084 for (size_t i = 0; i < input_shape.size(); i++) {
0085 output_shape[i] = (fEnd[i]-fStart[i])/ fSteps[i];
0086 }
0087 return ret;
0088 }
0089
0090
0091 void Initialize(RModel& model){
0092 if (model.CheckIfTensorAlreadyExist(fNData) == false){
0093 throw std::runtime_error("TMVA Slice Op Input Tensor is not found in model");
0094 }
0095
0096 std::vector<std::vector<size_t>> shapes;
0097 fShapeInput = model.GetTensorShape(fNData);
0098 shapes.push_back(fShapeInput);
0099
0100 std::vector<std::vector<IType>> itensors(4);
0101 if (fNames.size() > 0) {
0102
0103 for (size_t i = 0; i < fNames.size(); ++i) {
0104 if (!fNames[i].empty()) {
0105
0106 auto dptr = model.GetInitializedTensorData(fNames[i]);
0107 auto tensor = static_cast<IType *>(dptr.get());
0108 auto vec = model.GetTensorShape(fNames[i]);
0109 assert(vec.size() == 1);
0110 itensors[i] = std::vector<IType>(tensor, tensor + vec[0]);
0111 }
0112 else {
0113 switch (i)
0114 {
0115 case 2:
0116 itensors[2] = std::vector<IType>(fShapeInput.size());
0117 std::iota(itensors[2].begin(), itensors[2].end(), 0);
0118 break;
0119 case 3:
0120 itensors[3] = std::vector<IType>(itensors[0].size(), 1);
0121 default:
0122 break;
0123 }
0124 }
0125 }
0126 } else {
0127 assert (fAttributes.size() > 1);
0128 for (size_t i = 0; i < fAttributes.size(); i++) {
0129 itensors[i] = fAttributes[i];
0130 }
0131 }
0132 size_t dim = fShapeInput.size();
0133
0134 fSteps = std::vector<size_t>(dim, 1);
0135 fStart = std::vector<size_t>(dim, 0);
0136 fEnd = fShapeInput;
0137
0138 auto istart = itensors[0];
0139 auto iend = itensors[1];
0140 auto iaxes = itensors[2];
0141 auto isteps = itensors[3];
0142
0143
0144
0145 if (iaxes.size() > 0) {
0146 for (size_t i = 0; i < iaxes.size(); i++) {
0147
0148 if (iaxes[i] < 0) iaxes[i] = dim + iaxes[i];
0149 size_t jaxis = static_cast<size_t>(iaxes[i]);
0150 assert(jaxis < dim);
0151 size_t imax = fShapeInput[jaxis];
0152
0153 IType start = (istart[i] >= 0) ? istart[i] : imax + istart[i];
0154 if (start < 0) start = 0;
0155 if (start > static_cast<IType>(imax))
0156 start = imax;
0157 fStart[jaxis] = start;
0158 IType ie = (iend[i] >= 0) ? iend[i] : imax + iend[i];
0159 if (ie < 0) ie = 0;
0160 if (ie > static_cast<IType>(imax))
0161 ie = imax;
0162 fEnd[jaxis] = ie;
0163
0164 if (isteps.size() > 0) {
0165 if (isteps[i] < 0) {
0166
0167 throw std::runtime_error("TMVA Slice Op : negative steps not supported");
0168 }
0169 fSteps[jaxis] = isteps[i];
0170 assert(fSteps[jaxis] > 0 && fSteps[jaxis] < fShapeInput[jaxis]);
0171 }
0172 }
0173 }
0174
0175 fShapeOutput = ShapeInference({fShapeInput})[0];
0176 model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0177 }
0178
0179 std::string Generate(std::string OpName){
0180 OpName = "op_" + OpName;
0181 if (fShapeInput.empty() || fShapeOutput.empty()){
0182 throw std::runtime_error("TMVA SOFIE Slice Op called to Generate without being initialized first");
0183 }
0184
0185 std::stringstream out;
0186
0187
0188 out << SP << "///------- Slice operator\n" << std::endl;
0189
0190 size_t ndim = fShapeInput.size();
0191 std::vector<size_t> strides(ndim,1);
0192 for (int i = int(ndim-2); i >=0 ; i--) {
0193 strides[i] = strides[i+1]*fShapeInput[i+1];
0194 }
0195
0196 out << SP << "size_t iOut = 0;\n";
0197 std::string MSP = SP;
0198 for (size_t idim = 0; idim < ndim; idim++) {
0199 out << MSP << "for (size_t i" << idim << " = " << fStart[idim] << "; i" << idim << " < " << fEnd[idim]
0200 << "; i" << idim << "+= " << fSteps[idim] << ") {\n";
0201 MSP += SP;
0202 if (idim < ndim-1) out << MSP << "size_t stride" << idim << " = " << strides[idim] << "*i" << idim << ";\n";
0203 }
0204 out << MSP << "size_t iInput = ";
0205 for (size_t idim = 0; idim < ndim-1; idim++) out << " stride" << idim << " + ";
0206
0207 out << "i" << ndim-1 << ";\n";
0208 out << MSP << "tensor_" << fNOutput << "[iOut++] = tensor_" <<fNData << "[iInput];\n";
0209 for (size_t idim = 0; idim < ndim; idim++) {
0210 MSP = MSP.replace(0,SP.length(),"");
0211 out << MSP << "}\n";
0212 }
0213
0214 return out.str();
0215 }
0216
0217 };
0218
0219 }
0220 }
0221 }
0222
0223
0224 #endif