Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:11

0001 // @(#)root/tmva $Id$
0002 // Author: Matt Jachowski
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : TMVA::TSynapse                                                        *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Synapse class for use in derivatives of MethodANNBase                     *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Matt Jachowski  <jachowski@stanford.edu> - Stanford University, USA       *
0015  *                                                                                *
0016  * Copyright (c) 2005:                                                            *
0017  *      CERN, Switzerland                                                         *
0018  *                                                                                *
0019  * Redistribution and use in source and binary forms, with or without             *
0020  * modification, are permitted according to the terms listed in LICENSE           *
0021  * (see tmva/doc/LICENSE)                                          *
0022  **********************************************************************************/
0023 
0024 #ifndef ROOT_TMVA_TSynapse
0025 #define ROOT_TMVA_TSynapse
0026 
0027 //////////////////////////////////////////////////////////////////////////
0028 //                                                                      //
0029 // TSynapse                                                             //
0030 //                                                                      //
0031 // Synapse used by derivatives of MethodANNBase                         //
0032 //                                                                      //
0033 //////////////////////////////////////////////////////////////////////////
0034 
0035 #include "TObject.h"
0036 
0037 namespace TMVA {
0038 
0039    class TNeuron;
0040    class MsgLogger;
0041 
0042    class TSynapse : public TObject {
0043 
0044    public:
0045 
0046       TSynapse();
0047       virtual ~TSynapse();
0048 
0049       // set the weight of the synapse
0050       void SetWeight(Double_t weight);
0051 
0052       // get the weight of the synapse
0053       Double_t GetWeight()                 { return fWeight;         }
0054 
0055       // set the learning rate
0056       void SetLearningRate(Double_t rate)  { fLearnRate = rate;      }
0057 
0058       // get the learning rate
0059       Double_t GetLearningRate()           { return fLearnRate;      }
0060 
0061       // decay the learning rate
0062       void DecayLearningRate(Double_t rate){ fLearnRate *= (1-rate); }
0063 
0064       // set the pre-neuron
0065       void SetPreNeuron(TNeuron* pre)      { fPreNeuron = pre;       }
0066 
0067       // set the post-neuron
0068       void SetPostNeuron(TNeuron* post)    { fPostNeuron = post;     }
0069 
0070       // get the weighted output of the pre-neuron
0071       Double_t GetWeightedValue();
0072 
0073       // get the weighted error field of the post-neuron
0074       Double_t GetWeightedDelta();
0075 
0076       // force the synapse to adjust its weight according to its error field
0077       void AdjustWeight();
0078 
0079       // calculate the error field of the synapse
0080       void CalculateDelta();
0081 
0082       // initialize the error field of the synapse to 0
0083       void InitDelta()           { fDelta = 0.0; fCount = 0; }
0084 
0085       void SetDEDw(Double_t DEDw)              { fDEDw = DEDw;           }
0086       Double_t GetDEDw()                       { return fDEDw;           }
0087       Double_t GetDelta()                      { return fDelta;          }
0088 
0089    private:
0090 
0091       Double_t fWeight;            ///< weight of the synapse
0092       Double_t fLearnRate;         ///< learning rate parameter
0093       Double_t fDelta;             ///< local error field
0094       Double_t fDEDw;              ///< sum of deltas
0095       Int_t    fCount;             ///< number of updates contributing to error field
0096       TNeuron* fPreNeuron;         ///< pointer to pre-neuron
0097       TNeuron* fPostNeuron;        ///< pointer to post-neuron
0098 
0099       MsgLogger& Log() const;
0100 
0101       ClassDef(TSynapse,0); // Synapse class used by MethodANNBase and derivatives
0102    };
0103 
0104 } // namespace TMVA
0105 
0106 #endif