Warning, file /include/root/TMVA/ROperator_BatchNormalization.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_BatchNormalization
0002 #define TMVA_SOFIE_ROPERATOR_BatchNormalization
0003
0004 #include "SOFIE_common.hxx"
0005 #include "ROperator.hxx"
0006 #include "RModel.hxx"
0007
0008
0009 #include <cmath>
0010 #include <sstream>
0011
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015
0016 template <typename T>
0017 class ROperator_BatchNormalization final : public ROperator
0018 {
0019
0020 private:
0021
0022
0023 float fepsilon = 1e-05;
0024 float fmomentum = 0.9;
0025 std::size_t ftraining_mode = 0;
0026
0027 std::string fNX;
0028 std::string fNScale;
0029 std::string fNB;
0030 std::string fNMean;
0031 std::string fNVar;
0032 std::string fNY;
0033 EActivationType fActivation;
0034 std::string fNFusedScale;
0035
0036 std::vector<Dim> fShapeX;
0037 std::vector<Dim> fShapeY;
0038
0039 std::string fType;
0040
0041 public:
0042 ROperator_BatchNormalization() = delete;
0043
0044
0045 ROperator_BatchNormalization( float epsilon, float momentum, std::size_t training_mode,
0046 std::string nameX, std::string nameScale, std::string nameB,
0047 std::string nameMean, std::string nameVar, std::string nameY, EActivationType activation=EActivationType::UNDEFINED):
0048 fepsilon(epsilon), fmomentum(momentum), ftraining_mode(training_mode),
0049 fNX(UTILITY::Clean_name(nameX)), fNScale(UTILITY::Clean_name(nameScale)),
0050 fNB(UTILITY::Clean_name(nameB)), fNMean(UTILITY::Clean_name(nameMean)),
0051 fNVar(UTILITY::Clean_name(nameVar)), fNY(UTILITY::Clean_name(nameY)), fActivation(activation)
0052 {
0053 fInputTensorNames = { fNX };
0054 fOutputTensorNames = { fNY };
0055 fNFusedScale = fNScale + "_fused_inv_std_dev";
0056
0057 if(std::is_same<T, float>::value){
0058 fType = "float";
0059 }
0060 else{
0061 throw
0062 std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a BatchNormalization operator");
0063 }
0064 }
0065
0066
0067 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0068 ETensorType out = input[0];
0069 return {out};
0070 }
0071
0072 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0073 if (input.size() != 5 ) {
0074 throw
0075 std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference need 5 input tensors");
0076 }
0077 for(size_t i = 0; i < input.size(); i++) {
0078 if (input[i].size() != 4) {
0079 throw
0080 std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference only accept tensor with 4 dimensions");
0081 }
0082 }
0083
0084 auto ret = input;
0085 return ret;
0086 }
0087
0088 void Initialize(RModel& model) override {
0089 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0090 throw
0091 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNX + " fnx is not found in model");
0092 }
0093 if (!model.CheckIfTensorAlreadyExist(fNScale)) {
0094 throw
0095 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNScale + " fns is not found in model");
0096 }
0097 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0098 throw
0099 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNB + " fnb is not found in model");
0100 }
0101 if (!model.CheckIfTensorAlreadyExist(fNMean)) {
0102 throw
0103 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNMean + " fnm is not found in model");
0104 }
0105 if (!model.CheckIfTensorAlreadyExist(fNVar)) {
0106 throw
0107 std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNVar + " fnv is not found in model");
0108 }
0109
0110 fShapeX = model.GetDimTensorShape(fNX);
0111
0112 if (fShapeX.size() < 2 || fShapeX.size() > 4) {
0113 throw
0114 std::runtime_error("TMVA SOFIE BatchNormalization Op input tensor " + fNX + " fnx has wrong shape : " + ConvertShapeToString(fShapeX));
0115 }
0116
0117 fShapeY = fShapeX;
0118 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0119
0120 auto original_S = model.GetInitializedTensorData(fNScale);
0121 auto original_V = model.GetInitializedTensorData(fNVar);
0122
0123 auto shape_S = model.GetTensorShape(fNScale);
0124 if (shape_S.size() != 1) {
0125 throw std::runtime_error("TMVA SOFIE BatchNormalization 'scale' tensor must be 1D (per-channel).");
0126 }
0127 size_t channels = shape_S[0];
0128
0129 if (fType == "float") {
0130 float *original_scale_ptr = static_cast<float *>(original_S.get());
0131 float *original_var_ptr = static_cast<float *>(original_V.get());
0132 float *fused_scale_data = new float[channels];
0133
0134 for (size_t i = 0; i < channels; i++) {
0135
0136 fused_scale_data[i] = original_scale_ptr[i] / std::sqrt(original_var_ptr[i] + fepsilon);
0137 }
0138
0139 std::shared_ptr<void> fused_scale_ptr(fused_scale_data, std::default_delete<float[]>());
0140 model.AddInitializedTensor(fNFusedScale, model.GetTensorType(fNScale), {channels}, fused_scale_ptr);
0141 }
0142 }
0143
0144 std::string Generate(std::string OpName) override {
0145 OpName = "op_" + OpName;
0146 if (fShapeX.empty()){
0147 throw std::runtime_error("TMVA SOFIE Batch Normalization called to Generate without being initialized first");
0148 }
0149
0150 std::stringstream out;
0151
0152 auto batchSize = fShapeX[0].GetVal();
0153 auto channels = fShapeX[1].GetVal();
0154 std::string spatial_dim = "1";
0155 if (fShapeX.size() > 2) {
0156 auto spatialShape = fShapeX;
0157 spatialShape.erase(spatialShape.begin(), spatialShape.begin()+2);
0158 spatial_dim = ConvertDimShapeToLength( spatialShape);
0159 }
0160
0161 out << "\n\n//---- BatchNorm" << (fActivation == EActivationType::RELU ? " + ReLU" : "") << "\n";
0162 out << SP << "{\n";
0163 out << SP << " size_t i = 0;\n";
0164 out << SP << " for (size_t n = 0; n < " << batchSize << "; ++n) {\n";
0165 out << SP << " for (size_t c = 0; c < " << channels << "; ++c) {\n";
0166 out << SP << " const float mean_val = tensor_" << fNMean << "[c];\n";
0167 out << SP << " const float fused_scale_val = tensor_" << fNFusedScale << "[c];\n";
0168 out << SP << " const float bias_val = tensor_" << fNB << "[c];\n";
0169 out << SP << " for (size_t sp = 0; sp < " << spatial_dim << "; ++sp) {\n";
0170 out << SP << " float val = (tensor_" << fNX << "[i] - mean_val) * fused_scale_val + bias_val;\n";
0171
0172 if (fActivation == EActivationType::RELU) {
0173 out << SP << " tensor_" << fNY << "[i] = (val > 0.0f) ? val : 0.0f;\n";
0174 } else {
0175 out << SP << " tensor_" << fNY << "[i] = val;\n";
0176 }
0177 out << SP << " i++;\n";
0178 out << SP << " }\n";
0179 out << SP << " }\n";
0180 out << SP << " }\n";
0181 out << SP << "}\n";
0182
0183 return out.str();
0184 }
0185
0186 std::vector<std::string> GetBlasRoutines() override { return {}; }
0187 };
0188
0189 }
0190 }
0191 }
0192
0193
0194 #endif