File indexing completed on 2025-01-18 10:11:06
0001 #ifndef TMVA_SOFIE_ROperator_Expand
0002 #define TMVA_SOFIE_ROperator_Expand
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template<typename T>
0015 class ROperator_Expand final : public ROperator{
0016 private:
0017
0018 std::vector<size_t> fShapeX;
0019 std::vector<size_t> fShape;
0020 std::vector<size_t> fShapeY;
0021
0022 std::string fNX;
0023 std::string fNShape;
0024 std::string fNY;
0025 std::string fType;
0026
0027 bool fInitialized = false;
0028
0029 public:
0030 ROperator_Expand(){}
0031 ROperator_Expand(std::string nameX, std::string nameShape, std::string nameY):
0032 fNX(UTILITY::Clean_name(nameX)), fNShape(UTILITY::Clean_name(nameShape)), fNY(UTILITY::Clean_name(nameY)){}
0033
0034
0035 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0036 return input;
0037 }
0038
0039 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0040 return input;
0041 }
0042
0043 void Initialize(RModel& model) override {
0044
0045 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0046 throw std::runtime_error("TMVA SOFIE Expand Op Input Tensor " + fNX + " is not found in model");
0047 }
0048 fShapeX = model.GetTensorShape(fNX);
0049 if (!model.IsInitializedTensor(fNShape)) {
0050 throw std::runtime_error("TMVA::SOFIE - Tensor " + fNShape + " is not initialized.");
0051 }
0052 int64_t *shapeData =
0053 static_cast<int64_t *>(model.GetInitializedTensorData(fNShape).get());
0054 fShape = model.GetTensorShape(fNShape);
0055 if (fShape.size() != 1) {
0056 throw std::runtime_error("TMVA::SOFIE - Expand operator shape must be a 1d tensor.");
0057 }
0058 size_t N = fShape[0];
0059 std::vector<size_t> shape(shapeData, shapeData + N);
0060
0061 fShapeY = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcastShape(
0062 fShapeX, shape);
0063 fInitialized = model.IsInitializedTensor(fNX);
0064
0065 bool broadcast = !UTILITY::AreSameShape(fShapeX, fShapeY);
0066 if (broadcast && model.IsInitializedTensor(fNX)) {
0067
0068 auto data = model.GetInitializedTensorData(fNX);
0069 std::shared_ptr<void> broadcastedData(
0070 UTILITY::UnidirectionalBroadcast<float>(static_cast<float *>(data.get()), fShapeX, fShapeY),
0071 std::default_delete<float[]>());
0072
0073 model.UpdateInitializedTensor(fNX, model.GetTensorType(fNX), fShapeY, broadcastedData);
0074 fShapeX = fShapeY;
0075 }
0076 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0077 fType = ConvertTypeToString(model.GetTensorType(fNX));
0078 }
0079
0080 std::string GenerateInitCode() override {
0081 std::stringstream out;
0082 return out.str();
0083 }
0084
0085 std::string Generate(std::string OpName) override {
0086 OpName = "op_" + OpName;
0087
0088 if (fShapeY.empty()) {
0089 throw std::runtime_error("TMVA SOFIE Expand Op called to Generate without being initialized first");
0090 }
0091 std::stringstream out;
0092 out << SP << "\n//------ Expand Op" << "\n";
0093 size_t length = ConvertShapeToLength(fShapeY);
0094
0095 if (fInitialized) {
0096 out << "// Copying initialized tensor " << fNX << " to " << fNY << "\n";
0097 out << SP << "std::copy(tensor_" << fNX << ", " << "tensor_" << fNX << " + " << length << ", tensor_" << fNY << ");\n";
0098 } else {
0099 out << SP << "// Broadcasting uninitialized tensor " << fNX << "\n";
0100 out << SP << "{\n";
0101 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_" << fNX << ", " << ConvertShapeToString(fShapeX) << ", " << ConvertShapeToString(fShapeY) << ");\n";
0102 out << SP << SP << "std::copy(data, data + " << length << ", tensor_" << fNY << ");\n";
0103 out << SP << SP << "delete[] data;\n";
0104 out << SP << "}\n";
0105 }
0106 return out.str();
0107 }
0108
0109 };
0110
0111 }
0112 }
0113 }
0114
0115
0116 #endif