File indexing completed on 2025-01-18 10:10:57
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifndef TMVA_DNN_RESHAPELAYER
0028 #define TMVA_DNN_RESHAPELAYER
0029
0030 #include "TMatrix.h"
0031
0032 #include "TMVA/DNN/GeneralLayer.h"
0033 #include "TMVA/DNN/Functions.h"
0034
0035 #include <iostream>
0036
0037 namespace TMVA {
0038 namespace DNN {
0039
0040 template <typename Architecture_t>
0041 class TReshapeLayer : public VGeneralLayer<Architecture_t> {
0042 public:
0043 using Tensor_t = typename Architecture_t::Tensor_t;
0044 using Matrix_t = typename Architecture_t::Matrix_t;
0045 using Scalar_t = typename Architecture_t::Scalar_t;
0046
0047 private:
0048 bool fFlattening;
0049
0050 public:
0051
0052 TReshapeLayer(size_t BatchSize, size_t InputDepth, size_t InputHeight, size_t InputWidth, size_t Depth,
0053 size_t Height, size_t Width, size_t OutputNSlices, size_t OutputNRows, size_t OutputNCols,
0054 bool Flattening);
0055
0056
0057 TReshapeLayer(TReshapeLayer<Architecture_t> *layer);
0058
0059
0060 TReshapeLayer(const TReshapeLayer &);
0061
0062
0063 ~TReshapeLayer();
0064
0065
0066
0067
0068 void Forward(Tensor_t &input, bool applyDropout = false);
0069
0070 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
0071
0072
0073
0074 void Print() const;
0075
0076
0077 virtual void AddWeightsXMLTo(void *parent);
0078
0079
0080 virtual void ReadWeightsFromXML(void *parent);
0081
0082
0083
0084
0085
0086 bool isFlattening() const { return fFlattening; }
0087 };
0088
0089
0090
0091
0092
0093 template <typename Architecture_t>
0094 TReshapeLayer<Architecture_t>::TReshapeLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
0095 size_t depth, size_t height, size_t width, size_t outputNSlices,
0096 size_t outputNRows, size_t outputNCols, bool flattening)
0097 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, depth, height, width, 0, 0, 0, 0, 0,
0098 0, outputNSlices, outputNRows, outputNCols, EInitialization::kZero),
0099 fFlattening(flattening)
0100 {
0101 if (this->GetInputDepth() * this->GetInputHeight() * this->GetInputWidth() !=
0102 this->GetDepth() * this->GetHeight() * this->GetWidth()) {
0103 std::cout << "Reshape Dimensions not compatible \n"
0104 << this->GetInputDepth() << " x " << this->GetInputHeight() << " x " << this->GetInputWidth() << " --> "
0105 << this->GetDepth() << " x " << this->GetHeight() << " x " << this->GetWidth() << std::endl;
0106 return;
0107 }
0108 }
0109
0110
0111 template <typename Architecture_t>
0112 TReshapeLayer<Architecture_t>::TReshapeLayer(TReshapeLayer<Architecture_t> *layer)
0113 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer->isFlattening())
0114 {
0115 }
0116
0117
0118 template <typename Architecture_t>
0119 TReshapeLayer<Architecture_t>::TReshapeLayer(const TReshapeLayer &layer)
0120 : VGeneralLayer<Architecture_t>(layer), fFlattening(layer.fFlattening)
0121 {
0122
0123 }
0124
0125
0126 template <typename Architecture_t>
0127 TReshapeLayer<Architecture_t>::~TReshapeLayer()
0128 {
0129
0130 }
0131
0132
0133 template <typename Architecture_t>
0134 auto TReshapeLayer<Architecture_t>::Forward(Tensor_t &input, bool ) -> void
0135 {
0136 if (fFlattening) {
0137
0138 Architecture_t::Flatten(this->GetOutput(), input);
0139
0140 return;
0141 } else {
0142
0143 Architecture_t::Deflatten(this->GetOutput(), input);
0144 return;
0145 }
0146 }
0147
0148 template <typename Architecture_t>
0149 auto TReshapeLayer<Architecture_t>::Backward(Tensor_t &gradients_backward, const Tensor_t &
0150 ) -> void
0151
0152
0153 {
0154 size_t size = gradients_backward.GetSize();
0155
0156 if (size == 0) return;
0157 if (fFlattening) {
0158
0159 Architecture_t::Deflatten(gradients_backward, this->GetActivationGradients());
0160 return;
0161 } else {
0162 Architecture_t::Flatten(gradients_backward, this->GetActivationGradients() );
0163 return;
0164 }
0165 }
0166
0167
0168 template <typename Architecture_t>
0169 auto TReshapeLayer<Architecture_t>::Print() const -> void
0170 {
0171 std::cout << " RESHAPE Layer \t ";
0172 std::cout << "Input = ( " << this->GetInputDepth() << " , " << this->GetInputHeight() << " , " << this->GetInputWidth() << " ) ";
0173 if (this->GetOutput().GetSize() > 0) {
0174 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput().GetHSize() << " , " << this->GetOutput().GetWSize() << " ) ";
0175 }
0176 std::cout << std::endl;
0177 }
0178
0179 template <typename Architecture_t>
0180 auto TReshapeLayer<Architecture_t>::AddWeightsXMLTo(void *parent) -> void
0181 {
0182 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "ReshapeLayer");
0183
0184
0185 gTools().xmlengine().NewAttr(layerxml, nullptr, "Depth", gTools().StringFromInt(this->GetDepth()));
0186 gTools().xmlengine().NewAttr(layerxml, nullptr, "Height", gTools().StringFromInt(this->GetHeight()));
0187 gTools().xmlengine().NewAttr(layerxml, nullptr, "Width", gTools().StringFromInt(this->GetWidth()));
0188 gTools().xmlengine().NewAttr(layerxml, nullptr, "Flattening", gTools().StringFromInt(this->isFlattening()));
0189
0190
0191 }
0192
0193
0194 template <typename Architecture_t>
0195 void TReshapeLayer<Architecture_t>::ReadWeightsFromXML(void * )
0196 {
0197
0198 }
0199
0200
0201
0202 }
0203 }
0204
0205 #endif