File indexing completed on 2025-09-18 09:32:42
0001 #ifndef TMVA_SOFIE_SOFIE_COMMON
0002 #define TMVA_SOFIE_SOFIE_COMMON
0003
0004 #include "TMVA/RTensor.hxx"
0005
0006 #include "ROOT/RSpan.hxx"
0007
0008 #include <stdexcept>
0009 #include <type_traits>
0010 #include <cstdint>
0011 #include <cstring>
0012 #include <complex>
0013 #include <string>
0014 #include <vector>
0015 #include <map>
0016 #include <memory>
0017 #include <regex>
0018 #include <sstream>
0019 #include <iostream>
0020 #include <iomanip>
0021 #include <cassert>
0022 #include <limits>
0023
0024 namespace TMVA{
0025 namespace Experimental{
0026 namespace SOFIE{
0027
0028 enum class ETensorType{
0029 UNDEFINED = 0, FLOAT = 1, UINT8 = 2, INT8 = 3, UINT16 = 4, INT16 = 5, INT32 = 6, INT64 = 7, STRING = 8, BOOL = 9,
0030 FLOAT16 = 10, DOUBLE = 11, UINT32 = 12, UINT64 = 13, COMPLEX64 = 14, COMPLEX28 = 15, BFLOAT16 = 16
0031 };
0032
0033 enum class EActivationType{
0034 UNDEFINED = 0, RELU = 1, SOFTMAX = 2, SIGMOID = 3, LEAKYRELU = 4, TANH = 5, ELU = 6
0035 };
0036
0037 constexpr size_t GetTypeSize(ETensorType type) {
0038 switch (type) {
0039 case ETensorType::FLOAT: return sizeof(float);
0040 case ETensorType::DOUBLE: return sizeof(double);
0041 case ETensorType::UINT8: return sizeof(uint8_t);
0042 case ETensorType::INT8: return sizeof(int8_t);
0043 case ETensorType::UINT16: return sizeof(uint16_t);
0044 case ETensorType::INT16: return sizeof(int16_t);
0045 case ETensorType::INT32: return sizeof(int32_t);
0046 case ETensorType::INT64: return sizeof(int64_t);
0047 case ETensorType::UINT32: return sizeof(uint32_t);
0048 case ETensorType::UINT64: return sizeof(uint64_t);
0049 case ETensorType::BOOL: return sizeof(bool);
0050 case ETensorType::STRING: return sizeof(std::string);
0051 default: return 0;
0052 }
0053 }
0054
0055 typedef std::int64_t int_t;
0056
0057 std::string ConvertTypeToString(ETensorType type);
0058 ETensorType ConvertStringToType(std::string type);
0059
0060 struct Dim{
0061 bool isParam = false;
0062 size_t dim = 0;
0063 std::string param;
0064
0065
0066 Dim() {}
0067
0068
0069 Dim(const std::string & p, size_t d = 0) : isParam(true), dim(d), param(p) {}
0070
0071
0072 Dim(size_t d) : dim(d) {}
0073
0074 std::string GetVal() const {
0075 return (isParam) ? param : std::to_string(dim);
0076 }
0077 };
0078
0079
0080 struct InputTensorInfo{
0081 ETensorType type;
0082 std::vector<Dim> shape;
0083 };
0084
0085 struct TensorInfo{
0086 ETensorType type;
0087 std::vector<size_t> shape;
0088 };
0089
0090 struct DynamicTensorInfo{
0091 ETensorType type;
0092 std::vector<Dim> shape;
0093 };
0094
0095
0096 template <typename T>
0097 struct TensorType {};
0098 template<>
0099 struct TensorType<float> {
0100 static const std::string Name() { return "float"; }
0101 };
0102 template<>
0103 struct TensorType<double> {
0104 static const std::string Name() { return "double"; }
0105 };
0106 template<>
0107 struct TensorType<int64_t> {
0108 static const std::string Name() { return "int64_t"; }
0109 };
0110 template<>
0111 struct TensorType<int32_t> {
0112 static const std::string Name() { return "int32_t"; }
0113 };
0114 template<>
0115 struct TensorType<uint32_t> {
0116 static const std::string Name() { return "uint32_t"; }
0117 };
0118 template<>
0119 struct TensorType<uint64_t> {
0120 static const std::string Name() { return "uint64_t"; }
0121 };
0122
0123 struct TensorMemoryInfo {
0124 std::string_view tensor_name;
0125 size_t tensor_size;
0126
0127 TensorMemoryInfo split(const std::string_view new_name, size_t new_size) {
0128 if (new_size > tensor_size) {
0129 throw std::invalid_argument("New size exceeds available tensor size.");
0130 }
0131 tensor_size -= new_size;
0132 return TensorMemoryInfo{new_name, new_size};
0133 }
0134
0135
0136 void merge(const TensorMemoryInfo& other) {
0137 tensor_size += other.tensor_size;
0138 }
0139 };
0140
0141 struct MemoryPoolInfo {
0142
0143
0144 std::map<size_t, TensorMemoryInfo> total_stack;
0145
0146
0147 std::map<size_t, size_t> available_stack;
0148 };
0149
0150 std::vector<Dim> ConvertShapeToDim(std::vector<size_t> shape);
0151
0152 std::vector<size_t> ConvertShapeToInt(std::vector<Dim> shape);
0153
0154 std::size_t ConvertShapeToLength(std::vector<size_t> shape);
0155
0156 std::string ConvertShapeToString(std::vector<size_t> shape);
0157 std::string ConvertDynamicShapeToString(std::vector<Dim> shape);
0158
0159
0160
0161
0162 std::string ConvertDynamicShapeToLength(std::vector<Dim> shape);
0163
0164 template<class T>
0165 std::string ConvertValToString(T value) {
0166 std::stringstream ret;
0167 if (std::is_floating_point_v<T>)
0168 ret << std::setprecision(std::numeric_limits<T>::max_digits10);
0169 ret << value;
0170 return ret.str();
0171 }
0172
0173
0174
0175 template<class T>
0176 std::string ConvertValuesToString(size_t n, const T * data) {
0177 std::stringstream ret;
0178 ret << "{ ";
0179 for (size_t i = 0; i < n; i++) {
0180 if (std::is_floating_point_v<T>)
0181 ret << std::setprecision(std::numeric_limits<T>::max_digits10);
0182 ret << data[i];
0183 if (i < n-1) ret << ", ";
0184 }
0185 ret << "}";
0186 return ret.str();
0187 }
0188 template<class T>
0189 std::string ConvertValuesToString(const std::vector<T> & data) {
0190 return ConvertValuesToString(data.size(), data.data());
0191 }
0192
0193 class InitializedTensor {
0194 public:
0195 InitializedTensor() = default;
0196 InitializedTensor(ETensorType type, std::span<std::size_t> shape, std::shared_ptr<void> data, bool typeConstant = false)
0197 : fConstant(typeConstant), fType{type}, fShape{shape.begin(), shape.end()}, fData{data}
0198 {
0199 }
0200
0201 ETensorType const &type() const { return fType; }
0202 std::vector<std::size_t> const &shape() const { return fShape; }
0203 std::shared_ptr<void> const &sharedptr() const { return fData; }
0204
0205 bool IsConstantTensor() const { return fConstant;}
0206
0207 bool IsWeightTensor() const { return !fConstant && !fIsNotWritable;}
0208
0209 void SetNotWritable() { fIsNotWritable = true;}
0210
0211 template <class T = void>
0212 T const *data() const
0213 {
0214 return static_cast<T const *>(fData.get());
0215 }
0216
0217 void CastSharedToPersistent()
0218 {
0219
0220
0221 fSize = 1;
0222 for (std::size_t item : fShape) {
0223 fSize *= static_cast<int>(item);
0224 }
0225 switch (fType) {
0226 case ETensorType::FLOAT: fSize *= sizeof(float); break;
0227 case ETensorType::DOUBLE: fSize *= sizeof(double); break;
0228 case ETensorType::INT32: fSize *= sizeof(int32_t); break;
0229 case ETensorType::INT64: fSize *= sizeof(int64_t); break;
0230 case ETensorType::BOOL: fSize *= sizeof(bool); break;
0231 default:
0232 throw std::runtime_error("TMVA::SOFIE doesn't yet supports serialising data-type " +
0233 ConvertTypeToString(fType));
0234 }
0235 fPersistentData = static_cast<char *>(fData.get());
0236 }
0237 void CastPersistentToShared()
0238 {
0239
0240 if (fSize == 0 || fPersistentData == nullptr) {
0241 return;
0242 }
0243
0244
0245 if (fPersistentData == static_cast<char *>(fData.get())) {
0246 return;
0247 }
0248
0249
0250 fData = std::shared_ptr<void>{malloc(fSize), free};
0251 std::memcpy(fData.get(), fPersistentData, fSize);
0252
0253
0254
0255 delete[] fPersistentData;
0256 fPersistentData = nullptr;
0257 fSize = 0;
0258 }
0259
0260 private:
0261 bool fConstant = false;
0262 bool fIsNotWritable = false;
0263 ETensorType fType;
0264 std::vector<std::size_t> fShape;
0265 std::shared_ptr<void> fData;
0266 int fSize = 0;
0267 char *fPersistentData = nullptr;
0268 };
0269
0270 template <typename T>
0271 ETensorType GetTemplatedType(T ){
0272 if (std::is_same<T, float>::value) return ETensorType::FLOAT;
0273 if (std::is_same<T, uint8_t>::value) return ETensorType::UINT8;
0274 if (std::is_same<T, int8_t>::value) return ETensorType::INT8;
0275 if (std::is_same<T, uint16_t>::value) return ETensorType::UINT16;
0276 if (std::is_same<T, int16_t>::value) return ETensorType::INT16;
0277 if (std::is_same<T, int32_t>::value) return ETensorType::INT32;
0278 if (std::is_same<T, int64_t>::value) return ETensorType::INT64;
0279 if (std::is_same<T, std::string>::value) return ETensorType::STRING;
0280 if (std::is_same<T, bool>::value) return ETensorType::BOOL;
0281
0282 if (std::is_same<T, double>::value) return ETensorType::DOUBLE;
0283 if (std::is_same<T, uint32_t>::value) return ETensorType::UINT32;
0284 if (std::is_same<T, uint64_t>::value) return ETensorType::UINT64;
0285
0286 }
0287
0288 namespace UTILITY{
0289
0290 bool AreSameShape(const std::vector<size_t>&, const std::vector<size_t>&);
0291 bool AreSameShape(const std::vector<size_t>&, const std::vector<Dim>&);
0292 bool AreSameShape(const std::vector<Dim>&, const std::vector<Dim>&);
0293
0294
0295
0296 std::vector<size_t> MultidirectionalBroadcastShape(std::vector<std::vector<size_t>>);
0297
0298
0299 std::vector<size_t> UnidirectionalBroadcastShape(std::vector<size_t>, std::vector<size_t>);
0300
0301 std::string Clean_name(std::string input_tensor_name);
0302
0303 template<typename T>
0304 T* BroadcastConvBias(const T* data, const size_t channel, const std::vector<size_t>& targetShape) {
0305 size_t size = targetShape.size();
0306 if (targetShape[1] != channel) {
0307 std::stringstream ss;
0308 ss << "TMVA::SOFIE - Error broadcasting Conv Bias of shape {";
0309 ss << std::to_string(channel);
0310 ss << "} to ";
0311 ss << ConvertShapeToString(targetShape);
0312 throw
0313 std::runtime_error(ss.str());
0314 }
0315
0316 size_t targetLength = ConvertShapeToLength(targetShape);
0317 T* newData = new T[targetLength];
0318
0319 if (targetLength == channel) {
0320 std::copy(data, data + channel, newData);
0321 return newData;
0322 }
0323
0324
0325 size_t cStride = 1;
0326 for (size_t i = 2; i < size; i++)
0327 cStride *= targetShape[i];
0328
0329
0330 for (size_t i = 0; i < channel; i++) {
0331 std::fill(newData + i * cStride, newData + (i + 1) * cStride, data[i]);
0332 }
0333
0334 size_t batch = targetShape[0];
0335 size_t bStride = channel * cStride;
0336 for (size_t i = 1; i < batch; i++) {
0337 std::copy(newData, newData + bStride, newData + i * bStride);
0338 }
0339 return newData;
0340 }
0341
0342
0343
0344
0345 template<typename T, class ConstContT = std::span<const T>, class ContT = std::span<T> >
0346 void BroadcastTensor(ConstContT data, const std::vector<size_t>& shape, const std::vector<size_t>& targetShape, ContT broadcastedData) {
0347
0348 size_t size = shape.size();
0349
0350 size_t curLength = data.size();
0351 size_t targetLength = broadcastedData.size();
0352 assert(ConvertShapeToLength(targetShape) == targetLength);
0353
0354 if (shape.front() == targetShape.front() && shape.back() == 1 && size > 1) {
0355 size_t bsize = targetShape.back();
0356
0357 for (int k = int(size)-2; k >=0; k--) {
0358 if (shape[k] != 1) break;
0359 bsize *= targetShape[k];
0360 }
0361 for (size_t i = 0; i < curLength; i++) {
0362 std::fill(broadcastedData.begin() + i*bsize, broadcastedData.begin() + (i+1)*bsize , data[i]);
0363 }
0364 return;
0365 }
0366
0367 std::copy(data.begin(), data.end(), broadcastedData.begin());
0368
0369 size_t arrayNum = 1;
0370
0371 std::vector<T> newData(targetLength);
0372
0373 for (size_t idx = 0; idx < size; idx++) {
0374 size_t dim = shape[idx];
0375 size_t targetDim = targetShape[idx];
0376 if (dim == 1 && targetDim > 1) {
0377
0378 size_t newLength = curLength * targetDim;
0379
0380 size_t arrayLength = curLength / arrayNum;
0381
0382 if (arrayLength > 1) {
0383
0384 for (size_t arrayIdx = 0; arrayIdx < arrayNum; arrayIdx++) {
0385 for (size_t targetIdx = 0; targetIdx < targetDim; targetIdx++) {
0386 size_t offset = arrayIdx * arrayLength * targetDim + targetIdx * arrayLength;
0387 std::copy(broadcastedData.begin() + arrayIdx * arrayLength,
0388 broadcastedData.begin() + (arrayIdx + 1) * arrayLength,
0389 newData.begin() + offset);
0390 }
0391 }
0392 } else {
0393
0394 for (size_t arrayIdx = 0; arrayIdx < arrayNum; arrayIdx++) {
0395 std::fill(newData.begin() + arrayIdx * targetDim,
0396 newData.begin() + (arrayIdx + 1) * targetDim, broadcastedData[arrayIdx]);
0397 }
0398 }
0399
0400 curLength = newLength;
0401
0402 std::copy(newData.begin(), newData.begin() + newLength, broadcastedData.begin());
0403 }
0404
0405 arrayNum *= targetDim;
0406 }
0407
0408 }
0409
0410
0411 template<typename T>
0412 T* CreateBroadcastTensor(const T* data, const std::vector<size_t>& shape, const std::vector<size_t>& targetShape, size_t targetLength) {
0413
0414 T* broadcastedData = new T[targetLength];
0415 std::span<T> bData(broadcastedData, broadcastedData+targetLength);
0416 size_t curLength = ConvertShapeToLength(shape);
0417 std::span<const T> inData(data, curLength);
0418 BroadcastTensor<T, std::span<const T>, std::span<T>>(inData, shape, targetShape, bData);
0419 return broadcastedData;
0420 }
0421
0422
0423 template<typename T>
0424 T* UnidirectionalBroadcast(const T* data, const std::vector<size_t>& shape, const std::vector<size_t>& targetShape) {
0425
0426 if (shape.size() < targetShape.size()) {
0427 size_t targetSize = targetShape.size();
0428 std::vector<size_t> newShape(targetSize, 1);
0429 size_t offset = targetSize - shape.size();
0430 std::copy(shape.begin(), shape.end(), newShape.begin() + offset);
0431 return CreateBroadcastTensor<T>(data, newShape, targetShape, ConvertShapeToLength(targetShape));
0432 }
0433 return CreateBroadcastTensor<T>(data, shape, targetShape, ConvertShapeToLength(targetShape));
0434 }
0435
0436
0437 template<typename T>
0438 void UnidirectionalBroadcast(const T* data, const std::vector<size_t>& shape, const std::vector<size_t>& targetShape, std::span<T> broadcastedData) {
0439 size_t curLength = ConvertShapeToLength(shape);
0440 std::span<T> inData(const_cast<T*>(data), curLength);
0441
0442 if (shape.size() < targetShape.size()) {
0443 size_t targetSize = targetShape.size();
0444 std::vector<size_t> newShape(targetSize, 1);
0445 size_t offset = targetSize - shape.size();
0446 std::copy(shape.begin(), shape.end(), newShape.begin() + offset);
0447 BroadcastTensor<T>(inData, newShape, targetShape, broadcastedData);
0448 }
0449 BroadcastTensor<T, std::span<T>>(inData, shape, targetShape, broadcastedData);
0450 }
0451
0452 void UnidirectionalBroadcast(const std::vector<bool> & data, const std::vector<size_t>& shape, const std::vector<size_t>& targetShape, std::vector<bool> & broadcastedData);
0453
0454
0455 std::vector<size_t> ComputeStrideFromShape(const std::vector<size_t> & shape);
0456 std::vector<Dim> ComputeStrideFromShape(const std::vector<Dim> & shape);
0457
0458
0459
0460 inline bool is_a_ge_zero_and_a_lt_b(int a, int b) {
0461 return static_cast<unsigned>(a) < static_cast<unsigned>(b);
0462 }
0463
0464
0465
0466
0467
0468
0469
0470
0471
0472
0473
0474
0475
0476
0477
0478
0479
0480
0481
0482
0483
0484 template <typename T>
0485 void Im2col(const T *data_im, const int channels, const int height, const int width, const int kernel_h,
0486 const int kernel_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w,
0487 const int dilation_h, const int dilation_w, T *data_col)
0488 {
0489 const int output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
0490 const int output_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
0491 const int channel_size = height * width;
0492 for (int channel = channels; channel--; data_im += channel_size) {
0493 for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
0494 for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
0495 int input_row = -pad_h + kernel_row * dilation_h;
0496 for (int output_rows = output_h; output_rows; output_rows--) {
0497 if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
0498 for (int output_cols = output_w; output_cols; output_cols--) {
0499 *(data_col++) = 0;
0500 }
0501 } else {
0502 int input_col = -pad_w + kernel_col * dilation_w;
0503 for (int output_col = output_w; output_col; output_col--) {
0504 if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
0505 *(data_col++) = data_im[input_row * width + input_col];
0506 } else {
0507 *(data_col++) = 0;
0508 }
0509 input_col += stride_w;
0510 }
0511 }
0512 input_row += stride_h;
0513 }
0514 }
0515 }
0516 }
0517 }
0518
0519
0520 template <typename T>
0521 void Im2col_3d(const T *data_im, const int channels,
0522 const int depth, const int height, const int width,
0523 const int kernel_d, const int kernel_h, const int kernel_w,
0524 const int pad_d, const int pad_h, const int pad_w,
0525 const int stride_d, const int stride_h, const int stride_w,
0526 const int dilation_d, const int dilation_h, const int dilation_w, T *data_col)
0527 {
0528 const int output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
0529 const int output_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
0530 const int output_d = (depth + 2 * pad_d - (dilation_d * (kernel_d - 1) + 1)) / stride_d + 1;
0531 const int channel_size = height * width * depth;
0532
0533 for (int channel = channels; channel--; data_im += channel_size) {
0534 for (int kernel_depth = 0; kernel_depth < kernel_d; kernel_depth++) {
0535 for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
0536 for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
0537 int input_dep = -pad_d + kernel_depth * dilation_d;
0538 for (int output_dep = output_d; output_dep; output_dep--) {
0539 if (!is_a_ge_zero_and_a_lt_b(input_dep, depth)) {
0540 for (int output_rows = output_h; output_rows; output_rows--) {
0541 for (int output_cols = output_w; output_cols; output_cols--) {
0542 *(data_col++) = 0;
0543 }
0544 }
0545 } else {
0546 int input_row = -pad_h + kernel_row * dilation_h;
0547 for (int output_rows = output_h; output_rows; output_rows--) {
0548 if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
0549 for (int output_cols = output_w; output_cols; output_cols--) {
0550 *(data_col++) = 0;
0551 }
0552 } else {
0553 int input_col = -pad_w + kernel_col * dilation_w;
0554 for (int output_col = output_w; output_col; output_col--) {
0555 if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
0556 *(data_col++) = data_im[input_dep * width * height + input_row * width + input_col];
0557 } else {
0558 *(data_col++) = 0;
0559 }
0560 input_col += stride_w;
0561 }
0562 }
0563 input_row += stride_h;
0564 }
0565 }
0566 input_dep += stride_d;
0567 }
0568 }
0569 }
0570 }
0571 }
0572 }
0573
0574 template <typename Dtype>
0575 void col2im(const Dtype* data_col, const int channels,
0576 const int height, const int width, const int kernel_h, const int kernel_w,
0577 const int pad_h, const int pad_w,
0578 const int stride_h, const int stride_w,
0579 const int dilation_h, const int dilation_w,
0580 Dtype* data_im) {
0581
0582 std::fill(data_im, data_im + height * width * channels, 0.);
0583
0584
0585
0586 const int output_h = (height + 2 * pad_h -
0587 (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
0588 const int output_w = (width + 2 * pad_w -
0589 (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
0590 const int channel_size = height * width;
0591 for (int channel = channels; channel--; data_im += channel_size) {
0592 for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
0593 for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
0594 int input_row = -pad_h + kernel_row * dilation_h;
0595 for (int output_rows = output_h; output_rows; output_rows--) {
0596 if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
0597 data_col += output_w;
0598 } else {
0599 int input_col = -pad_w + kernel_col * dilation_w;
0600 for (int output_col = output_w; output_col; output_col--) {
0601 if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
0602
0603
0604
0605
0606
0607 data_im[input_row * width + input_col] += *data_col;
0608 }
0609 data_col++;
0610 input_col += stride_w;
0611 }
0612 }
0613 input_row += stride_h;
0614 }
0615 }
0616 }
0617 }
0618
0619 }
0620
0621
0622 template <class T>
0623 void FillOutput(T const *arr, std::vector<T> &out, std::size_t n)
0624 {
0625 out.resize(n);
0626 for (std::size_t i = 0; i < n; ++i) {
0627 out[i] = arr[i];
0628 }
0629 }
0630
0631
0632 inline void FillOutput(std::vector<bool> const &vec, std::vector<std::uint8_t> &out, std::size_t n)
0633 {
0634 out.resize(n);
0635 for (std::size_t i = 0; i < n; ++i) {
0636 out[i] = vec[i];
0637 }
0638 }
0639
0640 }
0641
0642 namespace BLAS{
0643 extern "C" void sgemm_(const char * transa, const char * transb, const int * m, const int * n, const int * k,
0644 const float * alpha, const float * A, const int * lda, const float * B, const int * ldb,
0645 const float * beta, float * C, const int * ldc);
0646 }
0647
0648
0649 struct GNN_Data {
0650 RTensor<float> node_data;
0651 RTensor<float> edge_data;
0652 RTensor<float> global_data;
0653 RTensor<int> edge_index;
0654
0655
0656
0657
0658 GNN_Data(): node_data(RTensor<float>({})), edge_data(RTensor<float>({})), global_data(RTensor<float>({})), edge_index(RTensor<int>({})) {}
0659
0660 };
0661
0662 template<typename T>
0663 TMVA::Experimental::RTensor<T> Concatenate( TMVA::Experimental::RTensor<T> & t1, TMVA::Experimental::RTensor<T> & t2, int axis = 0)
0664 {
0665
0666 if (t1.GetMemoryLayout() != t2.GetMemoryLayout())
0667 throw std::runtime_error("TMVA RTensor Concatenate - tensors have different memory layout");
0668 auto & shape1 = t1.GetShape();
0669 auto & shape2 = t2.GetShape();
0670 if (t1.GetSize()/shape1[axis] != t2.GetSize()/shape2[axis]) {
0671 std::cout << "axis " << axis << " sizes " << t1.GetSize() << " " << t2.GetSize() << " ";
0672 std::cout << "shape 1 : " << ConvertShapeToString(t1.GetShape());
0673 std::cout << " shape 2 : " << ConvertShapeToString(t2.GetShape()) << std::endl;
0674 throw std::runtime_error("TMVA RTensor Concatenate - tensors have incompatible shapes");
0675 }
0676 std::vector<size_t> outShape = shape1;
0677 outShape[axis] = shape1[axis] + shape2[axis];
0678 TMVA::Experimental::RTensor<T> tout(outShape, t1.GetMemoryLayout());
0679 if (t1.GetMemoryLayout() == TMVA::Experimental::MemoryLayout::ColumnMajor) {
0680 throw std::runtime_error("TMVA RTensor Concatenate is not yet supported for column major tensors");
0681 }
0682
0683 auto & stride1 = t1.GetStrides();
0684 auto & stride2 = t2.GetStrides();
0685 auto & outStride = tout.GetStrides();
0686
0687 size_t s1 = (axis > 0) ? stride1[axis-1] : t1.GetSize();
0688 size_t s2 = (axis > 0) ? stride2[axis-1] : t2.GetSize();
0689 size_t sout = (axis > 0) ? outStride[axis-1] : tout.GetSize();
0690 size_t nb = t1.GetSize()/s1;
0691 for (size_t i = 0; i < nb; i++) {
0692 std::copy(t1.GetData() + i*s1, t1.GetData() + (i+1)*s1, tout.GetData() + i * sout );
0693 std::copy(t2.GetData() + i*s2, t2.GetData() + (i+1)*s2, tout.GetData() + i * sout + s1 );
0694 }
0695
0696 return tout;
0697 }
0698
0699
0700 inline GNN_Data Concatenate(GNN_Data & data1, GNN_Data & data2, int axis = 0) {
0701 GNN_Data out;
0702 out.node_data = Concatenate(data1.node_data,data2.node_data, axis);
0703 out.edge_data = Concatenate(data1.edge_data,data2.edge_data, axis);
0704 out.global_data = Concatenate<float>(data1.global_data,data2.global_data, axis-1);
0705
0706 out.edge_index = data1.edge_index.Copy();
0707 return out;
0708 }
0709
0710 inline GNN_Data Copy(const GNN_Data & data) {
0711 GNN_Data out;
0712 out.node_data = RTensor<float>(data.node_data.GetShape());
0713 out.edge_data = RTensor<float>(data.edge_data.GetShape());
0714 out.global_data = RTensor<float>(data.global_data.GetShape());
0715 out.edge_index = RTensor<int>(data.edge_index.GetShape());
0716 std::copy(data.node_data.GetData(), data.node_data.GetData()+ data.node_data.GetSize(), out.node_data.GetData());
0717 std::copy(data.edge_data.GetData(), data.edge_data.GetData()+ data.edge_data.GetSize(), out.edge_data.GetData());
0718 std::copy(data.global_data.GetData(), data.global_data.GetData()+ data.global_data.GetSize(), out.global_data.GetData());
0719 std::copy(data.edge_index.GetData(), data.edge_index.GetData()+ data.edge_index.GetSize(), out.edge_index.GetData());
0720 return out;
0721 }
0722
0723 inline void Gemm_Call(float *output, bool transa, bool transb, int m, int n, int k, float alpha, const float *A,
0724 const float *B, float beta, const float *C)
0725 {
0726 char ct = 't';
0727 char cn = 'n';
0728 const int *lda = transa ? &k : &m;
0729 const int *ldb = transb ? &n : &k;
0730 const int *ldc = &m;
0731 if (C != nullptr) {
0732 std::copy(C, C + m * n, output);
0733 }
0734 TMVA::Experimental::SOFIE::BLAS::sgemm_(transa ? &ct : &cn, transb ? &ct : &cn, &m, &n, &k, &alpha, A, lda, B, ldb,
0735 &beta, output, ldc);
0736 }
0737
0738 }
0739 }
0740 }
0741
0742 #endif