File indexing completed on 2025-09-18 09:32:39
0001 #ifndef TMVA_SOFIE_ROPERATOR_Softmax
0002 #define TMVA_SOFIE_ROPERATOR_Softmax
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 template <typename T>
0015 class ROperator_Softmax final : public ROperator {
0016
0017 private:
0018 int64_t fAttrAxis;
0019
0020 std::string fNX;
0021 std::string fNY;
0022 std::vector<size_t> fShape;
0023
0024 std::string fType;
0025
0026 public:
0027 ROperator_Softmax() {}
0028 ROperator_Softmax(int64_t attr_axis, std::string nameX, std::string nameY)
0029 : fAttrAxis(attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0030 {
0031 fInputTensorNames = { fNX };
0032 fOutputTensorNames = { fNY };
0033 }
0034
0035 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0036
0037 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0038 auto ret = input;
0039 return ret;
0040 }
0041
0042 void Initialize(RModel& model) override {
0043 if (model.CheckIfTensorAlreadyExist(fNX) ==
0044 false) {
0045 throw std::runtime_error("TMVA SOFIE Softmax Op Input Tensor is not found in model");
0046 }
0047 fShape = model.GetTensorShape(fNX);
0048 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0049 fType = ConvertTypeToString(model.GetTensorType(fNX));
0050 if (model.Verbose()) {
0051 std::cout << "Softmax -> " << fNY << " " << ConvertShapeToString(fShape) << std::endl;
0052 }
0053 }
0054
0055 std::string Generate(std::string OpName) override {
0056 OpName = "op_" + OpName;
0057 if (fShape.empty()) {
0058 throw std::runtime_error("TMVA SOFIE Operator Softmax called to Generate without being initialized first");
0059 }
0060 std::stringstream out;
0061 size_t size = fShape.size();
0062 size_t length = ConvertShapeToLength(fShape);
0063 size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0064 out << "\n" << SP << "//------ SOFTMAX - " << size << " " << length << " " << axis << "\n";
0065
0066 if (size == 1) {
0067 out << SP << fType << " vmax = tensor_" << fNX << "[0];\n";
0068 out << SP << "for (size_t i = 1; i < " << length << " ; i++){\n";
0069 out << SP << SP << "if (tensor_" << fNX << "[i] > vmax) vmax = tensor_" << fNX << "[i];\n";
0070 out << SP << "}\n";
0071 out << SP << fType << " sum = 0.0;\n";
0072 out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0073 out << SP << SP << "tensor_" << fNY << "[i] = std::exp(tensor_" << fNX << "[i] - vmax);\n";
0074 out << SP << SP << "sum += tensor_" << fNY << "[i];\n";
0075 out << SP << "}\n";
0076 out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0077 out << SP << SP << "tensor_" << fNY << "[i] /= sum;\n";
0078 out << SP << "}\n";
0079 } else {
0080 size_t batch = fShape[0];
0081 size_t channel = fShape[1];
0082 size_t width = (size > 2) ? fShape[size - 1] : 1;
0083 size_t height = (size > 3) ? fShape[size - 2] : 1;
0084 size_t depth = (size > 4) ? fShape[size - 3] : 1;
0085 size_t hStride = width;
0086 size_t dStride = height * width;
0087 size_t cStride = depth * dStride;
0088 size_t bStride = channel * cStride;
0089
0090 size_t N = 0;
0091 size_t iStride = 0;
0092 if (axis == 0) {
0093 N = batch;
0094 iStride = bStride;
0095 } else if (axis == 1) {
0096 N = channel;
0097 iStride = cStride;
0098 } else if (axis == size - 1) {
0099 N = width;
0100 iStride = 1;
0101 } else if (size > 3 && axis == size - 2) {
0102 N = height;
0103 iStride = hStride;
0104 } else if (size == 5 && axis == size - 3) {
0105 N = depth;
0106 iStride = dStride;
0107 } else {
0108 throw
0109 std::runtime_error("TMVA::SOFIE - Softmax operator along the axis "
0110 + std::to_string(fAttrAxis) + " with " + std::to_string(size)
0111 + "d input tensor not supported.");
0112 }
0113
0114 bool notBatch = axis != 0;
0115 bool notChannel = axis != 1;
0116 bool notDepth = (size == 5 && axis != 2);
0117 bool notHeight = (size == 5 && axis != 3) || (size == 4 && axis != 2);
0118 bool notWidth = (size == 5 && axis != 4) || (size == 4 && axis != 3) || (size == 3 && axis != 2);
0119
0120 if (notBatch) {
0121 out << SP << "for (size_t n = 0; n < " << batch << " ; n++){\n";
0122 }
0123 if (notChannel) {
0124 out << SP << SP << "for (size_t c = 0; c < " << channel << " ; c++){\n";
0125 }
0126 if (notDepth) {
0127 out << SP << SP << "for (size_t d = 0; d < " << depth << " ; d++){\n";
0128 }
0129 if (notHeight) {
0130 out << SP << SP << "for (size_t h = 0; h < " << height << " ; h++){\n";
0131 }
0132 if (notWidth) {
0133 out << SP << SP << "for (size_t w = 0; w < " << width << " ; w++){\n";
0134 }
0135 out << SP << SP << SP << fType << " sum = 0.;\n";
0136 out << SP << SP << SP << "size_t index = 0";
0137 if (notBatch) {
0138 out << " + n * " << bStride;
0139 }
0140 if (notChannel) {
0141 out << "+ c * " << cStride;
0142 }
0143 if (notDepth) {
0144 out << " + d * " << dStride;
0145 }
0146 if (notHeight) {
0147 out << " + h * " << hStride;
0148 }
0149 if (notWidth) {
0150 out << " + w";
0151 }
0152 out << ";\n";
0153
0154 if (N == 0)
0155 throw std::runtime_error("TMVA::SOFIE - Softmax operator is along axis with zero elements");
0156 out << SP << SP << SP << fType << " vmax = tensor_" << fNX << "[index];\n";
0157 out << SP << SP << SP << "for (size_t i = 1; i < " << N << "; i++) {\n";
0158 out << SP << SP << SP << SP << "if (tensor_" << fNX << "[index + i*" << iStride << "] > vmax)\n";
0159 out << SP << SP << SP << SP << SP << "vmax = tensor_" << fNX << "[index + i*" << iStride << "];\n";
0160 out << SP << SP << SP << "}\n";
0161 out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0162 out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] = std::exp(tensor_" << fNX
0163 << "[index + i*" << iStride << "] - vmax);\n";
0164 out << SP << SP << SP << SP << "sum += tensor_" << fNY << "[index + i*" << iStride << "];\n";
0165 out << SP << SP << SP << "}\n";
0166 out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0167 out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] /= sum;\n";
0168 out << SP << SP << SP << "}\n";
0169 if (notWidth) {
0170 out << SP << SP << "}\n";
0171 }
0172 if (notHeight) {
0173 out << SP << SP << "}\n";
0174 }
0175 if (notDepth) {
0176 out << SP << SP << "}\n";
0177 }
0178 if (notChannel) {
0179 out << SP << SP << "}\n";
0180 }
0181 if (notBatch) {
0182 out << SP << "}\n";
0183 }
0184 }
0185 return out.str();
0186 }
0187 };
0188
0189 }
0190 }
0191 }
0192
0193 #endif