Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:11

0001 // @(#)root/tmva $Id$
0002 // Author: Matt Jachowski
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : TMVA::TNeuron                                                         *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Neuron class to be used in MethodANNBase and its derivatives.             *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Matt Jachowski  <jachowski@stanford.edu> - Stanford University, USA       *
0015  *                                                                                *
0016  * Copyright (c) 2005:                                                            *
0017  *      CERN, Switzerland                                                         *
0018  *                                                                                *
0019  * Redistribution and use in source and binary forms, with or without             *
0020  * modification, are permitted according to the terms listed in LICENSE           *
0021  * (see tmva/doc/LICENSE)                                          *
0022  **********************************************************************************/
0023 
0024 #ifndef ROOT_TMVA_TNeuron
0025 #define ROOT_TMVA_TNeuron
0026 
0027 //////////////////////////////////////////////////////////////////////////
0028 //                                                                      //
0029 // TNeuron                                                              //
0030 //                                                                      //
0031 // Neuron used by derivatives of MethodANNBase                          //
0032 //                                                                      //
0033 //////////////////////////////////////////////////////////////////////////
0034 
0035 #include <iostream>
0036 
0037 #include "TString.h"
0038 #include "TObjArray.h"
0039 #include "TFormula.h"
0040 
0041 #include "TMVA/TSynapse.h"
0042 #include "TMVA/TActivation.h"
0043 #include "TMVA/Types.h"
0044 
0045 namespace TMVA {
0046 
0047    class TNeuronInput;
0048 
0049    class TNeuron : public TObject {
0050 
0051    public:
0052 
0053       TNeuron();
0054       virtual ~TNeuron();
0055 
0056       // force the input value
0057       void ForceValue(Double_t value);
0058 
0059       // calculate the input value
0060       void CalculateValue();
0061 
0062       // calculate the activation value
0063       void CalculateActivationValue();
0064 
0065       // calculate the error field of the neuron
0066       void CalculateDelta();
0067 
0068       // set the activation function
0069       void SetActivationEqn(TActivation* activation);
0070 
0071       // set the input calculator
0072       void SetInputCalculator(TNeuronInput* calculator);
0073 
0074       // add a synapse as a pre-link
0075       void AddPreLink(TSynapse* pre);
0076 
0077       // add a synapse as a post-link
0078       void AddPostLink(TSynapse* post);
0079 
0080       // delete all pre-links
0081       void DeletePreLinks();
0082 
0083       // set the error
0084       void SetError(Double_t error);
0085 
0086       // update the error fields of all pre-synapses, batch mode
0087       // to actually update the weights, call adjust synapse weights
0088       void UpdateSynapsesBatch();
0089 
0090       // update the error fields and weights of all pre-synapses, sequential mode
0091       void UpdateSynapsesSequential();
0092 
0093       // update the weights of the all pre-synapses, batch mode
0094       //(call UpdateSynapsesBatch first)
0095       void AdjustSynapseWeights();
0096 
0097       // explicitly initialize error fields of pre-synapses, batch mode
0098       void InitSynapseDeltas();
0099 
0100       // print activation equation, for debugging
0101       void PrintActivationEqn();
0102 
0103       // inlined functions
0104       Double_t  GetValue() const                { return fValue;                          }
0105       Double_t  GetActivationValue() const      { return fActivationValue;                }
0106       Double_t  GetDelta() const                { return fDelta;                          }
0107       Double_t  GetDEDw() const                 { return fDEDw;                           }
0108       Int_t     NumPreLinks() const             { return NumLinks(fLinksIn);              }
0109       Int_t     NumPostLinks() const            { return NumLinks(fLinksOut);             }
0110       TSynapse* PreLinkAt ( Int_t index ) const { return (TSynapse*)fLinksIn->At(index);  }
0111       TSynapse* PostLinkAt( Int_t index ) const { return (TSynapse*)fLinksOut->At(index); }
0112       void      SetInputNeuron()                { NullifyLinks(fLinksIn);                 }
0113       void      SetOutputNeuron()               { NullifyLinks(fLinksOut);                }
0114       void      SetBiasNeuron()                 { NullifyLinks(fLinksIn);                 }
0115       void      SetDEDw( Double_t DEDw )        { fDEDw = DEDw;                           }
0116       Bool_t    IsInputNeuron() const           { return fLinksIn == nullptr;             }
0117       Bool_t    IsOutputNeuron() const          { return fLinksOut == nullptr;            }
0118       void      PrintPreLinks() const           { PrintLinks(fLinksIn); return;           }
0119       void      PrintPostLinks() const          { PrintLinks(fLinksOut); return;          }
0120 
0121       virtual void Print(Option_t* = "") const {
0122          std::cout << fValue << std::endl;
0123          //PrintPreLinks(); PrintPostLinks();
0124       }
0125 
0126    private:
0127 
0128       // private helper functions
0129       void InitNeuron();
0130       void DeleteLinksArray( TObjArray*& links );
0131       void PrintLinks      ( TObjArray* links ) const;
0132       void PrintMessage    ( EMsgType, TString message );
0133 
0134       // inlined helper functions
0135       Int_t NumLinks(TObjArray* links) const {
0136          if (links == nullptr) return 0;
0137          else return links->GetEntriesFast();
0138       }
0139       void NullifyLinks(TObjArray*& links) {
0140          if (links != nullptr) { delete links; links = nullptr; }
0141       }
0142 
0143       // private member variables
0144       TObjArray*    fLinksIn;                 ///< array of input synapses
0145       TObjArray*    fLinksOut;                ///< array of output synapses
0146       Double_t      fValue;                   ///< input value
0147       Double_t      fActivationValue;         ///< activation/output value
0148       Double_t      fDelta;                   ///< error field of neuron
0149       Double_t      fDEDw;                    ///< sum of all deltas
0150       Double_t      fError;                   ///< error, only set for output neurons
0151       Bool_t        fForcedValue;             ///< flag for forced input value
0152       TActivation*  fActivation;              ///< activation equation
0153       TNeuronInput* fInputCalculator;         ///< input calculator
0154 
0155       MsgLogger& Log() const;
0156 
0157       ClassDef(TNeuron,0); // Neuron class used by MethodANNBase derivative ANNs
0158    };
0159 
0160 } // namespace TMVA
0161 
0162 #endif