Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:14:54

0001 #ifndef TMVA_SOFIE_ROPERATOR_LAYERNORMALIZATION
0002 #define TMVA_SOFIE_ROPERATOR_LAYERNORMALIZATION
0003 
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/SOFIE_common.hxx"
0006 
0007 #include <sstream>
0008 #include <string>
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 template <typename T>
0015 class ROperator_LayerNormalization : public ROperator {
0016 private:
0017    int fAttrAxis;
0018    float fAttrEpsilon;
0019    size_t fAttrStashType;
0020 
0021    std::string fNX;
0022    std::string fNScale;
0023    std::string fNB;
0024    std::string fNY;
0025    std::string fNMean;
0026    std::string fNInvStdDev;
0027 
0028    std::string fNCastedX;
0029    std::string fNNormalizedX;
0030    std::string fNBroadcastedB;
0031 
0032    std::vector<Dim> fShapeX;
0033    std::vector<Dim> fShapeScale;
0034    std::vector<size_t> fShapeB;  // shape of input Bias (B) is assumed to be fully defined
0035    std::vector<Dim> fShapeY;
0036    std::vector<Dim> fShapeMean;
0037    std::vector<Dim> fShapeInvStdDev;
0038 
0039    size_t fAxis; // axis in [0, size)
0040    size_t fSize; // Size of the input
0041    // size_t fAxisDim;
0042 
0043    std::vector<Dim> fNormalizedShape;
0044    std::vector<Dim> fAxesShape;
0045    // lengths in string format
0046    std::string fLength; // Length of the input
0047    std::string fNormalizedLength;
0048    std::string fAxesLength;
0049 
0050    std::string fType;
0051 
0052 public:
0053    ROperator_LayerNormalization() {}
0054 
0055    ROperator_LayerNormalization(int axis, float epsilon, size_t stashType, const std::string &nameX,
0056                                 const std::string &nameScale, const std::string &nameB, const std::string &nameY,
0057                                 const std::string &nameMean, const std::string &nameInvStdDev)
0058       : fAttrAxis(axis), fAttrEpsilon(epsilon), fAttrStashType(stashType), fNX(UTILITY::Clean_name(nameX)),
0059         fNScale(UTILITY::Clean_name(nameScale)), fNB(UTILITY::Clean_name(nameB)),
0060         fNY(UTILITY::Clean_name(nameY)), fNMean(UTILITY::Clean_name(nameMean)), fNInvStdDev(UTILITY::Clean_name(nameInvStdDev))
0061    {
0062          fInputTensorNames = { fNX, fNScale };
0063          if (!fNB.empty()){
0064             fInputTensorNames.emplace_back(fNB);
0065          }
0066 
0067          fOutputTensorNames = { fNY };
0068          if (!fNMean.empty()){
0069             fOutputTensorNames.emplace_back(fNMean);
0070          }
0071          if (!fNInvStdDev.empty()){
0072             fOutputTensorNames.emplace_back(fNInvStdDev);
0073          }
0074    }
0075 
0076    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override { return input; }
0077 
0078    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0079 
0080    void Initialize(RModel& model) override {
0081       if (!model.CheckIfTensorAlreadyExist(fNX)) {
0082          throw std::runtime_error("TMVA::SOFIE - Tensor " + fNX + " not found.");
0083       }
0084       bool isDynamic = model.IsDynamicTensor(fNX);
0085       fShapeX = model.GetDynamicTensorShape(fNX);
0086       fShapeY = fShapeX;
0087       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0088       // Type of the output
0089       fType = ConvertTypeToString(model.GetTensorType(fNX));
0090       // Size of the input
0091       fSize = fShapeX.size();
0092       // Axis in [0, size)
0093       fAxis = (fAttrAxis < 0) ? fSize + fAttrAxis : fAttrAxis;
0094       // Shape of fShapeX[0, ..., fAxis)
0095       fAxesShape = std::vector<Dim>(fShapeX.begin(), fShapeX.begin() + fAxis);
0096       // Length of the axes
0097       fAxesLength = ConvertDynamicShapeToLength(fAxesShape);
0098       // Shape of fShapeX[fAxis, ..., fSize)
0099       fNormalizedShape = std::vector<Dim>(fShapeX.begin() + fAxis, fShapeX.end());
0100       // Length of the normalized axis
0101       fNormalizedLength = ConvertDynamicShapeToLength(fNormalizedShape);
0102       // length of the input
0103       fLength = ConvertDynamicShapeToLength(fShapeX);
0104       // Type of mean and std
0105       ETensorType type = (fAttrStashType == 1) ? ETensorType::FLOAT : model.GetTensorType(fNX);
0106       // Mean
0107       if (fNMean.empty()) {
0108          fNMean = "Mean" + fNX;
0109          // cannot use initializer list with one element since it is ambiguous
0110          if (isDynamic)
0111             // add size_t(-1) to indicate that shape is an expression
0112             model.AddIntermediateTensor(fNMean, type, std::vector<Dim>(1,Dim{fAxesLength,std::size_t(-1)}));
0113          else
0114             model.AddIntermediateTensor(fNMean, type, std::vector<size_t>(1,std::stoi(fAxesLength)));
0115       }
0116       // Inverse Standard Deviation
0117       if (fNInvStdDev.empty()) {
0118          fNInvStdDev = "InvStdDev" + fNX;
0119          if (isDynamic)
0120             model.AddIntermediateTensor(fNInvStdDev, type, std::vector<Dim>(1,Dim{fAxesLength,std::size_t(-1)}));
0121          else
0122             model.AddIntermediateTensor(fNInvStdDev, type, std::vector<size_t>(1,std::stoi(fAxesLength)));
0123       }
0124       // Cast X to float
0125       if (fAttrStashType == 1 && model.GetTensorType(fNX) != ETensorType::FLOAT) {
0126          fNCastedX = "Casted" + fNX;
0127          model.AddIntermediateTensor(fNCastedX, ETensorType::FLOAT, fShapeX);
0128          fNNormalizedX = "Normalized" + fNX;
0129          model.AddIntermediateTensor(fNNormalizedX, ETensorType::FLOAT, fShapeX);
0130       }
0131       // Broadcast the bias
0132       if (!fNB.empty()) {
0133          fShapeB = model.GetTensorShape(fNB);
0134          size_t lengthB = ConvertShapeToLength(fShapeB);
0135          if (isDynamic || lengthB < static_cast<size_t>(std::stoi(fLength))) {
0136             fNBroadcastedB = "Broadcasted" + fNB;
0137             model.AddIntermediateTensor(fNBroadcastedB, ConvertStringToType(fType), fShapeX);
0138          }
0139       }
0140       model.AddNeededStdLib("cmath");
0141    }
0142 
0143    std::string GenerateInitCode() override
0144    {
0145       std::stringstream out;
0146       if (!fNBroadcastedB.empty()) {
0147          out << SP << "// Broadcasting the bias of LayerNormalization op\n";
0148          out << SP << "{\n";
0149          out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_";
0150          out << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertDynamicShapeToString(fShapeX) << ");\n";
0151          out << SP << "std::copy(data, data + " << fLength << ", tensor_" << fNBroadcastedB << ");\n";
0152          out << SP << "delete[] data;\n";
0153          out << SP << "}\n";
0154       }
0155       return out.str();
0156    }
0157 
0158    std::string Generate(std::string opName) override
0159    {
0160       opName = "op_" + opName;
0161       if (fShapeX.empty()) {
0162          throw std::runtime_error("TMVA::SOFIE LayerNormalization operator " + opName +
0163                                   " called to generate without being initialized first.");
0164       }
0165       if (fShapeX.size() > 5) {
0166          throw std::runtime_error("TMVA::SOFIE LayerNormalization operator not "
0167                                   "implemented for input tensor of size > 5.");
0168       }
0169 
0170       std::stringstream out;
0171 
0172       out << "//---- Layer Normalization  operator " << opName << "\n";
0173 
0174       // Loop over all the normalized axes i.e. [axis, ..., size)
0175       std::vector<std::string> inputShape(fSize);
0176 
0177       for (size_t i = 0; i < fSize; i++) {
0178          inputShape[i] = fShapeX[i].GetVal();
0179       }
0180 
0181       auto strides = UTILITY::ComputeStrideFromShape(fShapeX);
0182       std::string InputIndex = "axis_0 * " + strides[0].GetVal();
0183       for (size_t i = 1; i < fSize; i++) {
0184          InputIndex += " + axis_" + std::to_string(i) + " * " + strides[i].GetVal();
0185       }
0186 
0187       auto axesStrides = UTILITY::ComputeStrideFromShape(fAxesShape);
0188       std::string axesIndex = "axis_" + std::to_string(0) + " * " + axesStrides[0].GetVal();
0189       for (size_t i = 1; i < fAxis; i++) {
0190          axesIndex += " + axis_" + std::to_string(i) + " * " + axesStrides[i].GetVal();
0191       }
0192 
0193       auto normalizedStrides = UTILITY::ComputeStrideFromShape(fNormalizedShape);
0194       std::string normalizedIndex = "axis_" + std::to_string(fAxis) + " * " + normalizedStrides[0].GetVal();
0195       for (size_t i = fAxis + 1; i < fSize; i++) {
0196          normalizedIndex += " + axis_" + std::to_string(i) + " * " + normalizedStrides[i - fAxis].GetVal();
0197       }
0198 
0199       if (!fNCastedX.empty()) {
0200          // Cast X to float
0201          out << SP << "for (size_t i = 0; i < " << fLength << "; i++) {\n";
0202          out << SP << SP << "tensor_" << fNCastedX << "[i] = " << "static_cast<float>(tensor_" << fNX;
0203          out << "[i]);\n";
0204          out << SP << "}\n";
0205       }
0206 
0207       out << SP << "// Compute the mean\n";
0208       // Loop over the normalized dimensions
0209       for (size_t i = 0; i < fAxis; i++) {
0210          std::string iIdx = "axis_" + std::to_string(i);
0211          out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0212                       << "; " << iIdx << "++) {\n";
0213       }
0214       out << SP << SP << fType << " sum = 0.;\n";
0215       // loop over all the dims in [0, fAxis)
0216       for (size_t j = fAxis; j < fSize; j++) {
0217          std::string jIdx = "axis_" + std::to_string(j);
0218          out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0219                          << "; " << jIdx << "++) {\n";
0220       }
0221       out << SP << SP << SP << "sum += tensor_" << fNX << "[" << InputIndex << "];\n";
0222       for (size_t j = fAxis; j < fSize; j++) {
0223          out << SP << SP << "}\n";
0224       }
0225       out << SP << SP << "tensor_" << fNMean << "[" << axesIndex << "] = sum / " << fType << "(";
0226       out << fNormalizedLength << ");\n";
0227       for (size_t i = fAxis; i < fSize; i++) {
0228          out << SP << "}\n";
0229       }
0230 
0231       out << SP << "// Compute the inverse Standard Deviation\n";
0232       // Loop over the normalized dimensions
0233       for (size_t i = 0; i < fAxis; i++) {
0234          std::string iIdx = "axis_" + std::to_string(i);
0235          out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0236                    << "; " << iIdx << "++){\n";
0237       }
0238       // Set sum = 0
0239       out << SP << SP << fType << " sum = 0.;\n";
0240       // loop over all the dims in [0, fAxis)
0241       for (size_t j = fAxis; j < fSize; j++) {
0242          std::string jIdx = "axis_" + std::to_string(j);
0243          out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0244                           << "; " << jIdx << "++){\n";
0245       }
0246       out << SP << SP << SP << "float tmp = tensor_" << fNX << "[" << InputIndex << "] - tensor_"
0247                             << fNMean << "[" << axesIndex << "];\n";
0248       out << SP << SP << SP << "sum += tmp*tmp;\n";
0249       for (size_t j = fAxis; j < fSize; j++) {
0250          out << SP << SP << "}\n";
0251       }
0252       out << SP << SP << "tensor_" << fNInvStdDev << "[" << axesIndex << "] = 1 / std::sqrt(";
0253       out << "sum / " << fType << "(" << fNormalizedLength << ") + " << fAttrEpsilon << ");\n";
0254       for (size_t i = 0; i < fAxis; i++) {
0255          out << SP << "}\n";
0256       }
0257 
0258       if (!fNCastedX.empty()) {
0259          out << "// NormalizedX = InvStdDev * (CastedX - Mean)\n";
0260          for (size_t i = 0; i < fAxis; i++) {
0261             std::string iIdx = "axis_" + std::to_string(i);
0262             out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0263                           << "; " << iIdx << "++){\n";
0264          }
0265          for (size_t j = fAxis; j < fSize; j++) {
0266             std::string jIdx = "axis_" + std::to_string(j);
0267             out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0268                              << "; " << jIdx << "++){\n";
0269          }
0270          out << SP << SP << SP << "tensor_" << fNNormalizedX << "[" << InputIndex << "] = tensor_";
0271          out << fNInvStdDev << "[" << axesIndex << "] * (tensor_" << fNCastedX << "[" << InputIndex;
0272          out << "] - tensor_" << fNMean << "[" << axesIndex << "])\n";
0273          for (size_t j = fAxis; j < fSize; j++) {
0274             out << SP << SP << "}\n";
0275          }
0276          for (size_t i = fAxis; i < fSize; i++) {
0277             out << SP << "}\n";
0278          }
0279          out << "// Y = Scale o NormalizedX";
0280          for (size_t i = 0; i < fAxis; i++) {
0281             std::string iIdx = "axis_" + std::to_string(i);
0282             out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0283                       << "; " << iIdx << "++){\n";
0284          }
0285          for (size_t j = fAxis; j < fSize; j++) {
0286             std::string jIdx = "axis_" + std::to_string(j);
0287             out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0288                             << "; " << jIdx << "++){\n";
0289          }
0290          out << SP << SP << SP << "tensor_" << fNY << "[" << InputIndex << "] = tensor_" << fNScale;
0291          out << "[" << axesIndex << "] * static_cast<" << fType << ">(tensor_" << fNCastedX << "[" << InputIndex;
0292          out << "]);\n";
0293          for (size_t j = fAxis; j < fSize; j++) {
0294             out << SP << SP << "}\n";
0295          }
0296          for (size_t i = fAxis; i < fSize; i++) {
0297             out << SP << "}\n";
0298          }
0299       } else {
0300          out << SP << "// Y = Scale o InvStdDev (X - Mean)\n";
0301          for (size_t i = 0; i < fAxis; i++) {
0302             std::string iIdx = "axis_" + std::to_string(i);
0303             out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0304                          << "; " << iIdx << "++){\n";
0305          }
0306          for (size_t j = fAxis; j < fSize; j++) {
0307             std::string jIdx = "axis_" + std::to_string(j);
0308             out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0309                            << "; " << jIdx << "++){\n";
0310          }
0311          out << SP << SP << SP << "tensor_" << fNY << "[" << InputIndex << "] = tensor_" << fNScale;
0312          out << "[" << normalizedIndex << "] * tensor_" << fNInvStdDev << "[" << axesIndex;
0313          out << "] * (tensor_" << fNX << "[" << InputIndex << "] - tensor_" << fNMean << "[";
0314          out << axesIndex << "]);\n";
0315          for (size_t j = fAxis; j < fSize; j++) {
0316             out << SP << SP << "}\n";
0317          }
0318          for (size_t i = fAxis; i < fSize; i++) {
0319             out << SP << "}\n";
0320          }
0321       }
0322 
0323       if (!fNB.empty()) {
0324          std::string bias = "tensor_" + (fNBroadcastedB.empty() ? fNB : fNBroadcastedB);
0325          out << SP << "// Add the bias to Y\n";
0326          out << SP << "int " << opName << "_n = " << fLength << ";\n";
0327          out << SP << "float " << opName << "_alpha = 1.;\n";
0328          out << SP << "int " << opName << "_inc = 1;\n";
0329          out << SP << "BLAS::saxpy_(&" << opName << "_n, &" << opName << "_alpha, " << bias << ", &";
0330          out << opName << "_inc, " << "tensor_" << fNY << ", &" << opName << "_inc);\n";
0331       }
0332 
0333       return out.str();
0334    }
0335 
0336    std::vector<std::string> GetBlasRoutines() override { return { std::string("Axpy") }; }
0337 
0338    std::vector<std::string> GetStdLibs() override { return { std::string("cmath") }; }
0339 };
0340 
0341 } // namespace SOFIE
0342 } // namespace Experimental
0343 } // namespace TMVA
0344 
0345 #endif