Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:57

0001 // @(#)root/tmva/tmva/dnn:$Id$
0002 // Author: Vladimir Ilievski
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : TReshapeLayer                                                         *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Reshape Deep Neural Network Layer                                         *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Vladimir Ilievski      <ilievski.vladimir@live.com>  - CERN, Switzerland  *
0015  *                                                                                *
0016  * Copyright (c) 2005-2015:                                                       *
0017  *      CERN, Switzerland                                                         *
0018  *      U. of Victoria, Canada                                                    *
0019  *      MPI-K Heidelberg, Germany                                                 *
0020  *      U. of Bonn, Germany                                                       *
0021  *                                                                                *
0022  * Redistribution and use in source and binary forms, with or without             *
0023  * modification, are permitted according to the terms listed in LICENSE           *
0024  * (see tmva/doc/LICENSE)                                          *
0025  **********************************************************************************/
0026 
0027 #ifndef TMVA_DNN_RESHAPELAYER
0028 #define TMVA_DNN_RESHAPELAYER
0029 
0030 #include "TMatrix.h"
0031 
0032 #include "TMVA/DNN/GeneralLayer.h"
0033 #include "TMVA/DNN/Functions.h"
0034 
0035 #include <iostream>
0036 
0037 namespace TMVA {
0038 namespace DNN {
0039 
0040 template <typename Architecture_t>
0041 class TReshapeLayer : public VGeneralLayer<Architecture_t> {
0042 public:
0043    using Tensor_t = typename Architecture_t::Tensor_t;
0044    using Matrix_t = typename Architecture_t::Matrix_t;
0045    using Scalar_t = typename Architecture_t::Scalar_t;
0046 
0047 private:
0048    bool fFlattening; ///< Whether the layer is doing flattening
0049 
0050 public:
0051    /*! Constructor */
0052    TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth,
0053                  size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols,
0054                  bool Flattening);
0055 
0056    /*! Copy the reshape layer provided as a pointer */
0057    TReshapeLayer(TReshapeLayer<Architecture_t> *layer);
0058 
0059    /*! Copy Constructor */
0060    TReshapeLayer(const TReshapeLayer &);
0061 
0062    /*! Destructor. */
0063    ~TReshapeLayer();
0064 
0065    /*! The input must be in 3D tensor form with the different matrices
0066     *  corresponding to different events in the batch. It transforms the
0067     *  input matrices. */
0068    void Forward(Tensor_t &input, bool applyDropout = false);
0069 
0070    void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
0071    //              Tensor_t &inp1, Tensor_t &inp2);
0072 
0073    /*! Prints the info about the layer. */
0074    void Print() const;
0075 
0076    /*! Writes the information and the weights about the layer in an XML node. */
0077    virtual void AddWeightsXMLTo(void *parent);
0078 
0079    /*! Read the information and the weights about the layer from XML node. */
0080    virtual void ReadWeightsFromXML(void *parent);
0081 
0082 
0083    /*! TODO Add documentation
0084     * Does this layer flatten? (necessary for DenseLayer)
0085     * B x D1 x D2 --> 1 x B x (D1 * D2) */
0086    bool isFlattening() const { return fFlattening; }
0087 };
0088 
0089 //
0090 //
0091 //  The Reshape Layer Class - Implementation
0092 //_________________________________________________________________________________________________
0093 template <typename Architecture_t>
0094 TReshapeLayer<Architecture_t>::TReshapeLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
0095                                              size_t depth, size_t height, size_t width, size_t outputNSlices,
0096                                              size_t outputNRows, size_t outputNCols, bool flattening)
0097    : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
0098                                    0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
0099      fFlattening(flattening)
0100 {
0101    if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
0102        this->GetDepth() * this->GetHeight() * this->GetWidth()) {
0103       std::cout << "Reshape Dimensions not compatible \n"
0104                 << this->GetInputDepth() << " x " << this->GetInputHeight() << " x " << this->GetInputWidth() << " --> "
0105                 << this->GetDepth() << " x " << this->GetHeight() << " x " << this->GetWidth() << std::endl;
0106       return;
0107    }
0108 }
0109 
0110 //_________________________________________________________________________________________________
0111 template <typename Architecture_t>
0112 TReshapeLayer<Architecture_t>::TReshapeLayer(TReshapeLayer<Architecture_t> *layer)
0113    : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
0114 {
0115 }
0116 
0117 //_________________________________________________________________________________________________
0118 template <typename Architecture_t>
0119 TReshapeLayer<Architecture_t>::TReshapeLayer(const TReshapeLayer &layer)
0120    : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
0121 {
0122    // Nothing to do here.
0123 }
0124 
0125 //_________________________________________________________________________________________________
0126 template <typename Architecture_t>
0127 TReshapeLayer<Architecture_t>::~TReshapeLayer()
0128 {
0129    // Nothing to do here.
0130 }
0131 
0132 //_________________________________________________________________________________________________
0133 template <typename Architecture_t>
0134 auto TReshapeLayer<Architecture_t>::Forward(Tensor_t &input, bool /*applyDropout*/) -> void
0135 {
0136    if (fFlattening) {
0137 
0138       Architecture_t::Flatten(this->GetOutput(), input);
0139 
0140       return;
0141    } else {
0142 
0143          Architecture_t::Deflatten(this->GetOutput(), input); //, out_size, nRows, nCols);
0144          return;
0145       }
0146 }
0147 //_________________________________________________________________________________________________
0148 template <typename Architecture_t>
0149 auto TReshapeLayer<Architecture_t>::Backward(Tensor_t &gradients_backward, const Tensor_t &
0150                                              /*activations_backward*/) -> void
0151 //                                             Tensor_t & /*inp1*/, Tensor_t &
0152 //                                             /*inp2*/) -> void
0153 {
0154    size_t size = gradients_backward.GetSize();
0155    // in case of first layer size is zero - do nothing
0156    if (size == 0) return;
0157    if (fFlattening) {
0158       // deflatten in backprop
0159       Architecture_t::Deflatten(gradients_backward, this->GetActivationGradients());
0160       return;
0161    } else {
0162          Architecture_t::Flatten(gradients_backward, this->GetActivationGradients() );
0163          return;
0164    }
0165 }
0166 
0167 //_________________________________________________________________________________________________
0168 template <typename Architecture_t>
0169 auto TReshapeLayer<Architecture_t>::Print() const -> void
0170 {
0171    std::cout << " RESHAPE Layer \t ";
0172    std::cout << "Input = ( " << this->GetInputDepth() << " , " <<  this->GetInputHeight() << " , " << this->GetInputWidth() << " ) ";
0173    if (this->GetOutput().GetSize() > 0) {
0174       std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput().GetHSize() << " , " << this->GetOutput().GetWSize() << " ) ";
0175    }
0176    std::cout << std::endl;
0177 }
0178 
0179 template <typename Architecture_t>
0180 auto TReshapeLayer<Architecture_t>::AddWeightsXMLTo(void *parent) -> void
0181 {
0182    auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "ReshapeLayer");
0183 
0184    // write info for reshapelayer
0185    gTools().xmlengine().NewAttr(layerxml, nullptr, "Depth", gTools().StringFromInt(this->GetDepth()));
0186    gTools().xmlengine().NewAttr(layerxml, nullptr, "Height", gTools().StringFromInt(this->GetHeight()));
0187    gTools().xmlengine().NewAttr(layerxml, nullptr, "Width", gTools().StringFromInt(this->GetWidth()));
0188    gTools().xmlengine().NewAttr(layerxml, nullptr, "Flattening", gTools().StringFromInt(this->isFlattening()));
0189 
0190 
0191 }
0192 
0193 //______________________________________________________________________________
0194 template <typename Architecture_t>
0195 void TReshapeLayer<Architecture_t>::ReadWeightsFromXML(void * /*parent*/)
0196 {
0197    // no info to read
0198 }
0199 
0200 
0201 
0202 } // namespace DNN
0203 } // namespace TMVA
0204 
0205 #endif