Warning, file /include/root/TMVA/ROperator_Einsum.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROperator_Einsum
0002 #define TMVA_SOFIE_ROperator_Einsum
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009 #include <cassert>
0010
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014
0015
0016
0017 template<typename T>
0018 class ROperator_Einsum final : public ROperator{
0019 private:
0020
0021 bool fIsInputBoolTensor = false;
0022
0023
0024 std::vector<std::string> fNInputs;
0025 std::string fNY;
0026
0027 std::vector<std::string> fInputLabels;
0028 std::string fOutputLabels;
0029 std::string fSumLabels;
0030 std::string fGemmType;
0031
0032 std::vector<int> fSumDims;
0033
0034 std::vector<std::vector<size_t>> fShapeInputs;
0035 std::vector<size_t> fShapeY;
0036
0037
0038
0039
0040 public:
0041 ROperator_Einsum(){}
0042 ROperator_Einsum(const std::string & equation, const std::vector<std::string> & namesX, const std::string & nameY):
0043 fNInputs(namesX.size()), fNY(UTILITY::Clean_name(nameY))
0044 {
0045 for (size_t i = 0; i < namesX.size(); i++)
0046 fNInputs[i] = UTILITY::Clean_name(namesX[i]);
0047
0048
0049 if (!ParseEquation(equation))
0050 throw std::runtime_error("TMVA SOFIE Einsum Op: Error parsing the equation " + equation);
0051
0052 fInputTensorNames.resize(fNInputs.size());
0053 std::transform(fNInputs.begin(), fNInputs.end(), fInputTensorNames.begin(),
0054 [](const std::string& s) -> std::string_view { return s; });
0055 fOutputTensorNames = { fNY };
0056 }
0057
0058 bool ParseEquation(const std::string & input_equation) {
0059 std::string eq (input_equation);
0060
0061 eq.erase(std::remove(eq.begin(), eq.end(), ' '), eq.end());
0062
0063 std::string target("->");
0064 size_t pos = eq.find(target);
0065 if (pos == std::string::npos) {
0066 std::cout << "'->' not found in the equation." << std::endl;
0067 return false;
0068 }
0069
0070 std::string inputStr = eq.substr(0, pos);
0071
0072 std::string outputStr = eq.substr(pos + target.length());
0073
0074
0075 size_t start = 0;
0076 size_t pos1 = 0;
0077
0078 while ((pos1 = inputStr.find(',', start)) != std::string::npos) {
0079 std::string labels = inputStr.substr(start, pos1 - start);
0080 fInputLabels.push_back(labels);
0081 start = pos1 + 1;
0082 }
0083
0084 fInputLabels.push_back(inputStr.substr(start));
0085
0086
0087 auto checkLabel = [](const std::string & label) {
0088 for (char c : label) {
0089 if (!std::isalnum(c)) {
0090 std::cout << "Wrong tensor label " << label << std::endl;
0091 return false;
0092 }
0093 }
0094
0095 return true;
0096 };
0097 for (auto & label : fInputLabels) {
0098 if (!checkLabel(label)) return false;
0099 }
0100 if (!checkLabel(outputStr)) {
0101 std::cout << "invalid output label" << std::endl;
0102 return false;
0103 }
0104 fOutputLabels = outputStr;
0105
0106 if (fInputLabels.size() != fNInputs.size()) {
0107 std::cout << "Invalid number of input labels found " << fInputLabels.size() << " for #inputs = " << fNInputs.size() << std::endl;
0108 return false;
0109 }
0110
0111 return true;
0112 }
0113
0114
0115 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0116 return input;
0117 }
0118
0119
0120 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0121
0122 auto ret = std::vector<std::vector<size_t>>(1, input[0]);
0123 return ret;
0124 }
0125
0126 void Initialize(RModel& model) override {
0127
0128 size_t i = 0;
0129 std::map<char, int> labelsMap;
0130 for ( auto & name : fNInputs) {
0131 if (!model.CheckIfTensorAlreadyExist(name))
0132 throw std::runtime_error(std::string("TMVA SOFIE Einsum Op Input Tensor ") + name + "is not found in model");
0133
0134
0135
0136
0137 auto shape = model.GetTensorShape(name);
0138 fShapeInputs.push_back(shape);
0139
0140
0141 std::string labels = fInputLabels[i];
0142 for (size_t j = 0; j < shape.size(); j++) {
0143 if (j >= labels.length()) {
0144 throw std::runtime_error(std::string("TMVA SOFIE Einsum Op Input Tensor has invalid label or shape ") + labels + " " + ConvertShapeToString(shape));
0145 }
0146 labelsMap[labels[j]] = shape[j];
0147 }
0148 i++;
0149 }
0150
0151 for (char l : fOutputLabels) {
0152 if (labelsMap.count(l) == 0)
0153 throw std::runtime_error(std::string("TMVA SOFIE Einsum Op : output label ") + std::string(&l) + " is not present in inputs");
0154 fShapeY.push_back(labelsMap[l]);
0155 }
0156
0157
0158 fSumLabels = "";
0159 fSumDims.clear();
0160 for (auto & l : labelsMap) {
0161 if (fOutputLabels.find(l.first) == std::string::npos) {
0162 fSumLabels += l.first;
0163 fSumDims.push_back(l.second);
0164 }
0165 }
0166
0167
0168
0169 if (fNInputs.size() == 2 && fSumDims.size() == 1 && fShapeInputs[0].size() >=2 && fShapeInputs[1].size() >= 2 ) {
0170
0171 char l = fSumLabels[0];
0172 size_t pos1 = fInputLabels[0].find(l);
0173 size_t pos2 = fInputLabels[1].find(l);
0174
0175
0176 if (pos1 == fInputLabels[0].length() - 1 && pos2 == fInputLabels[1].length() - 2)
0177 fGemmType = "nn";
0178 else if (pos1 == fInputLabels[0].length() - 2 && pos2 == fInputLabels[1].length() - 2)
0179 fGemmType = "tn";
0180 else if (pos1 == fInputLabels[0].length() - 1 && pos2 == fInputLabels[1].length() - 1)
0181 fGemmType = "nt";
0182 else if (pos1 == fInputLabels[0].length() - 2 && pos2 == fInputLabels[1].length() - 1)
0183 fGemmType = "tt";
0184 else
0185 fGemmType = "";
0186 }
0187
0188 model.AddIntermediateTensor(fNY, model.GetTensorType(fNInputs[0]), fShapeY);
0189
0190 if (model.Verbose()) {
0191 std::cout << "Einsum op ";
0192 for (i = 0; i < fNInputs.size(); i++) {
0193 if (i > 0) std::cout << ", ";
0194 std::cout << fNInputs[i] << " " << ConvertShapeToString(fShapeInputs[i]) << " " << fInputLabels[i];
0195 }
0196 std::cout << " --> " << fNY << " " << ConvertShapeToString(fShapeY) << " " << fOutputLabels << std::endl;
0197 }
0198
0199 }
0200
0201 std::string GenerateInitCode() override {
0202 std::stringstream out;
0203 return out.str();
0204 }
0205
0206 std::string Generate(std::string opName) override {
0207
0208 if (fIsOutputConstant) return "";
0209
0210 opName = "op_" + opName;
0211
0212 if (fShapeY.size() != fOutputLabels.length()) {
0213 throw std::runtime_error("TMVA SOFIE Einsum Op called to Generate without being initialized first");
0214 }
0215
0216
0217 auto tensorIndex = [](const std::vector<size_t> & stride, const std::string & labels) {
0218 std::stringstream strst;
0219 int dims = labels.length();
0220
0221 if (dims == 0) return std::string("0");
0222 assert (dims == (int) stride.size());
0223 for (int i = 0; i < dims-1; i++) {
0224 strst << stride[i] << "*" << std::string{labels[i]} << " + ";
0225 }
0226 strst << std::string{labels[dims-1]};
0227 return strst.str();
0228 };
0229
0230 std::stringstream out;
0231 out << SP << "\n//-------- Einsum \n";
0232
0233 auto outputStride = UTILITY::ComputeStrideFromShape(fShapeY);
0234
0235
0236 if (fGemmType.empty()) {
0237 int outDims = fShapeY.size();
0238 int inDims = fSumLabels.length();
0239 assert(outDims == int(fOutputLabels.size()));
0240 assert(inDims == int(fSumDims.size()));
0241 for (int i = 0; i < outDims; i++) {
0242 for (int j = 0; j < i; j++) out << SP;
0243 std::string l {fOutputLabels[i]};
0244 out << "for (int " << l << " = 0; " << l << " < " << fShapeY[i] << "; " << l << "++) {\n";
0245 }
0246
0247 std::string outputIndex = tensorIndex(outputStride,fOutputLabels);
0248
0249 for (int j = 0; j < outDims; j++) out << SP;
0250 out << "tensor_" << fNY << "[" << outputIndex << "] = 0;\n";
0251
0252 for (int i = 0; i < inDims; i++) {
0253 for (int j = 0; j < outDims + i; j++) out << SP;
0254 std::string l {fSumLabels[i]};
0255 out << "for (int " << l << " = 0; " << l << " < " << fSumDims[i] << "; " << l << "++) {\n";
0256 }
0257 for (int j = 0; j < outDims+inDims; j++) out << SP;
0258
0259 out << "tensor_" << fNY << "[" << outputIndex << "] +=\n";
0260 for (size_t k = 0; k < fNInputs.size(); k++) {
0261 auto inputStride = UTILITY::ComputeStrideFromShape(fShapeInputs[k]);
0262 std::string inputIndex = tensorIndex(inputStride,fInputLabels[k]);
0263 for (int j = 0; j < outDims+inDims; j++) out << SP;
0264 out << SP << "tensor_" << fNInputs[k] << "[" << inputIndex << "]";
0265 if (fNInputs.size() > 1 && k < fNInputs.size() -1) out << " *\n";
0266 }
0267 out << ";\n";
0268
0269
0270 for (int i = outDims+inDims-1; i >= 0; i--) {
0271 for (int j = 0; j < i; j++) out << SP;
0272 out << "}\n";
0273 }
0274
0275
0276 } else {
0277
0278 out << SP << "// implementing Einsum using MatMul \n";
0279
0280 out << SP << "char " << opName << "_transA = '" << fGemmType[0] << "';\n";
0281 out << SP << "char " << opName << "_transB = '" << fGemmType[1] << "';\n";
0282
0283 int64_t dimA = fShapeInputs[0].size();
0284 int64_t dimB = fShapeInputs[1].size();
0285
0286 auto m = (fGemmType[0] == 't') ? fShapeInputs[0][dimA-1] : fShapeInputs[0][dimA-2];
0287 auto n = (fGemmType[1] == 't') ? fShapeInputs[1][dimB-2] : fShapeInputs[1][dimB-1];
0288 auto k = (fGemmType[0] == 't') ? fShapeInputs[0][dimA-2] : fShapeInputs[0][dimA-1];
0289
0290 out << SP << "int " << opName << "_m = " << m << ";\n";
0291 out << SP << "int " << opName << "_n = " << n << ";\n";
0292 out << SP << "int " << opName << "_k = " << k << ";\n";
0293 out << SP << "float " << opName << "_alpha = 1.0;\n";
0294 out << SP << "float " << opName << "_beta = 0.0;\n";
0295 out << SP << "int " << opName << "_lda = " << ((fGemmType[0] == 't') ? m : k) << ";\n";
0296 out << SP << "int " << opName << "_ldb = " << ((fGemmType[1] == 't') ? k : n) << ";\n";
0297
0298 auto inputStrideA = UTILITY::ComputeStrideFromShape(fShapeInputs[0]);
0299 auto inputStrideB = UTILITY::ComputeStrideFromShape(fShapeInputs[1]);
0300
0301 int stackDims = fShapeY.size()-2;
0302 for (int i = 0; i < stackDims; i++) {
0303 for (int j = 0; j < i; j++) out << SP;
0304 std::string l {fOutputLabels[i]};
0305 out << "for (int " << l << " = 0; " << l << " < " << fShapeY[i] << "; " << l << "++) {\n";
0306 }
0307 auto tensorOffset = [](const std::vector<size_t> & stride, const std::string & labels) {
0308 std::stringstream strst;
0309 int dims = labels.length()-2;
0310
0311 if (dims == 0) return std::string("0");
0312 assert (dims +2 == (int) stride.size());
0313 for (int i = 0; i < dims; i++) {
0314 strst << stride[i] << "*" << std::string{labels[i]};
0315 if (i < dims-1) strst << " + ";
0316 }
0317 return strst.str();
0318 };
0319
0320 out << SP << "BLAS::sgemm_(&" << opName << "_transB, &" << opName << "_transA, &" << opName
0321 << "_n, &" << opName << "_m, &" << opName << "_k, &" << opName << "_alpha, "
0322 << "&tensor_" << fNInputs[1] << "[" << tensorOffset(inputStrideB, fInputLabels[1])
0323 << "], &" << opName << "_ldb, "
0324 << "&tensor_" << fNInputs[0] << "[" << tensorOffset(inputStrideA, fInputLabels[0] ) << "], &" << opName << "_lda, &" << opName << "_beta, "
0325 << "&tensor_" << fNY << "[" << tensorOffset(outputStride,fOutputLabels) << "], &" << opName << "_n);\n";
0326
0327
0328 for (int i = stackDims-1; i >= 0; i--) {
0329 for (int j = 0; j < i; j++) out << SP;
0330 out << "}\n";
0331 }
0332
0333 }
0334
0335
0336 return out.str();
0337 }
0338
0339 std::vector<std::string> GetBlasRoutines() override {
0340 return { std::string("Gemm") };
0341 }
0342 };
0343
0344 }
0345 }
0346 }
0347
0348
0349 #endif