File indexing completed on 2025-09-17 09:14:53
0001 #ifndef TMVA_SOFIE_ROperator_Expand
0002 #define TMVA_SOFIE_ROperator_Expand
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template<typename T>
0015 class ROperator_Expand final : public ROperator{
0016 private:
0017
0018 std::vector<size_t> fShapeX;
0019 std::vector<size_t> fShape;
0020 std::vector<size_t> fShapeY;
0021
0022 std::string fNX;
0023 std::string fNShape;
0024 std::string fNY;
0025 std::string fType;
0026
0027 bool fInitialized = false;
0028
0029 public:
0030 ROperator_Expand(){}
0031 ROperator_Expand(std::string nameX, std::string nameShape, std::string nameY):
0032 fNX(UTILITY::Clean_name(nameX)), fNShape(UTILITY::Clean_name(nameShape)), fNY(UTILITY::Clean_name(nameY)){
0033 fInputTensorNames = { fNX };
0034 fOutputTensorNames = { fNY };
0035 }
0036
0037
0038 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0039 return input;
0040 }
0041
0042 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0043 return input;
0044 }
0045
0046 void Initialize(RModel& model) override {
0047
0048 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0049 throw std::runtime_error("TMVA SOFIE Expand Op Input Tensor " + fNX + " is not found in model");
0050 }
0051 fShapeX = model.GetTensorShape(fNX);
0052 if (!model.IsInitializedTensor(fNShape)) {
0053 throw std::runtime_error("TMVA::SOFIE - Tensor " + fNShape + " is not initialized.");
0054 }
0055 int64_t *shapeData =
0056 static_cast<int64_t *>(model.GetInitializedTensorData(fNShape).get());
0057 fShape = model.GetTensorShape(fNShape);
0058 if (fShape.size() != 1) {
0059 throw std::runtime_error("TMVA::SOFIE - Expand operator shape must be a 1d tensor.");
0060 }
0061 size_t N = fShape[0];
0062 std::vector<size_t> shape(shapeData, shapeData + N);
0063
0064 fShapeY = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcastShape(
0065 fShapeX, shape);
0066 fInitialized = model.IsInitializedTensor(fNX);
0067
0068 bool broadcast = !UTILITY::AreSameShape(fShapeX, fShapeY);
0069 if (model.IsInitializedTensor(fNX)) {
0070
0071 auto data = model.GetInitializedTensorData(fNX);
0072 if (broadcast) {
0073 std::shared_ptr<void> broadcastedData(
0074 UTILITY::UnidirectionalBroadcast<T>(static_cast<T *>(data.get()), fShapeX, fShapeY),
0075 std::default_delete<T[]>());
0076
0077 model.UpdateInitializedTensor(fNX, model.GetTensorType(fNX), fShapeY, broadcastedData);
0078 fShapeX = fShapeY;
0079
0080 model.SetNotWritableInitializedTensor(fNX);
0081 data = broadcastedData;
0082 }
0083 if (broadcast || model.IsConstantTensor(fNX)) {
0084 fIsOutputConstant = true;
0085 model.AddConstantTensor(fNY, model.GetTensorType(fNX), fShapeY, data);
0086 fOutputTensorNames.pop_back();
0087 } else {
0088 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0089 }
0090 } else {
0091
0092 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0093 }
0094 fType = ConvertTypeToString(model.GetTensorType(fNX));
0095 if (model.Verbose())
0096 std::cout << "Expand - output is with shape " << ConvertShapeToString(fShapeY) << std::endl;
0097 }
0098
0099 std::string GenerateInitCode() override {
0100 std::stringstream out;
0101 if (!fIsOutputConstant && (fInitialized || fShapeX == fShapeY ) ) {
0102 size_t length = ConvertShapeToLength(fShapeY);
0103 out << "// Copying initialized tensor " << fNX << " to " << fNY << "\n";
0104 out << SP << "std::copy(tensor_" << fNX << ", " << "tensor_" << fNX << " + " << length << ", tensor_" << fNY << ");\n";
0105 }
0106 return out.str();
0107 }
0108
0109 std::string Generate(std::string OpName) override {
0110 if (fIsOutputConstant) return "";
0111 OpName = "op_" + OpName;
0112 if (fShapeY.empty()) {
0113 throw std::runtime_error("TMVA SOFIE Expand Op called to Generate without being initialized first");
0114 }
0115 std::stringstream out;
0116 out << SP << "\n//------ Expand Op" << "\n";
0117
0118 if (!fInitialized && fShapeX != fShapeY) {
0119 out << SP << "// Broadcasting uninitialized tensor " << fNX << "\n";
0120 out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" << fType << ">(tensor_" << fNX << ", " << ConvertShapeToString(fShapeX) << ", " << ConvertShapeToString(fShapeY)
0121 << ", std::span<"<<fType<<">(tensor_"<<fNY<<", "<<ConvertShapeToLength(fShapeY)<<"));\n";
0122 }
0123 return out.str();
0124 }
0125
0126 };
0127
0128 }
0129 }
0130 }
0131
0132
0133 #endif