File indexing completed on 2025-09-17 09:14:56
0001 #ifndef TMVA_SOFIE_ROPERATOR_Swish
0002 #define TMVA_SOFIE_ROPERATOR_Swish
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014 template <typename T>
0015 class ROperator_Swish final : public ROperator
0016 {
0017
0018 private:
0019
0020 std::string fNX;
0021 std::string fNY;
0022 std::vector<size_t> fShape;
0023
0024 public:
0025 ROperator_Swish(){}
0026 ROperator_Swish(std::string nameX, std::string nameY):
0027 fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){
0028 fInputTensorNames = { fNX };
0029 fOutputTensorNames = { fNY };
0030 }
0031
0032 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0033 return input;
0034 }
0035
0036 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0037 auto ret = input;
0038 return ret;
0039 }
0040
0041 void Initialize(RModel& model) override {
0042 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0043 throw std::runtime_error("TMVA SOFIE Swish Op Input Tensor is not found in model");
0044 }
0045 fShape = model.GetTensorShape(fNX);
0046 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
0047 }
0048
0049
0050 std::string Generate(std::string OpName) override {
0051 OpName = "op_" + OpName;
0052 if (fShape.empty()){
0053 throw std::runtime_error("TMVA SOFIE Operator Swish called to Generate without being initialized first");
0054 }
0055 std::stringstream out;
0056 int length = 1;
0057 for(auto& i: fShape){
0058 length *= i;
0059 }
0060 out << "\t" << "for (int id = 0; id < " << length << " ; id++){\n";
0061 out << "\t\t" << "tensor_" << fNY << "[id] = tensor_" << fNX <<"[id] / (1 + std::exp( - tensor_" << fNX << "[id]));\n";
0062 out << "\t}\n";
0063 return out.str();
0064 }
0065
0066 std::vector<std::string> GetStdLibs() override { return { std::string("cmath") };}
0067 };
0068
0069 }
0070 }
0071 }
0072
0073
0074 #endif