Warning, file /include/root/TMVA/ROperator_Random.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_Random
0002 #define TMVA_SOFIE_ROPERATOR_Random
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014
0015
0016 enum RandomOpMode { kUniform, kNormal};
0017
0018 class ROperator_Random final : public ROperator
0019 {
0020 public:
0021
0022 bool fUseROOT = true;
0023 private:
0024
0025 RandomOpMode fMode;
0026 ETensorType fType;
0027 std::string fNX;
0028 std::string fNY;
0029 int fSeed;
0030 std::vector<size_t> fShapeY;
0031 std::map<std::string,float> fParams;
0032
0033
0034
0035 public:
0036
0037 ROperator_Random(){}
0038 ROperator_Random(RandomOpMode mode, ETensorType type, const std::string & nameX, const std::string & nameY, const std::vector<size_t> & shape, const std::map<std::string, float> & params, float seed) :
0039 fMode(mode),
0040 fType(type),
0041 fNX(UTILITY::Clean_name(nameX)),
0042 fNY(UTILITY::Clean_name(nameY)),
0043 fSeed(seed),
0044 fShapeY(shape),
0045 fParams(params)
0046 {
0047 fInputTensorNames = { };
0048 fOutputTensorNames = { fNY };
0049 }
0050
0051
0052 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0053 return input;
0054 }
0055
0056 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0057 auto ret = input;
0058 return ret;
0059 }
0060
0061 void Initialize(RModel& model) override {
0062
0063 model.AddIntermediateTensor(fNY, fType, fShapeY);
0064
0065 if (fUseROOT) {
0066 model.AddNeededCustomHeader("TRandom3.h");
0067 }
0068
0069
0070 if (fMode == kNormal) {
0071 if (fParams.count("mean") == 0 )
0072 fParams["mean"] = 0;
0073 if (fParams.count("scale") == 0)
0074 fParams["scale"] = 1;
0075 }
0076 if (fMode == kUniform) {
0077 if (fParams.count("low") == 0)
0078 fParams["low"] = 0;
0079 if (fParams.count("high") == 0)
0080 fParams["high"] = 1;
0081 }
0082
0083 if (model.Verbose()) {
0084 std::cout << "Random";
0085 if (fMode == kNormal) std::cout << "Normal";
0086 else if (fMode == kUniform) std::cout << "Uniform";
0087 std::cout << " op -> " << fNY << " : " << ConvertShapeToString(fShapeY) << std::endl;
0088 for (auto & p : fParams)
0089 std::cout << p.first << " : " << p.second << std::endl;
0090 }
0091 }
0092
0093 std::string GenerateDeclCode() override {
0094 std::stringstream out;
0095 out << "std::unique_ptr<TRandom> fRndmEngine; // random number engine\n";
0096 return out.str();
0097 }
0098
0099 std::string GenerateInitCode() override {
0100 std::stringstream out;
0101 out << "//--- creating random number generator ----\n";
0102 if (fUseROOT) {
0103
0104 out << SP << "fRndmEngine.reset(new TRandom3(" << fSeed << "));\n";
0105 }
0106 else {
0107
0108 }
0109 return out.str();
0110 }
0111 std::string Generate(std::string OpName) override {
0112 OpName = "op_" + OpName;
0113
0114 std::stringstream out;
0115 out << "\n//------ Random";
0116 if (fMode == kNormal) out << "Normal\n";
0117 else if (fMode == kUniform) out << "Uniform\n";
0118
0119
0120 int length = ConvertShapeToLength(fShapeY);
0121 out << SP << "for (int i = 0; i < " << length << "; i++) {\n";
0122 if (fUseROOT) {
0123 if (fMode == kNormal) {
0124 if (fParams.count("mean") == 0 || fParams.count("scale") == 0)
0125 throw std::runtime_error("TMVA SOFIE RandomNormal op : no mean or scale are defined");
0126 float mean = fParams["mean"];
0127 float scale = fParams["scale"];
0128 out << SP << SP << "tensor_" << fNY << "[i] = fRndmEngine->Gaus(" << mean << "," << scale << ");\n";
0129 } else if (fMode == kUniform) {
0130 if (fParams.count("high") == 0 || fParams.count("low") == 0)
0131 throw std::runtime_error("TMVA SOFIE RandomUniform op : no low or high are defined");
0132 float high = fParams["high"];
0133 float low = fParams["low"];
0134 out << SP << SP << "tensor_" << fNY << "[i] = fRndmEngine->Uniform(" << low << "," << high << ");\n";
0135 }
0136 }
0137 out << SP << "}\n";
0138
0139 return out.str();
0140 }
0141
0142 std::vector<std::string> GetStdLibs() override {
0143 std::vector<std::string> ret = {"memory"};
0144 return ret;
0145 }
0146
0147 };
0148
0149 }
0150 }
0151 }
0152
0153
0154 #endif