Warning, file /include/root/TMVA/ROperator_Constant.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_Constant
0002 #define TMVA_SOFIE_ROPERATOR_Constant
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template<typename T>
0015 class ROperator_Constant final : public ROperator
0016 {
0017
0018 private:
0019
0020 std::string fNX;
0021 std::string fNY;
0022 std::vector<size_t> fShape;
0023 std::vector<T> fValues;
0024 std::string fAttrType;
0025 bool fIsConstantOfShape = false;
0026
0027 public:
0028 ROperator_Constant(){}
0029
0030 ROperator_Constant(const std::string & type, const std::vector<T> & values, const std::vector<size_t> & shape, std::string nameX, std::string nameY):
0031 fNX(UTILITY::Clean_name(nameX)),
0032 fNY(UTILITY::Clean_name(nameY)),
0033 fShape(shape),
0034 fValues(values),
0035 fAttrType(type)
0036 {
0037 fInputTensorNames = { };
0038 fOutputTensorNames = { };
0039 }
0040
0041 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0042 return input;
0043 }
0044
0045 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0046 auto ret = input;
0047 return ret;
0048 }
0049
0050 void Initialize(RModel& model) override {
0051
0052 size_t length = 1;
0053 if (!fNX.empty()) {
0054
0055 fIsConstantOfShape = true;
0056 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0057 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor is not found in model");
0058 }
0059
0060
0061 auto dptr = model.GetInitializedTensorData(fNX);
0062 auto input_tensor = static_cast<int64_t *>(dptr.get());
0063 auto input_shape = model.GetTensorShape(fNX);
0064 if (input_shape.size() > 1 )
0065 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op Input Tensor has invalid shape");
0066 if (input_tensor != nullptr && !input_shape.empty()) {
0067 fShape = std::vector<size_t> (input_shape[0]);
0068 for (size_t i = 0; i < fShape.size(); i++)
0069 fShape[i] = input_tensor[i];
0070 } else
0071 fShape = {1};
0072
0073 length = ConvertShapeToLength(fShape);
0074 if (fValues.size() != 1)
0075 throw std::runtime_error("TMVA SOFIE ConstantOfShape Op value Tensor has invalid size " + std::to_string(fValues.size()));
0076
0077 T value = fValues[0];
0078 fValues = std::vector<T>(length, value);
0079
0080 } else {
0081
0082
0083 length = ConvertShapeToLength(fShape);
0084 if (length != fValues.size())
0085 throw std::runtime_error("TMVA SOFIE Constant Op has invalid shape : " + ConvertShapeToString(fShape) +
0086 " with " + std::to_string(fValues.size()) + " values");
0087 }
0088
0089
0090
0091
0092
0093 model.AddConstantTensor(fNY, fShape, fValues);
0094 if (model.Verbose()) {
0095 std::cout << "adding constant tensor " << fNY << " with shape " << ConvertShapeToString(fShape)
0096 << " and values [";
0097 for (auto v : fValues) std::cout << " " << v;
0098 std::cout << "]" << std::endl;
0099 }
0100 }
0101
0102 std::string Generate(std::string ) override {
0103
0104 return "//---------------------------------------\n";
0105 }
0106 };
0107
0108 }
0109 }
0110 }
0111
0112
0113 #endif