File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_SOFIE_RFUNCTION_MLP
0002 #define TMVA_SOFIE_RFUNCTION_MLP
0003
0004 #include "TMVA/RFunction.hxx"
0005
0006 #include <vector>
0007
0008
0009 namespace TMVA {
0010 namespace Experimental {
0011 namespace SOFIE {
0012
0013 enum class Activation {
0014 RELU = 0x0,
0015 Invalid = 0x1,
0016 };
0017
0018 class RFunction_MLP: public RFunction_Update {
0019 private:
0020 Int_t fNumLayers;
0021 Activation fActivationFunction;
0022 bool fActivateFinal;
0023 std::vector<std::string> fKernelTensors;
0024 std::vector<std::string> fBiasTensors;
0025
0026 public:
0027 virtual ~RFunction_MLP() {}
0028 RFunction_MLP(FunctionTarget target, Int_t numLayers, Activation activation_function=Activation::RELU, bool activate_final=false, GraphType gType=GraphType::GNN);
0029
0030 void Initialize();
0031
0032 void AddLayerNormalization(int axis, float epsilon, size_t stashType, const std::string &nameX,
0033 const std::string &nameScale, const std::string &nameB, const std::string &nameY);
0034
0035 void AddInitializedTensors(const std::vector<std::vector<std::string>>& initialized_tensors) {
0036 fKernelTensors = initialized_tensors[0];
0037 fBiasTensors = initialized_tensors[1];
0038 }
0039 };
0040
0041 }
0042 }
0043 }
0044
0045 #endif