Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_Pad.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROPERATOR_Pad
0002 #define TMVA_SOFIE_ROPERATOR_Pad
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 template <typename T>
0015 class ROperator_Pad final : public ROperator
0016 {
0017 public:
0018    enum EMode { kConstant, kReflect, kEdge, kWrap };
0019 private:
0020 
0021    std::string fNX;
0022    std::string fNP;
0023    std::string fNCV;
0024    std::string fNAX;
0025    std::string fNY;
0026    T fConstantValue;
0027    EMode fMode;
0028    std::vector<size_t> fInputShape;
0029    std::vector<size_t> fOutputShape;
0030    std::vector<std::pair<int64_t, int64_t>> fPads;
0031 
0032 public:
0033 
0034    ROperator_Pad(){}
0035    ROperator_Pad(const std::string & nameX, const std::string & nameP,  const std::string & nameCV,
0036                  const std::string & nameAX, const std::string & nameY, const std::string & mode) :
0037       fNX(UTILITY::Clean_name(nameX)), fNP(UTILITY::Clean_name(nameP)),
0038       fNCV(UTILITY::Clean_name(nameCV)), fNAX(UTILITY::Clean_name(nameAX)),
0039       fNY(UTILITY::Clean_name(nameY))
0040       {
0041          fMode = kConstant;
0042          if (mode == "constant")
0043             fMode = kConstant;
0044          else if (mode == "reflect")
0045             fMode = kReflect;
0046          else if (mode == "edge")
0047             fMode = kEdge;
0048          else if (mode == "wrap")
0049             fMode = kWrap;
0050          
0051          fInputTensorNames = { fNX };
0052          fOutputTensorNames = { fNY };
0053       }
0054 
0055    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0056       return input;
0057    }
0058 
0059    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0060       auto ret = input; //suggest copy to compiler
0061       return ret;
0062    }
0063 
0064    void Initialize(RModel& model) override {
0065       if (model.CheckIfTensorAlreadyExist(fNX) == false){   //input must be a graph input, or already initialized intermediate tensor
0066          throw std::runtime_error("TMVA SOFIE Pad Op Input Tensor is not found in model");
0067       }
0068 
0069       fInputShape = model.GetTensorShape(fNX);
0070 
0071       if (fMode != EMode::kConstant) {
0072          throw std::runtime_error("TMVA SOFIE Pad Op supports now only Constant mode");
0073       }
0074 
0075       // get pads data
0076       int64_t * padsData = nullptr;
0077       if (model.IsInitializedTensor(fNP)) {
0078          padsData = static_cast<int64_t*>(model.GetInitializedTensorData(fNP).get());
0079       } else {
0080          throw std::runtime_error("TMVA SOFIE Pad Op supports now only initialized Pads data");
0081       }
0082       // get constant value
0083       fConstantValue = 0;
0084       if (!fNCV.empty()) {
0085          if (model.IsInitializedTensor(fNCV)) {
0086             T * cData = static_cast<T*>(model.GetInitializedTensorData(fNCV).get());
0087             fConstantValue = cData[0];
0088          } else {
0089             throw std::runtime_error("TMVA SOFIE Pad Op supports now only initialized Constant Value  data");
0090          }
0091       }
0092       std::vector<int64_t> axes;
0093       if (!fNAX.empty()) {
0094          if (model.IsInitializedTensor(fNAX)) {
0095             auto shape = model.GetTensorShape(fNAX);
0096             // it should be a 1D tensor
0097             size_t nax = shape[0];
0098             // switch types
0099             if (model.GetTensorType(fNAX) == ETensorType::INT64) {
0100                auto data = static_cast<int64_t*>(model.GetInitializedTensorData(fNAX).get());
0101                axes = std::vector<int64_t>(data, data + nax);
0102             } else if (model.GetTensorType(fNAX) == ETensorType::INT32) {
0103                auto data = static_cast<int32_t*>(model.GetInitializedTensorData(fNAX).get());
0104                axes.resize(nax);
0105                for (size_t i = 0; i < nax; i++)
0106                   axes[i] = data[i];
0107             }  else {
0108                throw std::runtime_error("TMVA SOFIE Pad Op invalid input Axes type");
0109             }
0110          } else {
0111             throw std::runtime_error("TMVA SOFIE Pad Op supports now only initialized Axes data");
0112          }
0113       }
0114 
0115 
0116       fOutputShape = fInputShape;
0117       size_t axesSize = axes.size();
0118       if (axesSize == 0) {
0119          for (size_t i = 0; i < fInputShape.size(); i++) {
0120             axes.push_back(i);
0121          }
0122          axesSize = fInputShape.size();
0123       }
0124       fPads.resize(fInputShape.size());
0125       for (size_t i = 0; i < fInputShape.size(); i++) {
0126          if (axes[i] < 0) axes[i] += fInputShape.size();
0127          if (axes[i] == int64_t(i)) {
0128             fPads[i].first = padsData[i];
0129             fPads[i].second = padsData[axesSize + i];
0130             int64_t outDim = static_cast<int64_t>(fOutputShape[i]) + fPads[i].first + fPads[i].second;
0131             if (outDim < 0)
0132                throw std::runtime_error("TMVA SOFIE Pad Op : invalid Pads values");
0133             fOutputShape[i] = outDim;
0134          }
0135       }
0136 
0137       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fOutputShape);
0138 
0139       if (model.Verbose()) {
0140          std::cout << "initializing Pad operator with pads ..  : ";
0141          for (auto & p : fPads)
0142             std::cout << "{ " << p.first << " , " << p.second << "} ";
0143          std::cout << std::endl;
0144          std::cout <<  "Pad: " << fNX << " " << ConvertShapeToString(fInputShape) << " -> " << fNY << " with shape " << ConvertShapeToString(fOutputShape)
0145                   << std::endl;
0146       }
0147 
0148    }
0149 
0150 
0151    std::string Generate(std::string OpName) override {
0152       OpName = "op_" + OpName;
0153       if (fOutputShape.empty()){
0154          throw std::runtime_error("TMVA SOFIE Operator Pad called to Generate without being initialized first");
0155       }
0156       std::stringstream out;
0157       auto inputStride = UTILITY::ComputeStrideFromShape(fInputShape);
0158       auto outStride = UTILITY::ComputeStrideFromShape(fOutputShape);
0159       out << "\n//------ Pad\n";
0160       // fill first output tensor with the constant values
0161       int length = ConvertShapeToLength(fOutputShape);
0162       int dims = fOutputShape.size();
0163       out << "std::fill(tensor_" << fNY << ", tensor_" << fNY << " + " << length << ","
0164           << fConstantValue << ");\n";
0165 
0166       // copy now data from input tensor in output ones
0167       for (int i = 0; i < dims; i++) {
0168          for (int j = 1; j < i; j++) out << SP;
0169          out << "for (int id" << i << " = 0; id" << i << " < " << fInputShape[i] << "; id"
0170              << i << "++) {\n";
0171       }
0172       // compute index from strides
0173       //linear_index = i_1 * stride[0] + i_2 * stride[1] + ... + i_N * stride[N-1]
0174       for (int j = 0; j < dims; j++) out << SP;
0175       out << "tensor_" << fNY << "[";
0176       for (int i = 0; i < dims; i++) {
0177          out << "(id" << i;
0178          if (fPads[i].first != 0) out << " + " << fPads[i].first;
0179          out << ")";
0180          if (i < dims-1) out << " * " << outStride[i] << " + ";
0181       }
0182       out << "] =\n     tensor_" << fNX << "[";
0183       for (int i = 0; i < dims; i++) {
0184          out << "id" << i;
0185          if (i < dims-1) out << " * " << inputStride[i] << " + ";
0186       }
0187       out << "];\n";
0188       for (int i = dims-1; i >= 0; i--) {
0189          for (int j = 1; j < i; j++) out << SP;
0190          out << "}\n";
0191       }
0192 
0193       return out.str();
0194    }
0195 
0196 };
0197 
0198 }//SOFIE
0199 }//Experimental
0200 }//TMVA
0201 
0202 
0203 #endif //TMVA_SOFIE_ROPERATOR_Swish