File indexing completed on 2025-01-18 10:11:11
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024 #ifndef ROOT_TMVA_TNeuron
0025 #define ROOT_TMVA_TNeuron
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035 #include <iostream>
0036
0037 #include "TString.h"
0038 #include "TObjArray.h"
0039 #include "TFormula.h"
0040
0041 #include "TMVA/TSynapse.h"
0042 #include "TMVA/TActivation.h"
0043 #include "TMVA/Types.h"
0044
0045 namespace TMVA {
0046
0047 class TNeuronInput;
0048
0049 class TNeuron : public TObject {
0050
0051 public:
0052
0053 TNeuron();
0054 virtual ~TNeuron();
0055
0056
0057 void ForceValue(Double_t value);
0058
0059
0060 void CalculateValue();
0061
0062
0063 void CalculateActivationValue();
0064
0065
0066 void CalculateDelta();
0067
0068
0069 void SetActivationEqn(TActivation* activation);
0070
0071
0072 void SetInputCalculator(TNeuronInput* calculator);
0073
0074
0075 void AddPreLink(TSynapse* pre);
0076
0077
0078 void AddPostLink(TSynapse* post);
0079
0080
0081 void DeletePreLinks();
0082
0083
0084 void SetError(Double_t error);
0085
0086
0087
0088 void UpdateSynapsesBatch();
0089
0090
0091 void UpdateSynapsesSequential();
0092
0093
0094
0095 void AdjustSynapseWeights();
0096
0097
0098 void InitSynapseDeltas();
0099
0100
0101 void PrintActivationEqn();
0102
0103
0104 Double_t GetValue() const { return fValue; }
0105 Double_t GetActivationValue() const { return fActivationValue; }
0106 Double_t GetDelta() const { return fDelta; }
0107 Double_t GetDEDw() const { return fDEDw; }
0108 Int_t NumPreLinks() const { return NumLinks(fLinksIn); }
0109 Int_t NumPostLinks() const { return NumLinks(fLinksOut); }
0110 TSynapse* PreLinkAt ( Int_t index ) const { return (TSynapse*)fLinksIn->At(index); }
0111 TSynapse* PostLinkAt( Int_t index ) const { return (TSynapse*)fLinksOut->At(index); }
0112 void SetInputNeuron() { NullifyLinks(fLinksIn); }
0113 void SetOutputNeuron() { NullifyLinks(fLinksOut); }
0114 void SetBiasNeuron() { NullifyLinks(fLinksIn); }
0115 void SetDEDw( Double_t DEDw ) { fDEDw = DEDw; }
0116 Bool_t IsInputNeuron() const { return fLinksIn == nullptr; }
0117 Bool_t IsOutputNeuron() const { return fLinksOut == nullptr; }
0118 void PrintPreLinks() const { PrintLinks(fLinksIn); return; }
0119 void PrintPostLinks() const { PrintLinks(fLinksOut); return; }
0120
0121 virtual void Print(Option_t* = "") const {
0122 std::cout << fValue << std::endl;
0123
0124 }
0125
0126 private:
0127
0128
0129 void InitNeuron();
0130 void DeleteLinksArray( TObjArray*& links );
0131 void PrintLinks ( TObjArray* links ) const;
0132 void PrintMessage ( EMsgType, TString message );
0133
0134
0135 Int_t NumLinks(TObjArray* links) const {
0136 if (links == nullptr) return 0;
0137 else return links->GetEntriesFast();
0138 }
0139 void NullifyLinks(TObjArray*& links) {
0140 if (links != nullptr) { delete links; links = nullptr; }
0141 }
0142
0143
0144 TObjArray* fLinksIn;
0145 TObjArray* fLinksOut;
0146 Double_t fValue;
0147 Double_t fActivationValue;
0148 Double_t fDelta;
0149 Double_t fDEDw;
0150 Double_t fError;
0151 Bool_t fForcedValue;
0152 TActivation* fActivation;
0153 TNeuronInput* fInputCalculator;
0154
0155 MsgLogger& Log() const;
0156
0157 ClassDef(TNeuron,0);
0158 };
0159
0160 }
0161
0162 #endif