Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_Einsum.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROperator_Einsum
0002 #define TMVA_SOFIE_ROperator_Einsum
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 #include <cassert>
0010 
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014 
0015 
0016 
0017 template<typename T>
0018 class ROperator_Einsum final : public ROperator{
0019 private:
0020 
0021    bool fIsInputBoolTensor = false;
0022 
0023 
0024    std::vector<std::string> fNInputs;
0025    std::string fNY;
0026 
0027    std::vector<std::string> fInputLabels;
0028    std::string fOutputLabels;
0029    std::string fSumLabels;  // string containing the reducing labels
0030    std::string fGemmType;
0031 
0032    std::vector<int> fSumDims; // dimension of the labels we use to perform summing
0033 
0034    std::vector<std::vector<size_t>> fShapeInputs;
0035    std::vector<size_t> fShapeY;
0036 
0037 
0038 
0039 
0040 public:
0041    ROperator_Einsum(){}
0042    ROperator_Einsum(const std::string & equation, const std::vector<std::string> & namesX, const std::string & nameY):
0043       fNInputs(namesX.size()), fNY(UTILITY::Clean_name(nameY))
0044    {
0045       for (size_t i = 0; i < namesX.size(); i++)
0046          fNInputs[i] = UTILITY::Clean_name(namesX[i]);
0047 
0048       // parse teh equations to find labels
0049       if (!ParseEquation(equation))
0050          throw std::runtime_error("TMVA SOFIE Einsum Op: Error parsing the equation " + equation);
0051 
0052       fInputTensorNames.resize(fNInputs.size());
0053       std::transform(fNInputs.begin(), fNInputs.end(), fInputTensorNames.begin(),
0054                   [](const std::string& s) -> std::string_view { return s; });
0055       fOutputTensorNames = { fNY };
0056    }
0057 
0058    bool ParseEquation(const std::string & input_equation) {
0059       std::string eq (input_equation);
0060       // remove blank spaces
0061       eq.erase(std::remove(eq.begin(), eq.end(), ' '), eq.end());
0062       // look for '->'  finding the first occurrence
0063       std::string target("->");
0064       size_t pos = eq.find(target);
0065       if (pos == std::string::npos) {
0066          std::cout << "'->' not found in the equation." << std::endl;
0067          return false;
0068       }
0069       // Substring before the target
0070       std::string inputStr = eq.substr(0, pos);
0071       // Substring after the target
0072       std::string outputStr = eq.substr(pos + target.length());
0073 
0074       // look now for the group of labels separated by "," in the inputs
0075       size_t start = 0;
0076       size_t pos1 = 0;
0077       // Extract labels separated by commas
0078       while ((pos1 = inputStr.find(',', start)) != std::string::npos) {
0079          std::string labels = inputStr.substr(start, pos1 - start);
0080          fInputLabels.push_back(labels);
0081          start = pos1 + 1; // Move past the comma
0082       }
0083       // Add the last label (after the final comma)
0084       fInputLabels.push_back(inputStr.substr(start));
0085 
0086       // check if labels are ok and do not contain alphanumeric characters
0087       auto checkLabel = [](const std::string & label) {
0088          for (char c : label) {
0089             if (!std::isalnum(c)) {
0090                std::cout << "Wrong tensor label " << label << std::endl;
0091                return false;
0092             }
0093          }
0094          // empty label is OK , is a scalar
0095          return true;
0096       };
0097       for (auto & label : fInputLabels) {
0098          if (!checkLabel(label)) return false;
0099       }
0100       if (!checkLabel(outputStr)) {
0101          std::cout << "invalid output label" << std::endl;
0102          return false;
0103       }
0104       fOutputLabels = outputStr;
0105 
0106       if (fInputLabels.size() != fNInputs.size()) {
0107          std::cout << "Invalid number of input labels found " << fInputLabels.size() << " for #inputs = " << fNInputs.size() << std::endl;
0108          return false;
0109       }
0110       // ignore for the time being broadcasting, empty output label and other features
0111       return true;
0112    }
0113 
0114    // type of output given input
0115    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0116       return input;
0117    }
0118 
0119    // shape of output tensors given input tensors
0120    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0121       // assume now inputs have same shape (no broadcasting)
0122       auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
0123       return ret;
0124    }
0125 
0126    void Initialize(RModel& model) override {
0127       // input must be a graph input, or already initialized intermediate tensor
0128       size_t i = 0;
0129       std::map<char, int> labelsMap;
0130       for ( auto & name : fNInputs) {
0131          if (!model.CheckIfTensorAlreadyExist(name))
0132             throw std::runtime_error(std::string("TMVA SOFIE Einsum Op Input Tensor ") + name + "is not found in model");
0133 
0134          // if (model.IsDynamicTensor(name) || model.IsDimInputTensor(name) ) {
0135          //    // not yet supported
0136          // } else {
0137          auto shape = model.GetTensorShape(name);
0138          fShapeInputs.push_back(shape);
0139          //}
0140          // fill the label maps
0141          std::string labels = fInputLabels[i];
0142          for (size_t j = 0; j < shape.size(); j++) {
0143             if (j >= labels.length()) {
0144                throw std::runtime_error(std::string("TMVA SOFIE Einsum Op Input Tensor has invalid label or shape ") + labels + " " + ConvertShapeToString(shape));
0145             }
0146             labelsMap[labels[j]] = shape[j];
0147          }
0148          i++;
0149       }
0150       // get output shape from label maps
0151       for (char l : fOutputLabels) {
0152          if (labelsMap.count(l) == 0)
0153             throw std::runtime_error(std::string("TMVA SOFIE Einsum Op : output label ") + std::string(&l) + " is not present in inputs");
0154          fShapeY.push_back(labelsMap[l]);
0155       }
0156       // we need to get the labels we are going to sum
0157       // these are the labels not present in the output
0158       fSumLabels = "";
0159       fSumDims.clear();
0160       for (auto & l : labelsMap) {
0161          if (fOutputLabels.find(l.first) == std::string::npos) {
0162             fSumLabels += l.first;
0163             fSumDims.push_back(l.second);
0164          }
0165       }
0166 
0167       // check if we can use MatMul for EinSum
0168       // need to have one sum labels in the last 2 and have the first in common
0169       if (fNInputs.size() == 2 && fSumDims.size() == 1 && fShapeInputs[0].size() >=2 && fShapeInputs[1].size() >= 2 ) {
0170          // find positions of dum labels
0171          char l = fSumLabels[0];
0172          size_t pos1 = fInputLabels[0].find(l);
0173          size_t pos2 = fInputLabels[1].find(l);
0174          // check if summing is done in the last 2 indices of tensor
0175 
0176          if (pos1 == fInputLabels[0].length() - 1 && pos2 == fInputLabels[1].length() - 2)
0177             fGemmType = "nn";
0178          else if (pos1 == fInputLabels[0].length() - 2 && pos2 == fInputLabels[1].length() - 2)
0179             fGemmType = "tn";
0180          else if (pos1 == fInputLabels[0].length() - 1 && pos2 == fInputLabels[1].length() - 1)
0181             fGemmType = "nt";
0182          else if (pos1 == fInputLabels[0].length() - 2 && pos2 == fInputLabels[1].length() - 1)
0183             fGemmType = "tt";
0184          else
0185             fGemmType = "";
0186       }
0187 
0188       model.AddIntermediateTensor(fNY, model.GetTensorType(fNInputs[0]), fShapeY);
0189 
0190       if (model.Verbose()) {
0191          std::cout << "Einsum op ";
0192          for (i = 0; i < fNInputs.size(); i++) {
0193             if (i > 0) std::cout << ", ";
0194             std::cout << fNInputs[i] << " " << ConvertShapeToString(fShapeInputs[i]) << " " << fInputLabels[i];
0195          }
0196          std::cout << " --> " << fNY << "  " << ConvertShapeToString(fShapeY) << "  " << fOutputLabels << std::endl;
0197       }
0198 
0199    }
0200 
0201    std::string GenerateInitCode() override {
0202       std::stringstream out;
0203       return out.str();
0204    }
0205 
0206    std::string Generate(std::string opName) override {
0207 
0208       if (fIsOutputConstant) return "";
0209 
0210       opName = "op_" + opName;
0211 
0212       if (fShapeY.size() != fOutputLabels.length()) {
0213          throw std::runtime_error("TMVA SOFIE Einsum Op called to Generate without being initialized first");
0214       }
0215 
0216       // function to write compute expression index from strides
0217       auto tensorIndex = [](const std::vector<size_t> & stride, const std::string & labels) {
0218          std::stringstream strst;
0219          int dims = labels.length();
0220          // scalar case
0221          if (dims == 0) return std::string("0");
0222          assert (dims == (int) stride.size());
0223          for (int i = 0; i < dims-1; i++) {
0224             strst << stride[i] << "*" << std::string{labels[i]} << " + ";
0225          }
0226          strst << std::string{labels[dims-1]};
0227          return strst.str();
0228       };
0229 
0230       std::stringstream out;
0231       out << SP << "\n//-------- Einsum   \n";
0232 
0233       auto outputStride = UTILITY::ComputeStrideFromShape(fShapeY);
0234 
0235       // loops on the output indices  i0,....iN
0236       if (fGemmType.empty()) {
0237       int outDims = fShapeY.size();
0238       int inDims = fSumLabels.length();
0239       assert(outDims == int(fOutputLabels.size()));
0240       assert(inDims == int(fSumDims.size()));
0241       for (int i = 0; i < outDims; i++) {
0242          for (int j = 0; j < i; j++) out << SP;
0243          std::string l {fOutputLabels[i]};
0244          out << "for (int " << l << " = 0; " << l << " < " << fShapeY[i] << "; " << l << "++) {\n";
0245       }
0246       // reset to zero output tensor
0247       std::string outputIndex = tensorIndex(outputStride,fOutputLabels);
0248 
0249       for (int j = 0; j < outDims; j++) out << SP;
0250       out << "tensor_" << fNY << "[" << outputIndex << "] = 0;\n";
0251       // loop on remaining indices where we perform the sum
0252       for (int i = 0; i < inDims; i++) {
0253          for (int j = 0; j < outDims + i; j++) out << SP;
0254          std::string l {fSumLabels[i]};
0255          out << "for (int " << l << " = 0; " << l << " < " << fSumDims[i] << "; " << l << "++) {\n";
0256       }
0257       for (int j = 0; j < outDims+inDims; j++) out << SP;
0258       // tensor_out[outId] += t_in_0[ind0] * t_in1[ind1] *....
0259       out << "tensor_" << fNY << "[" << outputIndex << "] +=\n";
0260       for (size_t k = 0; k < fNInputs.size(); k++) {
0261          auto inputStride = UTILITY::ComputeStrideFromShape(fShapeInputs[k]);
0262          std::string inputIndex = tensorIndex(inputStride,fInputLabels[k]);
0263          for (int j = 0; j < outDims+inDims; j++) out << SP;
0264          out << SP << "tensor_" << fNInputs[k] << "[" << inputIndex << "]";
0265          if (fNInputs.size() > 1 && k < fNInputs.size() -1) out << " *\n";
0266       }
0267       out << ";\n";
0268 
0269       // end loops on all indices i0,....iN
0270       for (int i = outDims+inDims-1; i >= 0; i--) {
0271          for (int j = 0; j < i; j++) out << SP;
0272          out << "}\n";
0273       }
0274 
0275 
0276       } else {
0277          // case we use Gemm
0278          out << SP << "// implementing Einsum using MatMul   \n";
0279          // note A is second input and B first one - due to transpose of Fortran rep.
0280          out << SP << "char " << opName << "_transA = '" << fGemmType[0] << "';\n";
0281          out << SP << "char " << opName << "_transB = '" << fGemmType[1] << "';\n";
0282          // need to consider case A and B have dim > 2 (for MatMul)
0283          int64_t dimA = fShapeInputs[0].size();
0284          int64_t dimB = fShapeInputs[1].size();
0285 
0286          auto m = (fGemmType[0] == 't') ? fShapeInputs[0][dimA-1] : fShapeInputs[0][dimA-2];
0287          auto n = (fGemmType[1] == 't') ? fShapeInputs[1][dimB-2] : fShapeInputs[1][dimB-1];
0288          auto k = (fGemmType[0] == 't') ? fShapeInputs[0][dimA-2] : fShapeInputs[0][dimA-1];
0289 
0290          out << SP << "int " << opName << "_m = " << m << ";\n";
0291          out << SP << "int " << opName << "_n = " << n << ";\n";
0292          out << SP << "int " << opName << "_k = " << k << ";\n";
0293          out << SP << "float " << opName << "_alpha = 1.0;\n";
0294          out << SP << "float " << opName << "_beta = 0.0;\n";
0295          out << SP << "int " << opName << "_lda = " << ((fGemmType[0] == 't') ? m : k) << ";\n";
0296          out << SP << "int " << opName << "_ldb = " << ((fGemmType[1] == 't') ? k : n) << ";\n";
0297 
0298          auto inputStrideA = UTILITY::ComputeStrideFromShape(fShapeInputs[0]);
0299          auto inputStrideB = UTILITY::ComputeStrideFromShape(fShapeInputs[1]);
0300 
0301          int stackDims = fShapeY.size()-2;
0302          for (int i = 0; i < stackDims; i++) {
0303             for (int j = 0; j < i; j++) out << SP;
0304             std::string l {fOutputLabels[i]};
0305             out << "for (int " << l << " = 0; " << l << " < " << fShapeY[i] << "; " << l << "++) {\n";
0306          }
0307          auto tensorOffset = [](const std::vector<size_t> & stride, const std::string & labels) {
0308             std::stringstream strst;
0309             int dims = labels.length()-2;
0310             // scalar case
0311             if (dims == 0) return std::string("0");
0312             assert (dims +2 == (int) stride.size());
0313             for (int i = 0; i < dims; i++) {
0314                strst << stride[i] << "*" << std::string{labels[i]};
0315                if (i < dims-1) strst << " + ";
0316             }
0317             return strst.str();
0318          };
0319          // only float type supported
0320          out << SP << "BLAS::sgemm_(&" << opName << "_transB, &" << opName << "_transA, &" << opName
0321              << "_n, &" << opName << "_m, &" << opName << "_k, &" << opName << "_alpha, "
0322              << "&tensor_" << fNInputs[1] << "[" << tensorOffset(inputStrideB, fInputLabels[1])
0323              << "], &" << opName << "_ldb, "
0324              << "&tensor_" << fNInputs[0] << "[" << tensorOffset(inputStrideA, fInputLabels[0]  ) << "], &" << opName << "_lda, &" << opName << "_beta, "
0325              << "&tensor_" << fNY << "[" << tensorOffset(outputStride,fOutputLabels) << "],  &" << opName << "_n);\n";
0326 
0327 
0328          for (int i = stackDims-1; i >= 0; i--) {
0329             for (int j = 0; j < i; j++) out << SP;
0330             out << "}\n";
0331          }
0332 
0333       }
0334 
0335 
0336       return out.str();
0337    }
0338 
0339    std::vector<std::string> GetBlasRoutines() override {
0340       return { std::string("Gemm") };
0341    }
0342 };
0343 
0344 }//SOFIE
0345 }//Experimental
0346 }//TMVA
0347 
0348 
0349 #endif //TMVA_SOFIE_ROperator_Einsum