File indexing completed on 2025-09-13 09:10:35
0001 #ifndef TMVA_SOFIE_ROPERATOR_Concat
0002 #define TMVA_SOFIE_ROPERATOR_Concat
0003
0004
0005 #include "TMVA/SOFIE_common.hxx"
0006 #include "TMVA/ROperator.hxx"
0007 #include "TMVA/RModel.hxx"
0008
0009 #include <sstream>
0010 #include <algorithm>
0011 #include <iterator>
0012 #include <iomanip>
0013 #include <limits>
0014
0015 namespace TMVA{
0016 namespace Experimental{
0017 namespace SOFIE{
0018
0019 class ROperator_Concat final : public ROperator
0020 {
0021 private:
0022 int fAxis=0;
0023 int fnewAxis=0;
0024 std::vector<std::string> fInputs;
0025 std::string fOutput;
0026 std::vector<Dim>fOutputShape;
0027 std::vector<std::vector<Dim>> fInputShapes;
0028
0029 public:
0030 ROperator_Concat(){}
0031 ROperator_Concat(std::vector<std::string> inputs, int axis, int newAxis, std::string output):
0032 fAxis(axis), fnewAxis(newAxis), fOutput(UTILITY::Clean_name(output)) {
0033 fInputs.reserve(inputs.size());
0034 for (auto & name : inputs)
0035 fInputs.push_back(UTILITY::Clean_name(name));
0036
0037 fInputTensorNames.resize(fInputs.size());
0038 std::transform(fInputs.begin(), fInputs.end(), fInputTensorNames.begin(),
0039 [](const std::string& s) -> std::string_view { return s; });
0040 fOutputTensorNames = { fOutput };
0041 }
0042
0043 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0044 return input;
0045 }
0046
0047
0048 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> inputs) override {
0049 std::vector<std::vector<size_t>> ret(1);
0050
0051 if (fAxis<0) {
0052 fAxis = inputs[0].size()+fAxis;
0053 }
0054 if (fAxis < 0 || fAxis >= (int) inputs[0].size())
0055 throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
0056
0057 int concat_dim=0;
0058 if(fnewAxis == 0){
0059 for (size_t i = 0; i < inputs.size(); i++) {
0060 if (i > 0 && inputs[i].size() != inputs[i - 1].size())
0061 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " +
0062 ConvertShapeToString(inputs[i]) + " and " + ConvertShapeToString(inputs[i - 1]));
0063 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0064 if ((int)iaxis == fAxis)
0065 concat_dim += inputs[i][iaxis];
0066 else if (i > 0 && inputs[i][iaxis] != inputs[i - 1][iaxis])
0067 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0068 ConvertShapeToString(inputs[i]) + " and " +
0069 ConvertShapeToString(inputs[i - 1]));
0070 }
0071 }
0072
0073
0074 ret[0] = inputs[0];
0075 ret[0][fAxis] = concat_dim;
0076 }
0077 std::vector<int> stack;
0078 if(fnewAxis == 1){
0079 for(size_t i = 0; i < inputs.size(); i++) {
0080 if (i > 0 && inputs[i].size() != inputs[i-1].size() )
0081 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " + fInputs[i] + " : " +
0082 ConvertShapeToString(inputs[i]) + " and " + fInputs[i-1] + " : " + ConvertShapeToString(inputs[i-1]));
0083 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0084 if ((int) iaxis == fAxis)
0085 stack.push_back(inputs[i][iaxis]);
0086 else
0087 if (i> 0 && inputs[i][iaxis] != inputs[i-1][iaxis])
0088 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0089 ConvertShapeToString(inputs[i]) + " and " + ConvertShapeToString(inputs[i-1]));
0090 }
0091
0092 }
0093 for(auto it:stack)
0094 ret[0].push_back(it);
0095 }
0096
0097 return ret;
0098 }
0099
0100
0101 std::vector<std::vector<Dim>> ShapeInference(const std::vector<std::vector<Dim>> & inputs) {
0102 std::vector<std::vector<Dim>> ret(1);
0103
0104 if (fAxis<0) {
0105 fAxis = inputs[0].size()+fAxis;
0106 }
0107 if (fAxis < 0 || fAxis >= (int) inputs[0].size())
0108 throw std::runtime_error("TMVA SOFIE Concat Op - invalid axis value ");
0109
0110 int concat_dim=0;
0111 if(fnewAxis == 0){
0112 for (size_t i = 0; i < inputs.size(); i++) {
0113 if (i > 0 && inputs[i].size() != inputs[i - 1].size())
0114 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have different shapes " + fInputs[i] + " : " +
0115 ConvertDynamicShapeToString(inputs[i]) + " and " + fInputs[i-1] + " : " + ConvertDynamicShapeToString(inputs[i - 1]));
0116 for (size_t iaxis = 0; iaxis < inputs[i].size(); iaxis++) {
0117 if ((int)iaxis == fAxis) {
0118
0119 if (inputs[i][iaxis].isParam)
0120 throw std::runtime_error("TMVA SOFIE Concat Op - not supporting input param dimensions for concatenation axis. Input shape is " +
0121 ConvertDynamicShapeToString(inputs[i]));
0122 concat_dim += inputs[i][iaxis].dim;
0123 }
0124
0125 else if (i > 0 && inputs[i][iaxis].GetVal() != inputs[i - 1][iaxis].GetVal())
0126 throw std::runtime_error("TMVA SOFIE Concat Op - input tensors have wrong shapes " +
0127 ConvertDynamicShapeToString(inputs[i]) + " and " +
0128 ConvertDynamicShapeToString(inputs[i - 1]));
0129 }
0130 }
0131
0132
0133 ret[0] = inputs[0];
0134 ret[0][fAxis].dim = concat_dim;
0135 }
0136
0137
0138
0139
0140 if(fnewAxis == 1){
0141 throw std::runtime_error("TMVA SOFIE Concat Op - stacking (i.e. COncatFromSequence with new_axis=1) is not supported ");
0142 }
0143 return ret;
0144 }
0145
0146 void Initialize(RModel& model) override {
0147 for (auto &it : fInputs) {
0148 if (model.CheckIfTensorAlreadyExist(it) == false) {
0149 throw std::runtime_error("TMVA SOFIE Concat Op Input Tensor " + it + " is not found in model");
0150 }
0151 fInputShapes.push_back(model.GetDynamicTensorShape(it));
0152 }
0153 fOutputShape = ShapeInference(fInputShapes)[0];
0154 if (model.Verbose())
0155 std::cout << "Output of concat operator has shape " << ConvertDynamicShapeToString(fOutputShape) << std::endl;
0156
0157
0158 if (model.GetTensorType(fInputs[0]) == ETensorType::INT64 && fAxis == 0) {
0159 fIsOutputConstant = true;
0160 for ( auto & input : fInputs) {
0161 if (!model.IsInitializedTensor(input)) {
0162 fIsOutputConstant = false;
0163 break;
0164 }
0165 }
0166 if (fIsOutputConstant) {
0167 auto outputShape = ConvertShapeToInt(fOutputShape);
0168 std::vector<int64_t> outputData(ConvertShapeToLength(outputShape));
0169 size_t offset = 0;
0170 for ( auto & input : fInputs) {
0171 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(input).get());
0172 auto inputShape = model.GetTensorShape(input);
0173 size_t inputLength = ConvertShapeToLength(inputShape);
0174 std::copy(inputData, inputData + inputLength, outputData.begin() + offset );
0175 offset += inputLength;
0176
0177 model.SetNotWritableInitializedTensor(input);
0178 }
0179 model.AddConstantTensor<int64_t>(fOutput, outputShape, outputData.data());
0180 if (model.Verbose()) {
0181 std::cout << "output of Concat is a constant tensor " << ConvertShapeToString(outputShape) << " : "
0182 << ConvertValuesToString(outputData) << std::endl;
0183 }
0184 }
0185 }
0186 if (!fIsOutputConstant) {
0187 model.AddIntermediateTensor(fOutput, model.GetTensorType(fInputs[0]), fOutputShape);
0188 if (model.Verbose()) {
0189 std::cout << "Concat ---> " << fOutput << " " << ConvertDynamicShapeToString(fOutputShape) << std::endl;
0190 }
0191 }
0192 }
0193
0194 std::string Generate(std::string OpName) override {
0195 if (fIsOutputConstant) return "";
0196 OpName = "op_"+OpName;
0197 if(fOutputShape.empty()){
0198 throw std::runtime_error("TMVA SOFIE Concat called to Generate without being initialized first");
0199 }
0200 std::stringstream out;
0201 out<<"\n//--------- Concat\n";
0202
0203 bool hasShapeOnes = true;
0204 for(int i = 0; i<fAxis; ++i){
0205 if(fInputShapes[0][i].dim !=1){
0206 hasShapeOnes = false;
0207 break;
0208 }
0209 }
0210 if (fAxis == 0 || hasShapeOnes) {
0211 std::string offset;
0212 for(size_t i=0; i<fInputs.size(); ++i) {
0213 std::string length = ConvertDynamicShapeToLength(fInputShapes[i]);
0214 out << SP << "std::copy(tensor_" <<fInputs[i] << ", tensor_" <<fInputs[i] << "+" << length <<", tensor_"<<fOutput;
0215 if (i > 0) out << offset;
0216 offset += " + " + length;
0217 out << ");\n";
0218 }
0219 }
0220 else {
0221
0222 std::vector<Dim> outStride = UTILITY::ComputeStrideFromShape(fOutputShape);
0223 std::vector<std::vector<Dim>> inStrides(fInputs.size());
0224 int idx = 0;
0225 for ( auto &s : inStrides) {
0226 s = UTILITY::ComputeStrideFromShape(fInputShapes[idx]);
0227 idx++;
0228 }
0229 for (int i = 0; i < fAxis; ++i) {
0230
0231 out << SP << "for (size_t i" << i << " = 0; i" << i << " < " << fOutputShape[i].GetVal() << "; ++i" << i <<") {\n";
0232 }
0233
0234 out << SP << SP << SP << "int idxOut = ";
0235 for (int k = 0; k < fAxis; k++) {
0236 if (k > 0) out << " + ";
0237 out << outStride[k].GetVal() << "*i" << k;
0238 }
0239 out << ";\n";
0240
0241 for (size_t j = 0; j < fInputs.size(); j++) {
0242 if (j>0)
0243 out << SP << SP << SP << "idxOut += " << fInputShapes[j-1][fAxis].GetVal() << ";\n";
0244 out << SP << SP << SP << "int idxIn" << j <<" = ";
0245 for (int k = 0; k < fAxis; k++) {
0246 if (k > 0) out << " + ";
0247 out << inStrides[j][k].GetVal() << "*i" << k;
0248 }
0249 out << ";\n";
0250 out << SP << SP << SP << "for (size_t iC = 0; iC < " << fInputShapes[j][fAxis].GetVal() << "; ++iC) {\n";
0251 out << SP << SP << SP << SP << "tensor_" << fOutput << "[idxOut+iC] = tensor_" << fInputs[j] << "[idxIn" << j << "+iC];\n";
0252 out << SP << SP << SP << "}\n";
0253
0254 }
0255 for (int i = 0; i < fAxis; ++i) {
0256 out << SP << "}\n";
0257 }
0258 }
0259
0260 return out.str();
0261 }
0262 };
0263 }
0264 }
0265 }
0266
0267 #endif