File indexing completed on 2026-05-06 08:52:54
0001 #ifndef TMVA_SOFIE_ROPERATOR_Concat
0002 #define TMVA_SOFIE_ROPERATOR_Concat
0003
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <iterator>
0012 #include <iomanip>
0013 #include <limits>
0014
0015 namespace TMVA{
0016 namespace Experimental{
0017 namespace SOFIE{
0018
0019 class ROperator_Concat final : public ROperator
0020 {
0021 private:
0022 int fAxis=0;
0023 int fnewAxis=0;
0024 std::vector<std::string> fInputs;
0025 std::string fOutput;
0026 std::vector<Dim>fOutputShape;
0027 std::vector<std::vector<Dim>> fInputShapes;
0028
0029 public:
0030
0031 ROperator_Concat(){}
0032 ROperator_Concat(std::vector<std::string> inputs, int axis, int newAxis, std::string output):
0033 fAxis(axis), fnewAxis(newAxis), fOutput(UTILITY::Clean_name(output)) {
0034 fInputs.reserve(inputs.size());
0035 for (auto & name : inputs)
0036 fInputs.push_back(UTILITY::Clean_name(name));
0037
0038 fInputTensorNames.resize(fInputs.size());
0039 std::transform(fInputs.begin(), fInputs.end(), fInputTensorNames.begin(),
0040 [](const std::string& s) -> std::string_view { return s; });
0041 fOutputTensorNames = { fOutput };
0042 }
0043
0044 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0045 return input;
0046 }
0047
0048
0049 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> inputs) override {
0050 std::vector<std::vector<size_t>> ret(1);
0051
0052 if (fAxis<0) {
0053 fAxis = inputs[0].size()+fAxis;
0054 }
0055 if (fAxis < 0 || fAxis >= (int) inputs[0].size())
0056 throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
0057
0058 int concat_dim=0;
0059
0060 if(fnewAxis == 0){
0061 for (size_t i = 0; i < inputs.size(); i++) {
0062 if (i > 0 && inputs[i].size() != inputs[i - 1].size())
0063 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " +
0064 ConvertShapeToString(inputs[i]) + " and " + ConvertShapeToString(inputs[i - 1]));
0065 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0066 if ((int)iaxis == fAxis)
0067 concat_dim += inputs[i][iaxis];
0068 else if (i > 0 && inputs[i][iaxis] != inputs[i - 1][iaxis])
0069 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0070 ConvertShapeToString(inputs[i]) + " and " +
0071 ConvertShapeToString(inputs[i - 1]));
0072 }
0073 }
0074
0075
0076 ret[0] = inputs[0];
0077 ret[0][fAxis] = concat_dim;
0078 }
0079 std::vector<int> stack;
0080
0081 if(fnewAxis == 1){
0082 for(size_t i = 0; i < inputs.size(); i++) {
0083 if (i > 0 && inputs[i].size() != inputs[i-1].size() )
0084 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " + fInputs[i] + " : " +
0085 ConvertShapeToString(inputs[i]) + " and " + fInputs[i-1] + " : " + ConvertShapeToString(inputs[i-1]));
0086 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0087 if ((int) iaxis == fAxis)
0088 stack.push_back(inputs[i][iaxis]);
0089 else
0090 if (i> 0 && inputs[i][iaxis] != inputs[i-1][iaxis])
0091 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0092 ConvertShapeToString(inputs[i]) + " and " + ConvertShapeToString(inputs[i-1]));
0093 }
0094
0095 }
0096 for(auto it:stack)
0097 ret[0].push_back(it);
0098 }
0099
0100 return ret;
0101 }
0102
0103
0104 std::vector<Dim> ShapeInference(const std::vector<std::vector<Dim>> & inputs, const RModel & model) {
0105 std::vector<Dim> ret(inputs[0].size());
0106
0107 if (fAxis<0) {
0108 fAxis = inputs[0].size()+fAxis;
0109 }
0110 if (fAxis < 0 || fAxis >= (int) inputs[0].size())
0111 throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
0112
0113 Dim concat_dim;
0114 if(fnewAxis == 0){
0115 for (size_t i = 0; i < inputs.size(); i++) {
0116 if (i > 0 && inputs[i].size() != inputs[i - 1].size())
0117 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " + fInputs[i] + " : " +
0118 ConvertShapeToString(inputs[i]) + " and " + fInputs[i-1] + " : " + ConvertShapeToString(inputs[i - 1]));
0119 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0120 if ((int)iaxis == fAxis) {
0121
0122 if (concat_dim.param.empty() && concat_dim.dim == 0)
0123 concat_dim = inputs[i][iaxis];
0124 else if (inputs[i][iaxis].isParam || concat_dim.isParam) {
0125 concat_dim =
0126 Dim{ concat_dim.GetVal() + std::string("+ ") + inputs[i][iaxis].GetVal(),
0127 static_cast<size_t>(-1)};
0128 } else {
0129 concat_dim = Dim { concat_dim.dim + inputs[i][iaxis].dim };
0130 }
0131 }
0132 else if (i == 0) {
0133 ret[iaxis] = inputs[i][iaxis];
0134 }
0135 else if ((!inputs[i][iaxis].isParam && !ret[iaxis].isParam) && (inputs[i][iaxis].dim != ret[iaxis].dim)) {
0136 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0137 ConvertShapeToString(inputs[i]) + " and " +
0138 ConvertShapeToString(inputs[i - 1]));
0139 }
0140 else if (!inputs[i][iaxis].isParam && ret[iaxis].isParam){
0141
0142 ret[iaxis] = inputs[i][iaxis];
0143 }
0144 else if (inputs[i][iaxis].isParam && ret[iaxis].isParam) {
0145
0146 auto & dimNames = model.GetDimShapeNames();
0147 auto p1 = std::find(dimNames.begin(), dimNames.end(), inputs[i][iaxis].param);
0148 auto p2 = std::find(dimNames.begin(), dimNames.end(), ret[iaxis].param);
0149 if (p1 < p2) ret[iaxis] = inputs[i][iaxis];
0150 }
0151
0152 }
0153
0154 if (concat_dim.isParam && concat_dim.dim == static_cast<size_t>(-1))
0155 concat_dim = Dim{ std::string("(") + concat_dim.GetVal() + std::string(")"), concat_dim.dim };
0156 }
0157
0158
0159 ret[fAxis] = Dim{concat_dim};
0160
0161 }
0162
0163
0164
0165
0166 if(fnewAxis == 1){
0167 throw std::runtime_error("TMVA SOFIE Concat Op - stacking (i.e. COncatFromSequence with new_axis=1) is not supported ");
0168 }
0169 return ret;
0170 }
0171
0172 void Initialize(RModel& model) override {
0173 for (auto &it : fInputs) {
0174 if (model.CheckIfTensorAlreadyExist(it) == false) {
0175 throw std::runtime_error("TMVA SOFIE Concat Op Input Tensor " + it + " is not found in model");
0176 }
0177 fInputShapes.push_back(model.GetDimTensorShape(it));
0178 }
0179 fOutputShape = ShapeInference(fInputShapes, model);
0180 if (model.Verbose())
0181 std::cout << "Output of concat operator has shape " << ConvertDimShapeToString(fOutputShape) << std::endl;
0182
0183
0184 bool isOutputShape = false;
0185 if (model.GetTensorType(fInputs[0]) == ETensorType::INT64 && fAxis == 0) {
0186 fIsOutputConstant = true;
0187 isOutputShape = true;
0188
0189 for ( auto & input : fInputs) {
0190 if (!model.IsInitializedTensor(input)) {
0191 fIsOutputConstant = false;
0192 if (!model.IsShapeTensor(input)) {
0193 isOutputShape = false;
0194 break;
0195 }
0196 }
0197 }
0198 if (fIsOutputConstant) {
0199 auto outputShape = ConvertShapeToInt(fOutputShape);
0200 std::vector<int64_t> outputData(ConvertShapeToLength(outputShape));
0201 size_t offset = 0;
0202 for ( auto & input : fInputs) {
0203 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(input).get());
0204 auto inputShape = model.GetTensorShape(input);
0205 size_t inputLength = ConvertShapeToLength(inputShape);
0206 std::copy(inputData, inputData + inputLength, outputData.begin() + offset );
0207 offset += inputLength;
0208
0209 model.SetNotWritableInitializedTensor(input);
0210 }
0211 model.AddConstantTensor<int64_t>(fOutput, outputShape, outputData.data());
0212 if (model.Verbose()) {
0213 std::cout << "output of Concat is a constant tensor " << ConvertShapeToString(outputShape) << " : "
0214 << ConvertValuesToString(outputData) << " (constant)" << std::endl;
0215 }
0216 } else if (isOutputShape) {
0217 auto outputShape = ConvertShapeToInt(fOutputShape);
0218 std::vector<Dim> outputData(ConvertShapeToLength(outputShape));
0219 size_t offset = 0;
0220 for ( auto & input : fInputs) {
0221 std::vector<Dim> inputData;
0222 auto inputShape = model.GetTensorShape(input);
0223 size_t inputLength = ConvertShapeToLength(inputShape);
0224 if (model.IsShapeTensor(input))
0225 inputData = model.GetShapeTensorValues(input);
0226 else if (model.IsConstantTensor(input)) {
0227 inputData.resize(inputLength);
0228 auto intData = static_cast<int64_t*>(model.GetInitializedTensorData(input).get());
0229 for (size_t i = 0; i < inputData.size(); i++)
0230 inputData[i] = Dim{ static_cast<size_t>(intData[i])};
0231 }
0232 std::cout << "concatenating input data " << inputLength << " " << inputData[0] << std::endl;
0233 std::copy(inputData.begin(), inputData.end(), outputData.begin() + offset );
0234 offset += inputLength;
0235 }
0236
0237 model.AddShapeTensor(fOutput,outputData, false);
0238 if (model.Verbose()) {
0239 std::cout << "output of Concat is a shape tensor " << ConvertShapeToString(outputShape) << " : "
0240 << ConvertShapeToString(outputData) << " (shape)" << std::endl;
0241 }
0242 fIsOutputConstant = true;
0243 }
0244 }
0245 if (!fIsOutputConstant) {
0246 model.AddIntermediateTensor(fOutput, model.GetTensorType(fInputs[0]), fOutputShape);
0247 if (model.Verbose()) {
0248 std::cout << "Concat ---> " << fOutput << " " << ConvertDimShapeToString(fOutputShape) << std::endl;
0249 }
0250 }
0251 }
0252
0253 std::string Generate(std::string opName) override {
0254 if (fIsOutputConstant) return "";
0255 opName = "op_" + opName;
0256 if(fOutputShape.empty()){
0257 throw std::runtime_error("TMVA SOFIE Concat called to Generate without being initialized first");
0258 }
0259 std::stringstream out;
0260 out<<"\n//--------- Concat " << opName << " --> " << ConvertShapeToString(fOutputShape) << "\n";
0261
0262 bool hasShapeOnes = true;
0263 for(int i = 0; i<fAxis; ++i){
0264 if(fInputShapes[0][i].dim !=1){
0265 hasShapeOnes = false;
0266 break;
0267 }
0268 }
0269 if (fAxis == 0 || hasShapeOnes) {
0270 std::string offset;
0271 for(size_t i=0; i<fInputs.size(); ++i) {
0272 auto length = ConvertDimShapeToLength(fInputShapes[i]);
0273 out << SP << "std::copy(tensor_" <<fInputs[i] << ", tensor_" <<fInputs[i] << "+" << length <<", tensor_"<<fOutput;
0274 if (i > 0) out << offset;
0275 offset += " + " + length;
0276 out << ");\n";
0277 }
0278 }
0279 else {
0280
0281 std::vector<Dim> outStride = UTILITY::ComputeStrideFromShape(fOutputShape);
0282 std::vector<std::vector<Dim>> inStrides(fInputs.size());
0283 int idx = 0;
0284 for ( auto &s : inStrides) {
0285 s = UTILITY::ComputeStrideFromShape(fInputShapes[idx]);
0286 idx++;
0287 }
0288 for (int i = 0; i < fAxis; ++i) {
0289
0290 out << SP << "for (size_t i" << i << " = 0; i" << i << " < " << fOutputShape[i].GetVal() << "; ++i" << i <<") {\n";
0291 }
0292
0293 out << SP << SP << SP << "int idxOut = ";
0294 for (int k = 0; k < fAxis; k++) {
0295 if (k > 0) out << " + ";
0296 out << outStride[k].GetVal() << "*i" << k;
0297 }
0298 out << ";\n";
0299
0300 for (size_t j = 0; j < fInputs.size(); j++) {
0301 if (j>0)
0302 out << SP << SP << SP << "idxOut += " << fInputShapes[j-1][fAxis].GetVal() << ";\n";
0303 out << SP << SP << SP << "int idxIn" << j <<" = ";
0304 for (int k = 0; k < fAxis; k++) {
0305 if (k > 0) out << " + ";
0306 out << inStrides[j][k].GetVal() << "*i" << k;
0307 }
0308 out << ";\n";
0309 out << SP << SP << SP << "for (size_t iC = 0; iC < " << fInputShapes[j][fAxis].GetVal() << "; ++iC) {\n";
0310 out << SP << SP << SP << SP << "tensor_" << fOutput << "[idxOut+iC] = tensor_" << fInputs[j] << "[idxIn" << j << "+iC];\n";
0311 out << SP << SP << SP << "}\n";
0312
0313 }
0314 for (int i = 0; i < fAxis; ++i) {
0315 out << SP << "}\n";
0316 }
0317 }
0318
0319 return out.str();
0320 }
0321 };
0322 }
0323 }
0324 }
0325
0326 #endif