Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_BatchNormalization.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROPERATOR_BatchNormalization
0002 #define TMVA_SOFIE_ROPERATOR_BatchNormalization
0003 
0004 #include "SOFIE_common.hxx"
0005 #include "ROperator.hxx"
0006 #include "RModel.hxx"
0007 
0008 
0009 #include <cmath>
0010 #include <sstream>
0011 
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015 
0016 template <typename T>
0017 class ROperator_BatchNormalization final : public ROperator
0018 {
0019 
0020 private:
0021 
0022    /* Attributes */
0023    float fepsilon = 1e-05;
0024    float fmomentum = 0.9;
0025    std::size_t ftraining_mode = 0;
0026 
0027    std::string fNX;
0028    std::string fNScale;
0029    std::string fNB;
0030    std::string fNMean;
0031    std::string fNVar;
0032    std::string fNY;
0033    EActivationType fActivation;
0034    std::string fNFusedScale;
0035 
0036    std::vector<Dim> fShapeX;
0037    std::vector<Dim> fShapeY;
0038 
0039    std::string fType;
0040 
0041 public:
0042    ROperator_BatchNormalization() = delete;
0043 
0044    /* Constructor */
0045    ROperator_BatchNormalization( float epsilon, float momentum, std::size_t training_mode,
0046    std::string nameX, std::string nameScale, std::string nameB,
0047    std::string nameMean, std::string nameVar, std::string nameY, EActivationType activation=EActivationType::UNDEFINED):
0048    fepsilon(epsilon), fmomentum(momentum), ftraining_mode(training_mode),
0049    fNX(UTILITY::Clean_name(nameX)), fNScale(UTILITY::Clean_name(nameScale)),
0050    fNB(UTILITY::Clean_name(nameB)), fNMean(UTILITY::Clean_name(nameMean)),
0051    fNVar(UTILITY::Clean_name(nameVar)), fNY(UTILITY::Clean_name(nameY)), fActivation(activation)
0052    {
0053       fInputTensorNames = { fNX };
0054       fOutputTensorNames = { fNY };
0055       fNFusedScale = fNScale + "_fused_inv_std_dev";
0056 
0057       if(std::is_same<T, float>::value){
0058          fType = "float";
0059       }
0060       else{
0061           throw
0062               std::runtime_error("TMVA SOFIE Encountered unsupported type parsing a BatchNormalization operator");
0063       }
0064    }
0065 
0066 
0067    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0068       ETensorType out = input[0];
0069       return {out};
0070    }
0071 
0072    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0073       if (input.size() != 5 ) {
0074          throw
0075          std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference need 5 input tensors");
0076       }
0077       for(size_t i = 0; i < input.size(); i++) {
0078          if (input[i].size() != 4) {
0079             throw
0080             std::runtime_error("TMVA SOFIE BatchNormalization Op Shape inference only accept tensor with 4 dimensions");
0081          }
0082       }
0083 
0084       auto ret = input;
0085       return ret;
0086    }
0087 
0088    void Initialize(RModel& model) override {
0089       if (!model.CheckIfTensorAlreadyExist(fNX)) {
0090          throw
0091             std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNX + " fnx is not found in model");
0092       }
0093       if (!model.CheckIfTensorAlreadyExist(fNScale)) {
0094          throw
0095             std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNScale + " fns is not found in model");
0096       }
0097       if (!model.CheckIfTensorAlreadyExist(fNB)) {
0098          throw
0099             std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNB + " fnb is not found in model");
0100       }
0101       if (!model.CheckIfTensorAlreadyExist(fNMean)) {
0102          throw
0103             std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNMean + " fnm is not found in model");
0104       }
0105       if (!model.CheckIfTensorAlreadyExist(fNVar)) {
0106          throw
0107             std::runtime_error("TMVA SOFIE BatchNormalization op Input Tensor " + fNVar + " fnv is not found in model");
0108       }
0109 
0110       fShapeX = model.GetDimTensorShape(fNX);
0111 
0112       if (fShapeX.size() <  2 || fShapeX.size() > 4) {
0113          throw
0114             std::runtime_error("TMVA SOFIE BatchNormalization Op input tensor " + fNX + " fnx has wrong shape : " + ConvertShapeToString(fShapeX));
0115       }
0116 
0117       fShapeY = fShapeX;
0118       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0119 
0120       auto original_S = model.GetInitializedTensorData(fNScale);
0121       auto original_V = model.GetInitializedTensorData(fNVar);
0122 
0123       auto shape_S = model.GetTensorShape(fNScale);
0124       if (shape_S.size() != 1) {
0125           throw std::runtime_error("TMVA SOFIE BatchNormalization 'scale' tensor must be 1D (per-channel).");
0126       }
0127       size_t channels = shape_S[0];
0128 
0129       if (fType == "float") {
0130          float *original_scale_ptr = static_cast<float *>(original_S.get());
0131          float *original_var_ptr = static_cast<float *>(original_V.get());
0132          float *fused_scale_data = new float[channels];
0133 
0134          for (size_t i = 0; i < channels; i++) {
0135             // Calculate scale * (1 / sqrt(variance + epsilon))
0136             fused_scale_data[i] = original_scale_ptr[i] / std::sqrt(original_var_ptr[i] + fepsilon);
0137          }
0138 
0139          std::shared_ptr<void> fused_scale_ptr(fused_scale_data, std::default_delete<float[]>());
0140          model.AddInitializedTensor(fNFusedScale, model.GetTensorType(fNScale), {channels}, fused_scale_ptr);
0141       }
0142    }
0143 
0144    std::string Generate(std::string OpName) override {
0145       OpName = "op_" + OpName;
0146       if (fShapeX.empty()){
0147          throw std::runtime_error("TMVA SOFIE Batch Normalization called to Generate without being initialized first");
0148       }
0149 
0150       std::stringstream out;
0151       //// Batch Norm op
0152       auto batchSize = fShapeX[0].GetVal();
0153       auto channels = fShapeX[1].GetVal();
0154       std::string spatial_dim = "1";
0155       if (fShapeX.size() > 2) {
0156          auto spatialShape = fShapeX;
0157          spatialShape.erase(spatialShape.begin(), spatialShape.begin()+2);
0158          spatial_dim = ConvertDimShapeToLength( spatialShape);
0159       }
0160 
0161       out << "\n\n//---- BatchNorm" << (fActivation == EActivationType::RELU ? " + ReLU" : "") << "\n";
0162       out << SP << "{\n";
0163       out << SP << "   size_t i = 0;\n";
0164       out << SP << "   for (size_t n = 0; n < " << batchSize << "; ++n) {\n";
0165       out << SP << "      for (size_t c = 0; c < " << channels << "; ++c) {\n";
0166       out << SP << "         const float mean_val = tensor_" << fNMean << "[c];\n";
0167       out << SP << "         const float fused_scale_val = tensor_" << fNFusedScale << "[c];\n";
0168       out << SP << "         const float bias_val = tensor_" << fNB << "[c];\n";
0169       out << SP << "         for (size_t sp = 0; sp < " << spatial_dim << "; ++sp) {\n";
0170       out << SP << "            float val = (tensor_" << fNX << "[i] - mean_val) * fused_scale_val + bias_val;\n";
0171 
0172       if (fActivation == EActivationType::RELU) {
0173          out << SP << "            tensor_" << fNY << "[i] = (val > 0.0f) ? val : 0.0f;\n";
0174       } else {
0175          out << SP << "            tensor_" << fNY << "[i] = val;\n";
0176       }
0177       out << SP << "            i++;\n";
0178       out << SP << "         }\n";
0179       out << SP << "      }\n";
0180       out << SP << "   }\n";
0181       out << SP << "}\n";
0182 
0183       return out.str();
0184    }
0185 
0186    std::vector<std::string> GetBlasRoutines() override { return {}; }
0187 };
0188 
0189 }//SOFIE
0190 }//Experimental
0191 }//TMVA
0192 
0193 
0194 #endif //TMVA_SOFIE_ROPERATOR_BatchNormalization