Warning, file /include/root/TMVA/ROperator_Constant.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_Constant
0002 #define TMVA_SOFIE_ROPERATOR_Constant
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template<typename T>
0015 class ROperator_Constant final : public ROperator
0016 {
0017
0018 private:
0019
0020 std::string fNX;
0021 std::string fNY;
0022 std::vector<size_t> fShape;
0023 std::vector<Dim> fDimShape;
0024 std::vector<Dim> fDimOutputShape;
0025 std::vector<T> fValues;
0026 std::string fAttrType;
0027 bool fIsConstantOfShape = false;
0028 bool fIsUndefinedInputShape = false;
0029
0030 public:
0031 ROperator_Constant(){}
0032
0033 ROperator_Constant(const std::string & type, const std::vector<T> & values, const std::vector<size_t> & shape, std::string nameX, std::string nameY):
0034 fNX(UTILITY::Clean_name(nameX)),
0035 fNY(UTILITY::Clean_name(nameY)),
0036 fShape(shape),
0037 fValues(values),
0038 fAttrType(type)
0039 {
0040 fInputTensorNames = { };
0041 fOutputTensorNames = { };
0042 }
0043
0044 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0045 return input;
0046 }
0047
0048 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0049 auto ret = input;
0050 return ret;
0051 }
0052
0053 void Initialize(RModel& model) override {
0054
0055 size_t length = 1;
0056
0057 if (!fNX.empty()) {
0058
0059 fIsConstantOfShape = true;
0060 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0061 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor is not found in model");
0062 }
0063
0064
0065 if (model.IsConstantTensor(fNX)) {
0066 fIsOutputConstant = true;
0067 auto dptr = model.GetInitializedTensorData(fNX);
0068 auto input_tensor = static_cast<int64_t *>(dptr.get());
0069 auto input_shape = model.GetTensorShape(fNX);
0070 if (input_shape.size() > 1 )
0071 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has invalid shape");
0072 if (input_tensor != nullptr && !input_shape.empty()) {
0073 fShape = std::vector<size_t> (input_shape[0]);
0074 for (size_t i = 0; i < fShape.size(); i++)
0075 fShape[i] = input_tensor[i];
0076 } else
0077 fShape = {1};
0078
0079 length = ConvertShapeToLength(fShape);
0080 if (fValues.size() != 1)
0081 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op value Tensor has invalid size " + std::to_string(fValues.size()));
0082
0083 T value = fValues[0];
0084 fValues = std::vector<T>(length, value);
0085 }
0086 else if (model.IsShapeTensor(fNX)) {
0087
0088 fDimOutputShape = model.GetShapeTensorValues(fNX);
0089 } else {
0090
0091
0092 fIsUndefinedInputShape = true;
0093 fDimShape = model.GetDimTensorShape(fNX);
0094 if (fDimShape.size() > 1 )
0095 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has invalid shape");
0096 if (!fDimShape[0].isParam) {
0097 fDimOutputShape.resize(fDimShape[0].dim);
0098 for (size_t i = 0; i < fDimShape[0].dim; i++) {
0099 fDimOutputShape[i] = Dim{ std::string("s_") + fNY + "_" + std::to_string(i)};
0100 }
0101 }
0102 else {
0103 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has not defied shape");
0104 }
0105 }
0106
0107 } else {
0108
0109
0110 fIsOutputConstant = true;
0111 length = ConvertShapeToLength(fShape);
0112 if (length != fValues.size())
0113 throw std::runtime_error("TMVA SOFIE Constant Op has invalid shape : " + ConvertShapeToString(fShape) +
0114 " with " + std::to_string(fValues.size()) + " values");
0115 }
0116
0117
0118
0119
0120
0121 if (fIsOutputConstant) {
0122 model.AddConstantTensor(fNY, fShape, fValues);
0123 if (model.Verbose()) {
0124 std::cout << "adding constant tensor " << fNY << " with shape " << ConvertShapeToString(fShape)
0125 << " and values [";
0126 for (auto v : fValues) std::cout << " " << v;
0127 std::cout << "]" << std::endl;
0128 }
0129 } else {
0130 model.AddIntermediateTensor(fNY, ConvertStringToType(TensorType<T>::Name()), fDimOutputShape);
0131 }
0132 }
0133
0134 std::string Generate(std::string opName) override {
0135
0136 std::stringstream out;
0137 if (fIsOutputConstant) {
0138 if (fNX.empty())
0139 out << "// ---- Constant (no-op) " << opName << " --> " << ConvertShapeToString(fDimOutputShape) << "\n";
0140 else
0141 out << "// ---- ConstantOfShape (no-op) " << opName << " --> " << ConvertShapeToString(fDimOutputShape) << "\n";
0142 return out.str();
0143 }
0144
0145
0146
0147 out << "\n//--------- ConstantOfShape " << opName << " --> " << ConvertShapeToString(fDimOutputShape) << "\n";
0148
0149 if (fIsUndefinedInputShape) {
0150 for (size_t i = 0; i < fDimOutputShape.size(); i++) {
0151 out << SP << "size_t " << fDimOutputShape[i].param << " = " << "tensor_" << fNX << "[" << i << "];\n";
0152 }
0153 }
0154 auto length = ConvertDimShapeToLength(fDimOutputShape);
0155
0156 out << SP << "if (" << length << " > fTensor_" << fNY << ".size())\n";
0157 out << SP << SP << "fTensor_" << fNY << ".resize(" << length << ");\n";
0158 out << SP << "std::fill(fTensor_" << fNY << ".begin(), fTensor_" << fNY << ".end(), " << fValues[0] << ");\n";
0159 return out.str();
0160 }
0161 };
0162
0163 }
0164 }
0165 }
0166
0167
0168 #endif