File indexing completed on 2025-01-18 10:11:05
0001 #ifndef TMVA_SOFIE_ROPERATOR_Concat
0002 #define TMVA_SOFIE_ROPERATOR_Concat
0003
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <iterator>
0012 #include <iomanip>
0013 #include <limits>
0014
0015 namespace TMVA{
0016 namespace Experimental{
0017 namespace SOFIE{
0018
0019 template <typename T>
0020 class ROperator_Concat final : public ROperator
0021 {
0022 private:
0023 int fAxis=0;
0024 int fnewAxis=0;
0025 std::vector<std::string> fInputs;
0026 std::string fOutput;
0027 std::vector<Dim>fOutputShape;
0028 std::vector<std::vector<Dim>> fInputShapes;
0029
0030 public:
0031 ROperator_Concat(){}
0032 ROperator_Concat(std::vector<std::string> inputs, int axis, int newAxis, std::string output):
0033 fAxis(axis), fnewAxis(newAxis), fOutput(UTILITY::Clean_name(output)) {
0034 fInputs.reserve(inputs.size());
0035 for (auto & name : inputs)
0036 fInputs.push_back(UTILITY::Clean_name(name));
0037 }
0038
0039 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0040 return input;
0041 }
0042
0043
0044 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> inputs){
0045 std::vector<std::vector<size_t>> ret(1);
0046
0047 if (fAxis<0) {
0048 fAxis = inputs[0].size()+fAxis;
0049 }
0050 if (fAxis < 0 || fAxis >= (int) inputs[0].size())
0051 throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
0052
0053 int concat_dim=0;
0054 if(fnewAxis == 0){
0055 for (size_t i = 0; i < inputs.size(); i++) {
0056 if (i > 0 && inputs[i].size() != inputs[i - 1].size())
0057 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " +
0058 ConvertShapeToString(inputs[i]) + " and " + ConvertShapeToString(inputs[i - 1]));
0059 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0060 if ((int)iaxis == fAxis)
0061 concat_dim += inputs[i][iaxis];
0062 else if (i > 0 && inputs[i][iaxis] != inputs[i - 1][iaxis])
0063 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0064 ConvertShapeToString(inputs[i]) + " and " +
0065 ConvertShapeToString(inputs[i - 1]));
0066 }
0067 }
0068
0069
0070 ret[0] = inputs[0];
0071 ret[0][fAxis] = concat_dim;
0072 }
0073 std::vector<int> stack;
0074 if(fnewAxis == 1){
0075 for(size_t i = 0; i < inputs.size(); i++) {
0076 if (i > 0 && inputs[i].size() != inputs[i-1].size() )
0077 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " +
0078 ConvertShapeToString(inputs[i]) + " and " + ConvertShapeToString(inputs[i-1]));
0079 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0080 if ((int) iaxis == fAxis)
0081 stack.push_back(inputs[i][iaxis]);
0082 else
0083 if (i> 0 && inputs[i][iaxis] != inputs[i-1][iaxis])
0084 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0085 ConvertShapeToString(inputs[i]) + " and " + ConvertShapeToString(inputs[i-1]));
0086 }
0087
0088 }
0089 for(auto it:stack)
0090 ret[0].push_back(it);
0091 }
0092
0093 return ret;
0094 }
0095
0096
0097 std::vector<std::vector<Dim>> ShapeInference(const std::vector<std::vector<Dim>> & inputs){
0098 std::vector<std::vector<Dim>> ret(1);
0099
0100 if (fAxis<0) {
0101 fAxis = inputs[0].size()+fAxis;
0102 }
0103 if (fAxis < 0 || fAxis >= (int) inputs[0].size())
0104 throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
0105
0106 int concat_dim=0;
0107 if(fnewAxis == 0){
0108 for (size_t i = 0; i < inputs.size(); i++) {
0109 if (i > 0 && inputs[i].size() != inputs[i - 1].size())
0110 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " +
0111 ConvertDynamicShapeToString(inputs[i]) + " and " + ConvertDynamicShapeToString(inputs[i - 1]));
0112 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0113 if ((int)iaxis == fAxis) {
0114
0115 if (inputs[i][iaxis].isParam)
0116 throw std::runtime_error("TMVA SOFIE Concat Op - not supporting input param dimensions for concatenation axis. Input shape is " +
0117 ConvertDynamicShapeToString(inputs[i]));
0118 concat_dim += inputs[i][iaxis].dim;
0119 }
0120
0121 else if (i > 0 && inputs[i][iaxis].GetVal() != inputs[i - 1][iaxis].GetVal())
0122 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0123 ConvertDynamicShapeToString(inputs[i]) + " and " +
0124 ConvertDynamicShapeToString(inputs[i - 1]));
0125 }
0126 }
0127
0128
0129 ret[0] = inputs[0];
0130 ret[0][fAxis].dim = concat_dim;
0131 }
0132
0133
0134
0135
0136 if(fnewAxis == 1){
0137 throw std::runtime_error("TMVA SOFIE Concat Op - stacking (i.e. COncatFromSequence with new_axis=1) is not supported ");
0138 }
0139 return ret;
0140 }
0141
0142 void Initialize(RModel &model)
0143 {
0144 for (auto &it : fInputs) {
0145 if (model.CheckIfTensorAlreadyExist(it) == false) {
0146 throw std::runtime_error("TMVA SOFIE Concat Op Input Tensor " + it + " is not found in model");
0147 }
0148 fInputShapes.push_back(model.GetDynamicTensorShape(it));
0149 }
0150 fOutputShape = ShapeInference(fInputShapes)[0];
0151 model.AddIntermediateTensor(fOutput, model.GetTensorType(fInputs[0]), fOutputShape);
0152 }
0153
0154 std::string Generate(std::string OpName){
0155 OpName = "op_"+OpName;
0156 if(fOutputShape.empty()){
0157 throw std::runtime_error("TMVA SOFIE Concat called to Generate without being initialized first");
0158 }
0159 std::stringstream out;
0160 out<<"\n//--------- Concat\n";
0161
0162 bool hasShapeOnes = true;
0163 for(int i = 0; i<fAxis; ++i){
0164 if(fInputShapes[0][i].dim !=1){
0165 hasShapeOnes = false;
0166 break;
0167 }
0168 }
0169 if (fAxis == 0 || hasShapeOnes) {
0170 std::string offset;
0171 for(size_t i=0; i<fInputs.size(); ++i) {
0172 std::string length = ConvertDynamicShapeToLength(fInputShapes[i]);
0173 out << SP << "std::copy(tensor_" <<fInputs[i] << ", tensor_" <<fInputs[i] << "+" << length <<", tensor_"<<fOutput;
0174 if (i > 0) out << offset;
0175 offset += " + " + length;
0176 out << ");\n";
0177 }
0178 }
0179 else {
0180
0181 std::vector<Dim> outStride = UTILITY::ComputeStrideFromShape(fOutputShape);
0182 std::vector<std::vector<Dim>> inStrides(fInputs.size());
0183 int idx = 0;
0184 for ( auto &s : inStrides) {
0185 s = UTILITY::ComputeStrideFromShape(fInputShapes[idx]);
0186 idx++;
0187 }
0188 for (int i = 0; i < fAxis; ++i) {
0189
0190 out << SP << "for (size_t i" << i << " = 0; i" << i << " < " << fOutputShape[i].GetVal() << "; ++i" << i <<") {\n";
0191 }
0192
0193 out << SP << SP << SP << "int idxOut = ";
0194 for (int k = 0; k < fAxis; k++) {
0195 if (k > 0) out << " + ";
0196 out << outStride[k].GetVal() << "*i" << k;
0197 }
0198 out << ";\n";
0199
0200 for (size_t j = 0; j < fInputs.size(); j++) {
0201 if (j>0)
0202 out << SP << SP << SP << "idxOut += " << fInputShapes[j-1][fAxis].GetVal() << ";\n";
0203 out << SP << SP << SP << "int idxIn" << j <<" = ";
0204 for (int k = 0; k < fAxis; k++) {
0205 if (k > 0) out << " + ";
0206 out << inStrides[j][k].GetVal() << "*i" << k;
0207 }
0208 out << ";\n";
0209 out << SP << SP << SP << "for (size_t iC = 0; iC < " << fInputShapes[j][fAxis].GetVal() << "; ++iC) {\n";
0210 out << SP << SP << SP << SP << "tensor_" << fOutput << "[idxOut+iC] = tensor_" << fInputs[j] << "[idxIn" << j << "+iC];\n";
0211 out << SP << SP << SP << "}\n";
0212
0213 }
0214 for (int i = 0; i < fAxis; ++i) {
0215 out << SP << "}\n";
0216 }
0217 }
0218
0219 return out.str();
0220 }
0221 };
0222 }
0223 }
0224 }
0225
0226 #endif