File indexing completed on 2025-01-30 10:22:58
0001 #ifndef TMVA_SOFIE_ROPERATOR_Cast
0002 #define TMVA_SOFIE_ROPERATOR_Cast
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014
0015 template <typename T>
0016 class ROperator_Cast final : public ROperator
0017 {
0018
0019 private:
0020
0021 std::string fNX;
0022 std::string fNY;
0023 std::vector<size_t> fShape;
0024 std::string fAttrType = "float";
0025
0026 public:
0027 ROperator_Cast(){}
0028 ROperator_Cast(std::string attr_type,std::string nameX, std::string nameY):
0029 fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)),
0030 fAttrType(attr_type) {}
0031
0032 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0033 return input;
0034 }
0035
0036 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0037 auto ret = input;
0038 return ret;
0039 }
0040
0041 void Initialize(RModel& model){
0042
0043 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0044 throw std::runtime_error("TMVA SOFIE Cast Op Input Tensor is not found in model");
0045 }
0046 fShape = model.GetTensorShape(fNX);
0047 model.AddIntermediateTensor(fNY, ConvertStringToType(fAttrType), fShape);
0048 }
0049
0050
0051 std::string Generate(std::string OpName){
0052 OpName = "op_" + OpName;
0053 if (fShape.empty()) {
0054 throw std::runtime_error("TMVA SOFIE Cast called to Generate without being initialized first");
0055 }
0056 std::stringstream out;
0057 size_t length = ConvertShapeToLength(fShape);
0058
0059
0060 out << "\n//------ CAST\n";
0061 out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
0062
0063 out << SP << SP << "tensor_" << fNY << "[id] = static_cast<"<< fAttrType << ">(tensor_" << fNX << "[id]);\n";
0064
0065 out << SP << "}\n";
0066 return out.str();
0067 }
0068
0069 };
0070
0071 }
0072 }
0073 }
0074
0075
0076 #endif