Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:32:39

0001 #ifndef TMVA_SOFIE_ROPERATOR_Softmax
0002 #define TMVA_SOFIE_ROPERATOR_Softmax
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 template <typename T>
0015 class ROperator_Softmax final : public ROperator {
0016 
0017 private:
0018    int64_t fAttrAxis;
0019 
0020    std::string fNX;
0021    std::string fNY;
0022    std::vector<size_t> fShape;
0023 
0024    std::string fType;
0025 
0026 public:
0027    ROperator_Softmax() {}
0028    ROperator_Softmax(int64_t attr_axis, std::string nameX, std::string nameY)
0029       : fAttrAxis(attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0030    {
0031          fInputTensorNames = { fNX };
0032          fOutputTensorNames = { fNY };
0033    }
0034 
0035    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0036 
0037    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0038       auto ret = input; // suggest copy to compiler
0039       return ret;
0040    }
0041 
0042    void Initialize(RModel& model) override {
0043       if (model.CheckIfTensorAlreadyExist(fNX) ==
0044           false) { // input must be a graph input, or already initialized intermediate tensor
0045          throw std::runtime_error("TMVA SOFIE Softmax Op Input Tensor is not found in model");
0046       }
0047       fShape = model.GetTensorShape(fNX);
0048       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0049       fType = ConvertTypeToString(model.GetTensorType(fNX));
0050       if (model.Verbose()) {
0051          std::cout << "Softmax -> " << fNY << " " << ConvertShapeToString(fShape) << std::endl;
0052       }
0053    }
0054 
0055    std::string Generate(std::string OpName) override {
0056       OpName = "op_" + OpName;
0057       if (fShape.empty()) {
0058          throw std::runtime_error("TMVA SOFIE Operator Softmax called to Generate without being initialized first");
0059       }
0060       std::stringstream out;
0061       size_t size = fShape.size();
0062       size_t length = ConvertShapeToLength(fShape);
0063       size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0064       out << "\n" << SP << "//------ SOFTMAX - " << size << "  " << length << "  " << axis << "\n";
0065       // use safe numerically implementation by subtracting max of tensor
0066       if (size == 1) {
0067          out << SP << fType << " vmax = tensor_" << fNX << "[0];\n";
0068          out << SP << "for (size_t i = 1; i < " << length << " ; i++){\n";
0069          out << SP << SP << "if (tensor_" << fNX << "[i] > vmax) vmax = tensor_" << fNX << "[i];\n";
0070          out << SP << "}\n";
0071          out << SP << fType << " sum = 0.0;\n";
0072          out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0073          out << SP << SP << "tensor_" << fNY << "[i] = std::exp(tensor_" << fNX << "[i] - vmax);\n";
0074          out << SP << SP << "sum += tensor_" << fNY << "[i];\n";
0075          out << SP << "}\n";
0076          out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0077          out << SP << SP << "tensor_" << fNY << "[i] /= sum;\n";
0078          out << SP << "}\n";
0079       } else {
0080          size_t batch = fShape[0];
0081          size_t channel = fShape[1];
0082          size_t width = (size > 2) ? fShape[size - 1] : 1;
0083          size_t height = (size > 3) ? fShape[size - 2] : 1;
0084          size_t depth = (size > 4) ? fShape[size - 3] : 1;
0085          size_t hStride = width;
0086          size_t dStride = height * width;
0087          size_t cStride = depth * dStride;
0088          size_t bStride = channel * cStride;
0089 
0090          size_t N = 0; // Size of the axis
0091          size_t iStride = 0;
0092          if (axis == 0) {
0093             N = batch;
0094             iStride = bStride;
0095          } else if (axis == 1) {
0096             N = channel;
0097             iStride = cStride;
0098          } else if (axis == size - 1) {
0099             N = width;
0100             iStride = 1;
0101          } else if (size > 3 && axis == size - 2) {
0102             N = height;
0103             iStride = hStride;
0104          } else if (size == 5 && axis == size - 3) {
0105             N = depth;
0106             iStride = dStride;
0107          } else {
0108             throw
0109                std::runtime_error("TMVA::SOFIE - Softmax operator along the axis "
0110                   + std::to_string(fAttrAxis) + " with " + std::to_string(size)
0111                   + "d input tensor not supported.");
0112          }
0113 
0114          bool notBatch = axis != 0;
0115          bool notChannel = axis != 1;
0116          bool notDepth = (size == 5 && axis != 2);
0117          bool notHeight = (size == 5 && axis != 3) || (size == 4 && axis != 2);
0118          bool notWidth = (size == 5 && axis != 4) || (size == 4 && axis != 3) || (size == 3 && axis != 2);
0119 
0120          if (notBatch) {
0121             out << SP << "for (size_t n = 0; n < " << batch << " ; n++){\n";
0122          }
0123          if (notChannel) {
0124             out << SP << SP << "for (size_t c = 0; c < " << channel << " ; c++){\n";
0125          }
0126          if (notDepth) {
0127             out << SP << SP << "for (size_t d = 0; d < " << depth << " ; d++){\n";
0128          }
0129          if (notHeight) {
0130             out << SP << SP << "for (size_t h = 0; h < " << height << " ; h++){\n";
0131          }
0132          if (notWidth) {
0133             out << SP << SP << "for (size_t w = 0; w < " << width << " ; w++){\n";
0134          }
0135          out << SP << SP << SP << fType << " sum = 0.;\n";
0136          out << SP << SP << SP << "size_t index = 0";
0137          if (notBatch) {
0138             out << " + n * " << bStride;
0139          }
0140          if (notChannel) {
0141             out << "+ c * " << cStride;
0142          }
0143          if (notDepth) {
0144             out << " + d * " << dStride;
0145          }
0146          if (notHeight) {
0147             out << " + h * " << hStride;
0148          }
0149          if (notWidth) {
0150             out << " + w";
0151          }
0152          out << ";\n";
0153          // apply softmax along the axis - find first maximum value for numerical stability
0154          if (N == 0)
0155             throw std::runtime_error("TMVA::SOFIE - Softmax operator is along axis with zero elements");
0156          out << SP << SP << SP << fType << " vmax = tensor_" << fNX << "[index];\n";
0157          out << SP << SP << SP << "for (size_t i = 1; i < " << N << "; i++) {\n";
0158          out << SP << SP << SP << SP << "if (tensor_" << fNX << "[index + i*" << iStride << "] > vmax)\n";
0159          out << SP << SP << SP << SP << SP << "vmax = tensor_" << fNX << "[index + i*" << iStride << "];\n";
0160          out << SP << SP << SP << "}\n";
0161          out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0162          out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] = std::exp(tensor_" << fNX
0163              << "[index + i*" << iStride << "] - vmax);\n";
0164          out << SP << SP << SP << SP << "sum += tensor_" << fNY << "[index + i*" << iStride << "];\n";
0165          out << SP << SP << SP << "}\n";
0166          out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0167          out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] /= sum;\n";
0168          out << SP << SP << SP << "}\n";
0169          if (notWidth) {
0170             out << SP << SP << "}\n"; // end w
0171          }
0172          if (notHeight) {
0173             out << SP << SP << "}\n"; // end h
0174          }
0175          if (notDepth) {
0176             out << SP << SP << "}\n"; // end d
0177          }
0178          if (notChannel) {
0179             out << SP << SP << "}\n"; // end c
0180          }
0181          if (notBatch) {
0182             out << SP << "}\n"; // end n
0183          }
0184       }
0185       return out.str();
0186    }
0187 };
0188 
0189 } // namespace SOFIE
0190 } // namespace Experimental
0191 } // namespace TMVA
0192 
0193 #endif // TMVA_SOFIE_ROPERATOR_Softmax