File indexing completed on 2025-09-17 09:14:54
0001 #ifndef TMVA_SOFIE_ROPERATOR_LAYERNORMALIZATION
0002 #define TMVA_SOFIE_ROPERATOR_LAYERNORMALIZATION
0003
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/SOFIE_common.hxx"
0006
0007 #include <sstream>
0008 #include <string>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 template <typename T>
0015 class ROperator_LayerNormalization : public ROperator {
0016 private:
0017 int fAttrAxis;
0018 float fAttrEpsilon;
0019 size_t fAttrStashType;
0020
0021 std::string fNX;
0022 std::string fNScale;
0023 std::string fNB;
0024 std::string fNY;
0025 std::string fNMean;
0026 std::string fNInvStdDev;
0027
0028 std::string fNCastedX;
0029 std::string fNNormalizedX;
0030 std::string fNBroadcastedB;
0031
0032 std::vector<Dim> fShapeX;
0033 std::vector<Dim> fShapeScale;
0034 std::vector<size_t> fShapeB;
0035 std::vector<Dim> fShapeY;
0036 std::vector<Dim> fShapeMean;
0037 std::vector<Dim> fShapeInvStdDev;
0038
0039 size_t fAxis;
0040 size_t fSize;
0041
0042
0043 std::vector<Dim> fNormalizedShape;
0044 std::vector<Dim> fAxesShape;
0045
0046 std::string fLength;
0047 std::string fNormalizedLength;
0048 std::string fAxesLength;
0049
0050 std::string fType;
0051
0052 public:
0053 ROperator_LayerNormalization() {}
0054
0055 ROperator_LayerNormalization(int axis, float epsilon, size_t stashType, const std::string &nameX,
0056 const std::string &nameScale, const std::string &nameB, const std::string &nameY,
0057 const std::string &nameMean, const std::string &nameInvStdDev)
0058 : fAttrAxis(axis), fAttrEpsilon(epsilon), fAttrStashType(stashType), fNX(UTILITY::Clean_name(nameX)),
0059 fNScale(UTILITY::Clean_name(nameScale)), fNB(UTILITY::Clean_name(nameB)),
0060 fNY(UTILITY::Clean_name(nameY)), fNMean(UTILITY::Clean_name(nameMean)), fNInvStdDev(UTILITY::Clean_name(nameInvStdDev))
0061 {
0062 fInputTensorNames = { fNX, fNScale };
0063 if (!fNB.empty()){
0064 fInputTensorNames.emplace_back(fNB);
0065 }
0066
0067 fOutputTensorNames = { fNY };
0068 if (!fNMean.empty()){
0069 fOutputTensorNames.emplace_back(fNMean);
0070 }
0071 if (!fNInvStdDev.empty()){
0072 fOutputTensorNames.emplace_back(fNInvStdDev);
0073 }
0074 }
0075
0076 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override { return input; }
0077
0078 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0079
0080 void Initialize(RModel& model) override {
0081 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0082 throw std::runtime_error("TMVA::SOFIE - Tensor " + fNX + " not found.");
0083 }
0084 bool isDynamic = model.IsDynamicTensor(fNX);
0085 fShapeX = model.GetDynamicTensorShape(fNX);
0086 fShapeY = fShapeX;
0087 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0088
0089 fType = ConvertTypeToString(model.GetTensorType(fNX));
0090
0091 fSize = fShapeX.size();
0092
0093 fAxis = (fAttrAxis < 0) ? fSize + fAttrAxis : fAttrAxis;
0094
0095 fAxesShape = std::vector<Dim>(fShapeX.begin(), fShapeX.begin() + fAxis);
0096
0097 fAxesLength = ConvertDynamicShapeToLength(fAxesShape);
0098
0099 fNormalizedShape = std::vector<Dim>(fShapeX.begin() + fAxis, fShapeX.end());
0100
0101 fNormalizedLength = ConvertDynamicShapeToLength(fNormalizedShape);
0102
0103 fLength = ConvertDynamicShapeToLength(fShapeX);
0104
0105 ETensorType type = (fAttrStashType == 1) ? ETensorType::FLOAT : model.GetTensorType(fNX);
0106
0107 if (fNMean.empty()) {
0108 fNMean = "Mean" + fNX;
0109
0110 if (isDynamic)
0111
0112 model.AddIntermediateTensor(fNMean, type, std::vector<Dim>(1,Dim{fAxesLength,std::size_t(-1)}));
0113 else
0114 model.AddIntermediateTensor(fNMean, type, std::vector<size_t>(1,std::stoi(fAxesLength)));
0115 }
0116
0117 if (fNInvStdDev.empty()) {
0118 fNInvStdDev = "InvStdDev" + fNX;
0119 if (isDynamic)
0120 model.AddIntermediateTensor(fNInvStdDev, type, std::vector<Dim>(1,Dim{fAxesLength,std::size_t(-1)}));
0121 else
0122 model.AddIntermediateTensor(fNInvStdDev, type, std::vector<size_t>(1,std::stoi(fAxesLength)));
0123 }
0124
0125 if (fAttrStashType == 1 && model.GetTensorType(fNX) != ETensorType::FLOAT) {
0126 fNCastedX = "Casted" + fNX;
0127 model.AddIntermediateTensor(fNCastedX, ETensorType::FLOAT, fShapeX);
0128 fNNormalizedX = "Normalized" + fNX;
0129 model.AddIntermediateTensor(fNNormalizedX, ETensorType::FLOAT, fShapeX);
0130 }
0131
0132 if (!fNB.empty()) {
0133 fShapeB = model.GetTensorShape(fNB);
0134 size_t lengthB = ConvertShapeToLength(fShapeB);
0135 if (isDynamic || lengthB < static_cast<size_t>(std::stoi(fLength))) {
0136 fNBroadcastedB = "Broadcasted" + fNB;
0137 model.AddIntermediateTensor(fNBroadcastedB, ConvertStringToType(fType), fShapeX);
0138 }
0139 }
0140 model.AddNeededStdLib("cmath");
0141 }
0142
0143 std::string GenerateInitCode() override
0144 {
0145 std::stringstream out;
0146 if (!fNBroadcastedB.empty()) {
0147 out << SP << "// Broadcasting the bias of LayerNormalization op\n";
0148 out << SP << "{\n";
0149 out << SP << SP << "float* data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_";
0150 out << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertDynamicShapeToString(fShapeX) << ");\n";
0151 out << SP << "std::copy(data, data + " << fLength << ", tensor_" << fNBroadcastedB << ");\n";
0152 out << SP << "delete[] data;\n";
0153 out << SP << "}\n";
0154 }
0155 return out.str();
0156 }
0157
0158 std::string Generate(std::string opName) override
0159 {
0160 opName = "op_" + opName;
0161 if (fShapeX.empty()) {
0162 throw std::runtime_error("TMVA::SOFIE LayerNormalization operator " + opName +
0163 " called to generate without being initialized first.");
0164 }
0165 if (fShapeX.size() > 5) {
0166 throw std::runtime_error("TMVA::SOFIE LayerNormalization operator not "
0167 "implemented for input tensor of size > 5.");
0168 }
0169
0170 std::stringstream out;
0171
0172 out << "//---- Layer Normalization operator " << opName << "\n";
0173
0174
0175 std::vector<std::string> inputShape(fSize);
0176
0177 for (size_t i = 0; i < fSize; i++) {
0178 inputShape[i] = fShapeX[i].GetVal();
0179 }
0180
0181 auto strides = UTILITY::ComputeStrideFromShape(fShapeX);
0182 std::string InputIndex = "axis_0 * " + strides[0].GetVal();
0183 for (size_t i = 1; i < fSize; i++) {
0184 InputIndex += " + axis_" + std::to_string(i) + " * " + strides[i].GetVal();
0185 }
0186
0187 auto axesStrides = UTILITY::ComputeStrideFromShape(fAxesShape);
0188 std::string axesIndex = "axis_" + std::to_string(0) + " * " + axesStrides[0].GetVal();
0189 for (size_t i = 1; i < fAxis; i++) {
0190 axesIndex += " + axis_" + std::to_string(i) + " * " + axesStrides[i].GetVal();
0191 }
0192
0193 auto normalizedStrides = UTILITY::ComputeStrideFromShape(fNormalizedShape);
0194 std::string normalizedIndex = "axis_" + std::to_string(fAxis) + " * " + normalizedStrides[0].GetVal();
0195 for (size_t i = fAxis + 1; i < fSize; i++) {
0196 normalizedIndex += " + axis_" + std::to_string(i) + " * " + normalizedStrides[i - fAxis].GetVal();
0197 }
0198
0199 if (!fNCastedX.empty()) {
0200
0201 out << SP << "for (size_t i = 0; i < " << fLength << "; i++) {\n";
0202 out << SP << SP << "tensor_" << fNCastedX << "[i] = " << "static_cast<float>(tensor_" << fNX;
0203 out << "[i]);\n";
0204 out << SP << "}\n";
0205 }
0206
0207 out << SP << "// Compute the mean\n";
0208
0209 for (size_t i = 0; i < fAxis; i++) {
0210 std::string iIdx = "axis_" + std::to_string(i);
0211 out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0212 << "; " << iIdx << "++) {\n";
0213 }
0214 out << SP << SP << fType << " sum = 0.;\n";
0215
0216 for (size_t j = fAxis; j < fSize; j++) {
0217 std::string jIdx = "axis_" + std::to_string(j);
0218 out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0219 << "; " << jIdx << "++) {\n";
0220 }
0221 out << SP << SP << SP << "sum += tensor_" << fNX << "[" << InputIndex << "];\n";
0222 for (size_t j = fAxis; j < fSize; j++) {
0223 out << SP << SP << "}\n";
0224 }
0225 out << SP << SP << "tensor_" << fNMean << "[" << axesIndex << "] = sum / " << fType << "(";
0226 out << fNormalizedLength << ");\n";
0227 for (size_t i = fAxis; i < fSize; i++) {
0228 out << SP << "}\n";
0229 }
0230
0231 out << SP << "// Compute the inverse Standard Deviation\n";
0232
0233 for (size_t i = 0; i < fAxis; i++) {
0234 std::string iIdx = "axis_" + std::to_string(i);
0235 out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0236 << "; " << iIdx << "++){\n";
0237 }
0238
0239 out << SP << SP << fType << " sum = 0.;\n";
0240
0241 for (size_t j = fAxis; j < fSize; j++) {
0242 std::string jIdx = "axis_" + std::to_string(j);
0243 out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0244 << "; " << jIdx << "++){\n";
0245 }
0246 out << SP << SP << SP << "float tmp = tensor_" << fNX << "[" << InputIndex << "] - tensor_"
0247 << fNMean << "[" << axesIndex << "];\n";
0248 out << SP << SP << SP << "sum += tmp*tmp;\n";
0249 for (size_t j = fAxis; j < fSize; j++) {
0250 out << SP << SP << "}\n";
0251 }
0252 out << SP << SP << "tensor_" << fNInvStdDev << "[" << axesIndex << "] = 1 / std::sqrt(";
0253 out << "sum / " << fType << "(" << fNormalizedLength << ") + " << fAttrEpsilon << ");\n";
0254 for (size_t i = 0; i < fAxis; i++) {
0255 out << SP << "}\n";
0256 }
0257
0258 if (!fNCastedX.empty()) {
0259 out << "// NormalizedX = InvStdDev * (CastedX - Mean)\n";
0260 for (size_t i = 0; i < fAxis; i++) {
0261 std::string iIdx = "axis_" + std::to_string(i);
0262 out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0263 << "; " << iIdx << "++){\n";
0264 }
0265 for (size_t j = fAxis; j < fSize; j++) {
0266 std::string jIdx = "axis_" + std::to_string(j);
0267 out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0268 << "; " << jIdx << "++){\n";
0269 }
0270 out << SP << SP << SP << "tensor_" << fNNormalizedX << "[" << InputIndex << "] = tensor_";
0271 out << fNInvStdDev << "[" << axesIndex << "] * (tensor_" << fNCastedX << "[" << InputIndex;
0272 out << "] - tensor_" << fNMean << "[" << axesIndex << "])\n";
0273 for (size_t j = fAxis; j < fSize; j++) {
0274 out << SP << SP << "}\n";
0275 }
0276 for (size_t i = fAxis; i < fSize; i++) {
0277 out << SP << "}\n";
0278 }
0279 out << "// Y = Scale o NormalizedX";
0280 for (size_t i = 0; i < fAxis; i++) {
0281 std::string iIdx = "axis_" + std::to_string(i);
0282 out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0283 << "; " << iIdx << "++){\n";
0284 }
0285 for (size_t j = fAxis; j < fSize; j++) {
0286 std::string jIdx = "axis_" + std::to_string(j);
0287 out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0288 << "; " << jIdx << "++){\n";
0289 }
0290 out << SP << SP << SP << "tensor_" << fNY << "[" << InputIndex << "] = tensor_" << fNScale;
0291 out << "[" << axesIndex << "] * static_cast<" << fType << ">(tensor_" << fNCastedX << "[" << InputIndex;
0292 out << "]);\n";
0293 for (size_t j = fAxis; j < fSize; j++) {
0294 out << SP << SP << "}\n";
0295 }
0296 for (size_t i = fAxis; i < fSize; i++) {
0297 out << SP << "}\n";
0298 }
0299 } else {
0300 out << SP << "// Y = Scale o InvStdDev (X - Mean)\n";
0301 for (size_t i = 0; i < fAxis; i++) {
0302 std::string iIdx = "axis_" + std::to_string(i);
0303 out << SP << "for (size_t " << iIdx << " = 0; " << iIdx << " < " << inputShape[i]
0304 << "; " << iIdx << "++){\n";
0305 }
0306 for (size_t j = fAxis; j < fSize; j++) {
0307 std::string jIdx = "axis_" + std::to_string(j);
0308 out << SP << SP << "for (size_t " << jIdx << " = 0; " << jIdx << " < " << inputShape[j]
0309 << "; " << jIdx << "++){\n";
0310 }
0311 out << SP << SP << SP << "tensor_" << fNY << "[" << InputIndex << "] = tensor_" << fNScale;
0312 out << "[" << normalizedIndex << "] * tensor_" << fNInvStdDev << "[" << axesIndex;
0313 out << "] * (tensor_" << fNX << "[" << InputIndex << "] - tensor_" << fNMean << "[";
0314 out << axesIndex << "]);\n";
0315 for (size_t j = fAxis; j < fSize; j++) {
0316 out << SP << SP << "}\n";
0317 }
0318 for (size_t i = fAxis; i < fSize; i++) {
0319 out << SP << "}\n";
0320 }
0321 }
0322
0323 if (!fNB.empty()) {
0324 std::string bias = "tensor_" + (fNBroadcastedB.empty() ? fNB : fNBroadcastedB);
0325 out << SP << "// Add the bias to Y\n";
0326 out << SP << "int " << opName << "_n = " << fLength << ";\n";
0327 out << SP << "float " << opName << "_alpha = 1.;\n";
0328 out << SP << "int " << opName << "_inc = 1;\n";
0329 out << SP << "BLAS::saxpy_(&" << opName << "_n, &" << opName << "_alpha, " << bias << ", &";
0330 out << opName << "_inc, " << "tensor_" << fNY << ", &" << opName << "_inc);\n";
0331 }
0332
0333 return out.str();
0334 }
0335
0336 std::vector<std::string> GetBlasRoutines() override { return { std::string("Axpy") }; }
0337
0338 std::vector<std::string> GetStdLibs() override { return { std::string("cmath") }; }
0339 };
0340
0341 }
0342 }
0343 }
0344
0345 #endif