File indexing completed on 2025-12-22 10:28:06
0001 #ifndef TMVA_SOFIE_ROPERATOR_Softmax
0002 #define TMVA_SOFIE_ROPERATOR_Softmax
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 template <typename T>
0015 class ROperator_Softmax final : public ROperator {
0016
0017 private:
0018 int64_t fAttrAxis;
0019
0020 std::string fNX;
0021 std::string fNY;
0022 std::vector<Dim> fShape;
0023
0024 std::string fType;
0025
0026 public:
0027 ROperator_Softmax() {}
0028 ROperator_Softmax(int64_t attr_axis, std::string nameX, std::string nameY)
0029 : fAttrAxis(attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0030 {
0031 fInputTensorNames = { fNX };
0032 fOutputTensorNames = { fNY };
0033 }
0034
0035 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override { return input; }
0036
0037 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0038 auto ret = input;
0039 return ret;
0040 }
0041
0042 void Initialize(RModel& model) override {
0043 if (model.CheckIfTensorAlreadyExist(fNX) ==
0044 false) {
0045 throw std::runtime_error("TMVA SOFIE Softmax Op Input Tensor is not found in model");
0046 }
0047 fShape = model.GetDimTensorShape(fNX);
0048 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0049 fType = ConvertTypeToString(model.GetTensorType(fNX));
0050 if (model.Verbose()) {
0051 std::cout << "Softmax -> " << fNY << " " << ConvertShapeToString(fShape) << std::endl;
0052 }
0053 }
0054
0055 std::string Generate(std::string OpName) override {
0056 OpName = "op_" + OpName;
0057 if (fShape.empty()) {
0058 throw std::runtime_error("TMVA SOFIE Operator Softmax called to Generate without being initialized first");
0059 }
0060 std::stringstream out;
0061 size_t size = fShape.size();
0062 auto length_str = ConvertDimShapeToLength(fShape);
0063 size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0064
0065
0066 if (axis == size - 1) {
0067 std::string axis_size = fShape[axis].GetVal();
0068 std::string num_rows;
0069 if (IsInteger(length_str) && IsInteger(axis_size)) {
0070 num_rows = std::to_string(std::stoul(length_str) / std::stoul(axis_size));
0071 } else {
0072 num_rows = "(" + length_str + ") / (" + axis_size + ")";
0073 }
0074
0075 out << "\n" << SP << "//------ SOFTMAX - " << size << " " << length_str << " " << axis << "\n";
0076 out << SP << "for (int i = 0; i < " << num_rows << "; ++i) {\n";
0077 out << SP << SP << "size_t offset = i * " << axis_size << ";\n";
0078 out << SP << SP << fType << " const * x_ptr = &tensor_" << fNX << "[offset];\n";
0079 out << SP << SP << fType << " * y_ptr = &tensor_" << fNY << "[offset];\n";
0080
0081 out << SP << SP << fType << " vmax = x_ptr[0];\n";
0082 out << SP << SP << "for (int j = 1; j < " << axis_size << "; ++j) {\n";
0083 out << SP << SP << SP << "if (x_ptr[j] > vmax) vmax = x_ptr[j];\n";
0084 out << SP << SP << "}\n";
0085
0086 out << SP << SP << fType << " sum = 0.0;\n";
0087 out << SP << SP << "for (int j = 0; j < " << axis_size << "; ++j) {\n";
0088 out << SP << SP << SP << "y_ptr[j] = std::exp(x_ptr[j] - vmax);\n";
0089 out << SP << SP << SP << "sum += y_ptr[j];\n";
0090 out << SP << SP << "}\n";
0091
0092 out << SP << SP << fType << " inv_sum = 1.0f / sum;\n";
0093 out << SP << SP << "for (int j = 0; j < " << axis_size << "; ++j) {\n";
0094 out << SP << SP << SP << "y_ptr[j] *= inv_sum;\n";
0095 out << SP << SP << "}\n";
0096 out << SP << "}\n";
0097
0098 } else {
0099 auto stride = UTILITY::ComputeStrideFromShape(fShape);
0100 size_t k = 0;
0101 std::vector<std::string> l(size);
0102 for (size_t i = 0; i < size; i++) {
0103 if (i != axis) {
0104 for (size_t j = 0; j < k; j++) out << SP;
0105 l[i] = std::string("i") + std::to_string(i);
0106 out << "for (int " << l[i] << " = 0; " << l[i] << " < " << fShape[i] << "; " << l[i] << "++) {\n";
0107 k++;
0108 }
0109 }
0110 for (size_t j = 0; j < size-1; j++) out << SP;
0111 out << fType << " sum = 0.;\n";
0112 for (size_t j = 0; j < size-1; j++) out << SP;
0113 out << "size_t index = ";
0114 bool first = true;
0115 for (size_t i = 0; i < size; i++) {
0116 if (i == axis) continue;
0117 if (!first) out << " + ";
0118 if (stride[i].GetVal() != "1")
0119 out << stride[i] << "*";
0120 out << l[i];
0121 first = false;
0122 }
0123 out << ";\n";
0124
0125 for (size_t j = 0; j < size-1; j++) out << SP;
0126 out << fType << " vmax = tensor_" << fNX << "[index];\n";
0127 for (size_t j = 0; j < size-1; j++) out << SP;
0128 out << "for (int i = 1; i < " << fShape[axis] << "; i++) {\n";
0129 for (size_t j = 0; j < size; j++) out << SP;
0130 out << fType << " x = tensor_" << fNX << "[index + i";
0131 if (stride[axis].GetVal() != "1") out << "*(" << stride[axis] << ")";
0132 out << "];\n";
0133 for (size_t j = 0; j < size; j++) out << SP;
0134 out << "if (x > vmax) vmax = x;\n";
0135 for (size_t j = 0; j < size-1; j++) out << SP;
0136 out << "}\n";
0137
0138 for (size_t j = 0; j < size-1; j++) out << SP;
0139 out << "for (int i = 0; i < " << fShape[axis] << "; i++) {\n";
0140 for (size_t j = 0; j < size; j++) out << SP;
0141 out << "size_t id = index + i";
0142 if (stride[axis].GetVal() != "1") out << "*(" << stride[axis] << ")";
0143 out << ";\n";
0144 for (size_t j = 0; j < size; j++) out << SP;
0145 out << "tensor_" << fNY << "[id] = std::exp(tensor_" << fNX << "[id] - vmax);\n";
0146 for (size_t j = 0; j < size; j++) out << SP;
0147 out << "sum += tensor_" << fNY << "[id];\n";
0148 for (size_t j = 0; j < size-1; j++) out << SP;
0149 out << "}\n";
0150
0151 for (size_t j = 0; j < size-1; j++) out << SP;
0152 out << "for (int i = 0; i < " << fShape[axis] << "; i++) {\n";
0153 for (size_t j = 0; j < size; j++) out << SP;
0154 out << "tensor_" << fNY << "[index + i";
0155 if (stride[axis].GetVal() != "1") out << "*(" << stride[axis] << ")";
0156 out << "] /= sum;\n";
0157 for (size_t j = 0; j < size-1; j++) out << SP;
0158 out << "}\n";
0159
0160 for (int i = static_cast<int>(k) - 1; i >= 0; i--) {
0161 for (int j = 0; j < i; j++) out << SP;
0162 out << "}\n";
0163 }
0164 }
0165 return out.str();
0166 }
0167 std::vector<std::string> GetStdLibs() override { return { std::string("cmath") }; }
0168 };
0169
0170 }
0171 }
0172 }
0173
0174 #endif