Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:08

0001 #ifndef TMVA_SOFIE_ROPERATOR_Softmax
0002 #define TMVA_SOFIE_ROPERATOR_Softmax
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 template <typename T>
0015 class ROperator_Softmax final : public ROperator {
0016 
0017 private:
0018    int64_t fAttrAxis;
0019 
0020    std::string fNX;
0021    std::string fNY;
0022    std::vector<size_t> fShape;
0023 
0024    std::string fType;
0025 
0026 public:
0027    ROperator_Softmax() {}
0028    ROperator_Softmax(int64_t attr_axis, std::string nameX, std::string nameY)
0029       : fAttrAxis(attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0030    {
0031    }
0032 
0033    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) { return input; }
0034 
0035    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input)
0036    {
0037       auto ret = input; // suggest copy to compiler
0038       return ret;
0039    }
0040 
0041    void Initialize(RModel &model)
0042    {
0043       if (model.CheckIfTensorAlreadyExist(fNX) ==
0044           false) { // input must be a graph input, or already initialized intermediate tensor
0045          throw std::runtime_error("TMVA SOFIE Softmax Op Input Tensor is not found in model");
0046       }
0047       fShape = model.GetTensorShape(fNX);
0048       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0049       fType = ConvertTypeToString(model.GetTensorType(fNX));
0050    }
0051 
0052    std::string Generate(std::string OpName)
0053    {
0054       OpName = "op_" + OpName;
0055       if (fShape.empty()) {
0056          throw std::runtime_error("TMVA SOFIE Operator Softmax called to Generate without being initialized first");
0057       }
0058       std::stringstream out;
0059       size_t size = fShape.size();
0060       size_t length = ConvertShapeToLength(fShape);
0061       size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0062       out << "\n" << SP << "//------ SOFTMAX\n";
0063       if (size == 1) {
0064          out << SP << fType << " sum = 0.0;\n";
0065          out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0066          out << SP << SP << "tensor_" << fNY << "[i] = std::exp(tensor_" << fNX << "[i]);\n";
0067          out << SP << SP << "sum += tensor_" << fNY << "[i];\n";
0068          out << SP << "}\n";
0069          out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0070          out << SP << SP << "tensor_" << fNY << "[i] /= sum;\n";
0071          out << SP << "}\n";
0072       } else {
0073          size_t batch = fShape[0];
0074          size_t channel = fShape[1];
0075          size_t width = (size > 2) ? fShape[size - 1] : 1;
0076          size_t height = (size > 3) ? fShape[size - 2] : 1;
0077          size_t depth = (size > 4) ? fShape[size - 3] : 1;
0078          size_t hStride = width;
0079          size_t dStride = height * width;
0080          size_t cStride = depth * dStride;
0081          size_t bStride = channel * cStride;
0082 
0083          size_t N = 0; // Size of the axis
0084          size_t iStride = 0;
0085          if (axis == 0) {
0086             N = batch;
0087             iStride = bStride;
0088          } else if (axis == 1) {
0089             N = channel;
0090             iStride = cStride;
0091          } else if (axis == size - 1) {
0092             N = width;
0093             iStride = 1;
0094          } else if (size > 3 && axis == size - 2) {
0095             N = height;
0096             iStride = hStride;
0097          } else if (size == 5 && axis == size - 3) {
0098             N = depth;
0099             iStride = dStride;
0100          } else {
0101             throw
0102                std::runtime_error("TMVA::SOFIE - Softmax operator along the axis "
0103                   + std::to_string(fAttrAxis) + " with " + std::to_string(size)
0104                   + "d input tensor not supported.");
0105          }
0106 
0107          bool notBatch = axis != 0;
0108          bool notChannel = axis != 1;
0109          bool notDepth = (size == 5 && axis != 2);
0110          bool notHeight = (size == 5 && axis != 3) || (size == 4 && axis != 2);
0111          bool notWidth = (size == 5 && axis != 4) || (size == 4 && axis != 3) || (size == 3 && axis != 2);
0112 
0113          if (notBatch) {
0114             out << SP << "for (size_t n = 0; n < " << batch << " ; n++){\n";
0115          }
0116          if (notChannel) {
0117             out << SP << SP << "for (size_t c = 0; c < " << channel << " ; c++){\n";
0118          }
0119          if (notDepth) {
0120             out << SP << SP << "for (size_t d = 0; d < " << depth << " ; d++){\n";
0121          }
0122          if (notHeight) {
0123             out << SP << SP << "for (size_t h = 0; h < " << height << " ; h++){\n";
0124          }
0125          if (notWidth) {
0126             out << SP << SP << "for (size_t w = 0; w < " << width << " ; w++){\n";
0127          }
0128          out << SP << SP << SP << fType << " sum = 0.;\n";
0129          out << SP << SP << SP << "size_t index = 0";
0130          if (notBatch) {
0131             out << "+ n * " << bStride;
0132          }
0133          if (notChannel) {
0134             out << "+ c * " << cStride;
0135          }
0136          if (notDepth) {
0137             out << "+ d * " << dStride;
0138          }
0139          if (notHeight) {
0140             out << "+ h * " << hStride;
0141          }
0142          if (notWidth) {
0143             out << " + w";
0144          }
0145          out << ";\n";
0146          // apply softmax along the axis
0147          out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0148          out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] = std::exp(tensor_" << fNX
0149              << "[index + i*" << iStride << "]);\n";
0150          out << SP << SP << SP << SP << "sum += tensor_" << fNY << "[index + i*" << iStride << "];\n";
0151          out << SP << SP << SP << "}\n";
0152          out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0153          out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] /= sum;\n";
0154          out << SP << SP << SP << "}\n";
0155          if (notWidth) {
0156             out << SP << SP << "}\n"; // end w
0157          }
0158          if (notHeight) {
0159             out << SP << SP << "}\n"; // end h
0160          }
0161          if (notDepth) {
0162             out << SP << SP << "}\n"; // end d
0163          }
0164          if (notChannel) {
0165             out << SP << SP << "}\n"; // end c
0166          }
0167          if (notBatch) {
0168             out << SP << "}\n"; // end n
0169          }
0170       }
0171       return out.str();
0172    }
0173 };
0174 
0175 } // namespace SOFIE
0176 } // namespace Experimental
0177 } // namespace TMVA
0178 
0179 #endif // TMVA_SOFIE_ROPERATOR_Softmax