File indexing completed on 2025-01-18 10:11:08
0001 #ifndef TMVA_SOFIE_ROPERATOR_Softmax
0002 #define TMVA_SOFIE_ROPERATOR_Softmax
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 template <typename T>
0015 class ROperator_Softmax final : public ROperator {
0016
0017 private:
0018 int64_t fAttrAxis;
0019
0020 std::string fNX;
0021 std::string fNY;
0022 std::vector<size_t> fShape;
0023
0024 std::string fType;
0025
0026 public:
0027 ROperator_Softmax() {}
0028 ROperator_Softmax(int64_t attr_axis, std::string nameX, std::string nameY)
0029 : fAttrAxis(attr_axis), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY))
0030 {
0031 }
0032
0033 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) { return input; }
0034
0035 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input)
0036 {
0037 auto ret = input;
0038 return ret;
0039 }
0040
0041 void Initialize(RModel &model)
0042 {
0043 if (model.CheckIfTensorAlreadyExist(fNX) ==
0044 false) {
0045 throw std::runtime_error("TMVA SOFIE Softmax Op Input Tensor is not found in model");
0046 }
0047 fShape = model.GetTensorShape(fNX);
0048 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0049 fType = ConvertTypeToString(model.GetTensorType(fNX));
0050 }
0051
0052 std::string Generate(std::string OpName)
0053 {
0054 OpName = "op_" + OpName;
0055 if (fShape.empty()) {
0056 throw std::runtime_error("TMVA SOFIE Operator Softmax called to Generate without being initialized first");
0057 }
0058 std::stringstream out;
0059 size_t size = fShape.size();
0060 size_t length = ConvertShapeToLength(fShape);
0061 size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0062 out << "\n" << SP << "//------ SOFTMAX\n";
0063 if (size == 1) {
0064 out << SP << fType << " sum = 0.0;\n";
0065 out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0066 out << SP << SP << "tensor_" << fNY << "[i] = std::exp(tensor_" << fNX << "[i]);\n";
0067 out << SP << SP << "sum += tensor_" << fNY << "[i];\n";
0068 out << SP << "}\n";
0069 out << SP << "for (size_t i = 0; i < " << length << " ; i++){\n";
0070 out << SP << SP << "tensor_" << fNY << "[i] /= sum;\n";
0071 out << SP << "}\n";
0072 } else {
0073 size_t batch = fShape[0];
0074 size_t channel = fShape[1];
0075 size_t width = (size > 2) ? fShape[size - 1] : 1;
0076 size_t height = (size > 3) ? fShape[size - 2] : 1;
0077 size_t depth = (size > 4) ? fShape[size - 3] : 1;
0078 size_t hStride = width;
0079 size_t dStride = height * width;
0080 size_t cStride = depth * dStride;
0081 size_t bStride = channel * cStride;
0082
0083 size_t N = 0;
0084 size_t iStride = 0;
0085 if (axis == 0) {
0086 N = batch;
0087 iStride = bStride;
0088 } else if (axis == 1) {
0089 N = channel;
0090 iStride = cStride;
0091 } else if (axis == size - 1) {
0092 N = width;
0093 iStride = 1;
0094 } else if (size > 3 && axis == size - 2) {
0095 N = height;
0096 iStride = hStride;
0097 } else if (size == 5 && axis == size - 3) {
0098 N = depth;
0099 iStride = dStride;
0100 } else {
0101 throw
0102 std::runtime_error("TMVA::SOFIE - Softmax operator along the axis "
0103 + std::to_string(fAttrAxis) + " with " + std::to_string(size)
0104 + "d input tensor not supported.");
0105 }
0106
0107 bool notBatch = axis != 0;
0108 bool notChannel = axis != 1;
0109 bool notDepth = (size == 5 && axis != 2);
0110 bool notHeight = (size == 5 && axis != 3) || (size == 4 && axis != 2);
0111 bool notWidth = (size == 5 && axis != 4) || (size == 4 && axis != 3) || (size == 3 && axis != 2);
0112
0113 if (notBatch) {
0114 out << SP << "for (size_t n = 0; n < " << batch << " ; n++){\n";
0115 }
0116 if (notChannel) {
0117 out << SP << SP << "for (size_t c = 0; c < " << channel << " ; c++){\n";
0118 }
0119 if (notDepth) {
0120 out << SP << SP << "for (size_t d = 0; d < " << depth << " ; d++){\n";
0121 }
0122 if (notHeight) {
0123 out << SP << SP << "for (size_t h = 0; h < " << height << " ; h++){\n";
0124 }
0125 if (notWidth) {
0126 out << SP << SP << "for (size_t w = 0; w < " << width << " ; w++){\n";
0127 }
0128 out << SP << SP << SP << fType << " sum = 0.;\n";
0129 out << SP << SP << SP << "size_t index = 0";
0130 if (notBatch) {
0131 out << "+ n * " << bStride;
0132 }
0133 if (notChannel) {
0134 out << "+ c * " << cStride;
0135 }
0136 if (notDepth) {
0137 out << "+ d * " << dStride;
0138 }
0139 if (notHeight) {
0140 out << "+ h * " << hStride;
0141 }
0142 if (notWidth) {
0143 out << " + w";
0144 }
0145 out << ";\n";
0146
0147 out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0148 out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] = std::exp(tensor_" << fNX
0149 << "[index + i*" << iStride << "]);\n";
0150 out << SP << SP << SP << SP << "sum += tensor_" << fNY << "[index + i*" << iStride << "];\n";
0151 out << SP << SP << SP << "}\n";
0152 out << SP << SP << SP << "for (size_t i = 0; i < " << N << "; i++) {\n";
0153 out << SP << SP << SP << SP << "tensor_" << fNY << "[index + i*" << iStride << "] /= sum;\n";
0154 out << SP << SP << SP << "}\n";
0155 if (notWidth) {
0156 out << SP << SP << "}\n";
0157 }
0158 if (notHeight) {
0159 out << SP << SP << "}\n";
0160 }
0161 if (notDepth) {
0162 out << SP << SP << "}\n";
0163 }
0164 if (notChannel) {
0165 out << SP << SP << "}\n";
0166 }
0167 if (notBatch) {
0168 out << SP << "}\n";
0169 }
0170 }
0171 return out.str();
0172 }
0173 };
0174
0175 }
0176 }
0177 }
0178
0179 #endif