File indexing completed on 2025-01-18 10:10:53
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifndef MAXPOOLLAYER_H_
0028 #define MAXPOOLLAYER_H_
0029
0030 #include "TMatrix.h"
0031
0032 #include "TMVA/DNN/CNN/ConvLayer.h"
0033 #include "TMVA/DNN/Functions.h"
0034 #include "TMVA/DNN/CNN/ContextHandles.h"
0035
0036 #include <iostream>
0037
0038 namespace TMVA {
0039 namespace DNN {
0040 namespace CNN {
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058 template <typename Architecture_t>
0059 class TMaxPoolLayer : public VGeneralLayer<Architecture_t> {
0060
0061 public:
0062 using Tensor_t = typename Architecture_t::Tensor_t;
0063 using Matrix_t = typename Architecture_t::Matrix_t;
0064 using Scalar_t = typename Architecture_t::Scalar_t;
0065
0066 using LayerDescriptor_t = typename Architecture_t::PoolingDescriptor_t;
0067 using WeightsDescriptor_t = typename Architecture_t::EmptyDescriptor_t;
0068 using HelperDescriptor_t = typename Architecture_t::DropoutDescriptor_t;
0069
0070
0071 using AlgorithmForward_t = typename Architecture_t::AlgorithmForward_t;
0072 using AlgorithmBackward_t = typename Architecture_t::AlgorithmBackward_t;
0073 using AlgorithmHelper_t = typename Architecture_t::AlgorithmHelper_t;
0074
0075
0076 using AlgorithmDataType_t = typename Architecture_t::AlgorithmDataType_t;
0077
0078 using ReduceTensorDescriptor_t = typename Architecture_t::ReduceTensorDescriptor_t;
0079
0080 protected:
0081 size_t fFilterDepth;
0082 size_t fFilterHeight;
0083 size_t fFilterWidth;
0084
0085 size_t fStrideRows;
0086 size_t fStrideCols;
0087
0088 size_t fNLocalViewPixels;
0089 size_t fNLocalViews;
0090
0091 Scalar_t fDropoutProbability;
0092
0093 TDescriptors *fDescriptors = nullptr;
0094
0095 TWorkspace *fWorkspace = nullptr;
0096
0097 private:
0098 Tensor_t fIndexTensor;
0099
0100 void InitializeDescriptors();
0101 void ReleaseDescriptors();
0102 void InitializeWorkspace();
0103 void FreeWorkspace();
0104
0105 public:
0106
0107 TMaxPoolLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t FilterHeight,
0108 size_t FilterWidth, size_t StrideRows, size_t StrideCols, Scalar_t DropoutProbability);
0109
0110
0111 TMaxPoolLayer(TMaxPoolLayer<Architecture_t> *layer);
0112
0113
0114 TMaxPoolLayer(const TMaxPoolLayer &);
0115
0116
0117 virtual ~TMaxPoolLayer();
0118
0119
0120
0121
0122
0123
0124
0125 void Forward(Tensor_t &input, bool applyDropout = true);
0126
0127
0128
0129
0130
0131 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
0132
0133
0134
0135 virtual void AddWeightsXMLTo(void *parent);
0136
0137
0138 virtual void ReadWeightsFromXML(void *parent);
0139
0140
0141 void Print() const;
0142
0143
0144 size_t GetFilterDepth() const { return fFilterDepth; }
0145 size_t GetFilterHeight() const { return fFilterHeight; }
0146 size_t GetFilterWidth() const { return fFilterWidth; }
0147
0148 size_t GetStrideRows() const { return fStrideRows; }
0149 size_t GetStrideCols() const { return fStrideCols; }
0150
0151 size_t GetNLocalViews() const { return fNLocalViews; }
0152
0153 Scalar_t GetDropoutProbability() const { return fDropoutProbability; }
0154
0155 const Tensor_t & GetIndexTensor() const { return fIndexTensor; }
0156 Tensor_t & GetIndexTensor() { return fIndexTensor; }
0157
0158
0159 TDescriptors *GetDescriptors() { return fDescriptors; }
0160 const TDescriptors *GetDescriptors() const { return fDescriptors; }
0161
0162 TWorkspace *GetWorkspace() { return fWorkspace; }
0163 const TWorkspace *GetWorkspace() const { return fWorkspace; }
0164 };
0165
0166
0167 template <typename Architecture_t>
0168 TMaxPoolLayer<Architecture_t>::TMaxPoolLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
0169 size_t filterHeight, size_t filterWidth, size_t strideRows,
0170 size_t strideCols, Scalar_t dropoutProbability)
0171 : VGeneralLayer<Architecture_t>(
0172 batchSize, inputDepth, inputHeight, inputWidth, inputDepth,
0173 TConvLayer<Architecture_t>::calculateDimension(inputHeight, filterHeight, 0, strideRows),
0174 TConvLayer<Architecture_t>::calculateDimension(inputWidth, filterWidth, 0, strideCols), 0, 0, 0, 0, 0,
0175 0,
0176 batchSize, inputDepth,
0177 TConvLayer<Architecture_t>::calculateNLocalViews(inputHeight, filterHeight, 0, strideRows, inputWidth,
0178 filterWidth, 0, strideCols),
0179 EInitialization::kZero),
0180 fFilterDepth(inputDepth), fFilterHeight(filterHeight), fFilterWidth(filterWidth), fStrideRows(strideRows),
0181 fStrideCols(strideCols),
0182 fNLocalViews(TConvLayer<Architecture_t>::calculateNLocalViews(inputHeight, filterHeight, 0, strideRows,
0183 inputWidth, filterWidth, 0, strideCols)),
0184 fDropoutProbability(dropoutProbability), fIndexTensor(batchSize, inputDepth, fNLocalViews)
0185 {
0186 InitializeDescriptors();
0187 InitializeWorkspace();
0188 }
0189
0190
0191 template <typename Architecture_t>
0192 TMaxPoolLayer<Architecture_t>::TMaxPoolLayer(TMaxPoolLayer<Architecture_t> *layer)
0193 : VGeneralLayer<Architecture_t>(layer), fFilterDepth(layer->GetFilterDepth()),
0194 fFilterHeight(layer->GetFilterHeight()), fFilterWidth(layer->GetFilterWidth()),
0195 fStrideRows(layer->GetStrideRows()), fStrideCols(layer->GetStrideCols()), fNLocalViews(layer->GetNLocalViews()),
0196 fDropoutProbability(layer->GetDropoutProbability()), fIndexTensor(layer->GetIndexTensor().GetShape())
0197 {
0198 InitializeDescriptors();
0199 InitializeWorkspace();
0200 }
0201
0202
0203 template <typename Architecture_t>
0204 TMaxPoolLayer<Architecture_t>::TMaxPoolLayer(const TMaxPoolLayer &layer)
0205 : VGeneralLayer<Architecture_t>(layer), fFilterDepth(layer.fFilterDepth), fFilterHeight(layer.fFilterHeight),
0206 fFilterWidth(layer.fFilterWidth), fStrideRows(layer.fStrideRows), fStrideCols(layer.fStrideCols),
0207 fNLocalViews(layer.fNLocalViews), fDropoutProbability(layer.fDropoutProbability),
0208 fIndexTensor(layer.GetIndexTensor().GetShape())
0209 {
0210 InitializeDescriptors();
0211 InitializeWorkspace();
0212 }
0213
0214
0215 template <typename Architecture_t>
0216 TMaxPoolLayer<Architecture_t>::~TMaxPoolLayer()
0217 {
0218 if (fDescriptors) {
0219 ReleaseDescriptors();
0220 delete fDescriptors;
0221 fDescriptors = nullptr;
0222 }
0223
0224 if (fWorkspace) {
0225 FreeWorkspace();
0226 delete fWorkspace;
0227 fWorkspace = nullptr;
0228 }
0229 }
0230
0231
0232 template <typename Architecture_t>
0233 auto TMaxPoolLayer<Architecture_t>::Forward(Tensor_t &input, bool applyDropout) -> void
0234 {
0235 if (applyDropout && (this->GetDropoutProbability() != 1.0)) {
0236 Architecture_t::DropoutForward(input, fDescriptors, fWorkspace, this->GetDropoutProbability());
0237 }
0238
0239 Architecture_t::Downsample(
0240 this->GetOutput(), fIndexTensor, input, (TCNNDescriptors<TMaxPoolLayer<Architecture_t>> &)*fDescriptors,
0241 (TCNNWorkspace<TMaxPoolLayer<Architecture_t>> &)*fWorkspace, this->GetInputHeight(), this->GetInputWidth(),
0242 this->GetFilterHeight(), this->GetFilterWidth(), this->GetStrideRows(), this->GetStrideCols());
0243 }
0244
0245
0246 template <typename Architecture_t>
0247 auto TMaxPoolLayer<Architecture_t>::Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward) -> void
0248
0249 {
0250
0251 if (this->GetDropoutProbability() != 1.0) {
0252 Architecture_t::DropoutBackward(this->GetActivationGradients(), fDescriptors, fWorkspace);
0253 }
0254 Architecture_t::MaxPoolLayerBackward(
0255 gradients_backward, this->GetActivationGradients(), fIndexTensor, activations_backward, this->GetOutput(),
0256 (TCNNDescriptors<TMaxPoolLayer<Architecture_t>> &)(*fDescriptors),
0257 (TCNNWorkspace<TMaxPoolLayer<Architecture_t>> &)(*fWorkspace), this->GetInputHeight(), this->GetInputWidth(),
0258 this->GetFilterHeight(), this->GetFilterWidth(), this->GetStrideRows(), this->GetStrideCols(),
0259 this->GetNLocalViews());
0260 }
0261
0262
0263 template <typename Architecture_t>
0264 auto TMaxPoolLayer<Architecture_t>::Print() const -> void
0265 {
0266 std::cout << " POOL Layer: \t";
0267 std::cout << "( W = " << this->GetWidth() << " , ";
0268 std::cout << " H = " << this->GetHeight() << " , ";
0269 std::cout << " D = " << this->GetDepth() << " ) ";
0270
0271 std::cout << "\t Filter ( W = " << this->GetFilterWidth() << " , ";
0272 std::cout << " H = " << this->GetFilterHeight() << " ) ";
0273
0274 if (this->GetOutput().GetSize() > 0) {
0275 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput().GetCSize()
0276 << " , " << this->GetOutput().GetHSize() << " , " << this->GetOutput().GetWSize() << " ) ";
0277 }
0278 std::cout << std::endl;
0279 }
0280
0281
0282 template <typename Architecture_t>
0283 void TMaxPoolLayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0284 {
0285 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "MaxPoolLayer");
0286
0287
0288 gTools().xmlengine().NewAttr(layerxml, nullptr, "FilterHeight", gTools().StringFromInt(this->GetFilterHeight()));
0289 gTools().xmlengine().NewAttr(layerxml, nullptr, "FilterWidth", gTools().StringFromInt(this->GetFilterWidth()));
0290 gTools().xmlengine().NewAttr(layerxml, nullptr, "StrideRows", gTools().StringFromInt(this->GetStrideRows()));
0291 gTools().xmlengine().NewAttr(layerxml, nullptr, "StrideCols", gTools().StringFromInt(this->GetStrideCols()));
0292
0293 }
0294
0295
0296 template <typename Architecture_t>
0297 void TMaxPoolLayer<Architecture_t>::ReadWeightsFromXML(void * )
0298 {
0299
0300 }
0301
0302
0303 template <typename Architecture_t>
0304 void TMaxPoolLayer<Architecture_t>::InitializeDescriptors() {
0305 Architecture_t::InitializePoolDescriptors(fDescriptors, this);
0306 }
0307
0308 template <typename Architecture_t>
0309 void TMaxPoolLayer<Architecture_t>::ReleaseDescriptors() {
0310 Architecture_t::ReleasePoolDescriptors(fDescriptors);
0311 }
0312
0313
0314 template <typename Architecture_t>
0315 void TMaxPoolLayer<Architecture_t>::InitializeWorkspace() {
0316
0317
0318
0319
0320
0321
0322
0323 }
0324
0325 template <typename Architecture_t>
0326 void TMaxPoolLayer<Architecture_t>::FreeWorkspace() {
0327
0328 }
0329
0330 }
0331 }
0332 }
0333
0334 #endif