Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:53

0001 // @(#)root/tmva/tmva/dnn:$Id$
0002 // Author: Joana Niermann 23/07/19
0003 
0004 /*************************************************************************
0005  * Copyright (C) 2019, Joana Niermann                                    *
0006  * All rights reserved.                                                  *
0007  *                                                                       *
0008  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0009  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0010  *************************************************************************/
0011 
0012 ///////////////////////////////////////////////////////////////////
0013 // Definition of the TCudnn architecture class, which provides   //
0014 // a wrapping of the low-level functionality for neural networks //
0015 // in the cuDNN library.                                         //
0016 ///////////////////////////////////////////////////////////////////
0017 
0018 #ifndef TMVA_DNN_ARCHITECTURES_CUDNN
0019 #define TMVA_DNN_ARCHITECTURES_CUDNN
0020 
0021 #include "RConfigure.h"   // for definition of R__HAS_CUDNN
0022 
0023 #ifndef R__HAS_CUDNN
0024 #error This file can be compiled only when cudnn is available in ROOT
0025 #else
0026 
0027 #include "cudnn.h"
0028 
0029 #include "TMVA/DNN/Functions.h"
0030 #include "TMVA/DNN/CNN/ContextHandles.h"
0031 //#include "TMVA/DNN/CNN/Descriptors.h"
0032 #include "TMVA/DNN/BatchNormLayer.h"
0033 #include "TMVA/DNN/CNN/ConvLayer.h"
0034 #include "TMVA/DNN/CNN/MaxPoolLayer.h"
0035 #include "TMVA/DNN/RNN/RNNLayer.h"
0036 #include "TMVA/DNN/RNN/LSTMLayer.h"
0037 #include "TMVA/DNN/RNN/GRULayer.h"
0038 
0039 
0040 #include "Cuda/CudaBuffers.h"
0041 #include "Cuda/CudaTensor.h"
0042 #include "TMVA/DNN/TensorDataLoader.h"
0043 #include <utility>
0044 #include <vector>
0045 #include <string>
0046 
0047 #include "TMVA/DNN/Architectures/Cuda.h"
0048 
0049 class TRandom;
0050 
0051 namespace TMVA
0052 {
0053 namespace DNN
0054 {
0055 
0056 struct TCudnnEmptyDescriptor {};
0057 
0058 
0059 /** The TCudnn architecture class.
0060  *
0061  * Low-level interface class for CUDA computing architectures using the cuDNN
0062  * library as backend. Contains as public types the declaration of the scalar,
0063  * matrix and buffer types for this architecture, as well as the remaining
0064  * functions in the low-level interface in the form of static members.
0065  */
0066 template<typename AFloat = Float_t>
0067 class TCudnn
0068 {
0069 private:
0070    static TRandom * fgRandomGen;
0071 public:
0072 
0073    using Scalar_t       = AFloat;
0074    using Matrix_t       = TCudaTensor<AFloat>;
0075    using Tensor_t       = TCudaTensor<AFloat>;
0076    using DeviceBuffer_t = TCudaDeviceBuffer<AFloat>;
0077    using HostBuffer_t   = TCudaHostBuffer<AFloat>;
0078 
0079    // The descriptors for the (tensor) data are held by the data classes (CudaTensor)
0080    using ActivationDescriptor_t  = cudnnActivationDescriptor_t;
0081    using ConvolutionDescriptor_t = cudnnConvolutionDescriptor_t;
0082    using DropoutDescriptor_t     = cudnnDropoutDescriptor_t;
0083    using FilterDescriptor_t      = cudnnFilterDescriptor_t;
0084    //using OpTensorDescriptor_t    = cudnnOpTensorDescriptor_t;
0085    using PoolingDescriptor_t     = cudnnPoolingDescriptor_t;
0086    //using ReductionDescriptor_t   = cudnnReduceTensorDescriptor_t;
0087    using AlgorithmForward_t      = cudnnConvolutionFwdAlgo_t;
0088    using AlgorithmBackward_t     = cudnnConvolutionBwdDataAlgo_t;
0089    using AlgorithmHelper_t       = cudnnConvolutionBwdFilterAlgo_t;
0090    using AlgorithmDataType_t     = cudnnDataType_t;
0091    using ReduceTensorDescriptor_t = cudnnReduceTensorDescriptor_t;
0092    using TensorDescriptor_t       = cudnnTensorDescriptor_t;
0093    using RecurrentDescriptor_t    = cudnnRNNDescriptor_t;
0094 #if (CUDNN_VERSION >= 8000)
0095    using RNNDataDescriptor_t         = cudnnRNNDataDescriptor_t;
0096 #else
0097    using RNNDataDescriptor_t         = TCudnnEmptyDescriptor;
0098 #endif
0099    using EmptyDescriptor_t       = TCudnnEmptyDescriptor;        // Used if a descriptor is not needed in a class
0100 
0101    using BNormLayer_t            = TBatchNormLayer<TCudnn<AFloat>>;
0102    using BNormDescriptors_t      = TDNNGenDescriptors<BNormLayer_t>;
0103    //using BNormWorkspace_t        = CNN::TCNNWorkspace<BNormLayer_t>;*/
0104    using ConvLayer_t             = CNN::TConvLayer<TCudnn<AFloat>>;
0105    using ConvDescriptors_t       = CNN::TCNNDescriptors<ConvLayer_t>;
0106    using ConvWorkspace_t         = CNN::TCNNWorkspace<ConvLayer_t>;
0107    using PoolingLayer_t          = CNN::TMaxPoolLayer<TCudnn<AFloat>>;
0108    using PoolingDescriptors_t    = CNN::TCNNDescriptors<PoolingLayer_t>;
0109    using PoolingWorkspace_t      = CNN::TCNNWorkspace<PoolingLayer_t>;
0110 
0111    using RNNLayer_t              = RNN::TBasicRNNLayer<TCudnn<AFloat>>;
0112    using RNNDescriptors_t        = RNN::TRNNDescriptors<TCudnn<AFloat>>;
0113    using RNNWorkspace_t          = RNN::TRNNWorkspace<TCudnn<AFloat>>;
0114 
0115    using LSTMLayer_t             = RNN::TBasicLSTMLayer<TCudnn<AFloat>>;
0116    // using LSTMDescriptors_t       = RNN::TRNNDescriptors<LSTMLayer_t>;
0117    // using LSTMWorkspace_t         = RNN::TRNNWorkspace<LSTMLayer_t>;
0118 
0119    using GRULayer_t              = RNN::TBasicGRULayer<TCudnn<AFloat>>;
0120    // using GRUDescriptors_t        = RNN::TRNNDescriptors<GRULayer_t>;
0121    // using GRUWorkspace_t          = RNN::TRNNWorkspace<GRULayer_t>;
0122 
0123    // template <typename AFloat>
0124    // using ConvDescriptors_t = CNN::TCNNDescriptors<CNN::TConvLayer<TCudnn<AFloat>>>;
0125 
0126    // convolution options
0127    // default is -1 (left to cudnn)
0128    struct CNNOptions  {
0129 
0130       static int ConvFwdAlgorithm;
0131       static int ConvBwdDataAlgorithm;
0132       static int ConvBwdFilterAlgorithm;
0133       // default is 0 (left to cudnn : a value -1 will indicate to not use any space)
0134       static Long_t ConvMaxWorkspaceSize;
0135    }; // namespace DNN
0136 
0137    static TMVA::Experimental::MemoryLayout GetTensorLayout() { return TMVA::Experimental::MemoryLayout::RowMajor; }
0138 
0139 
0140    static Tensor_t CreateTensor(size_t n, size_t c, size_t h, size_t w) {
0141       return Tensor_t( {n,c,h,w}, GetTensorLayout(), 0, 0);
0142    }
0143 
0144    static Tensor_t CreateTensor(DeviceBuffer_t buffer, size_t n, size_t c, size_t h, size_t w) {
0145       return Tensor_t( buffer, {n,c,h,w}, GetTensorLayout(), 0, 0);
0146    }
0147 
0148    static Tensor_t CreateTensor(size_t n, size_t c, size_t w)
0149    {
0150       return Tensor_t({n, c, w}, GetTensorLayout(), 0, 0);
0151    }
0152 
0153    static Tensor_t CreateTensor(DeviceBuffer_t buffer, size_t n, size_t c, size_t w)
0154    {
0155       return Tensor_t(buffer, {n, c, w}, GetTensorLayout(), 0, 0);
0156    }
0157 
0158    static bool IsCudnn() { return true; }
0159 
0160    // create a weight tensor/matrix vector   from another tensor/weight  vector using the given tensor shapes
0161    // this function is used by the optimizers to store intermediate weights representations
0162    static void  CreateWeightTensors( std::vector<Matrix_t> & newWeights, const std::vector<Matrix_t> & weights) {
0163       if (!newWeights.empty()) newWeights.clear();
0164       size_t n =  weights.size();
0165       for (size_t i = 0; i < n; ++i)
0166          newWeights.emplace_back( weights[i].GetShape(), weights[i].GetLayout(), 0, 0);
0167    }
0168    //____________________________________________________________________________
0169    //
0170    // Architecture Initialization
0171    //____________________________________________________________________________
0172 
0173    static void InitializeBNormDescriptors(TDescriptors * & descriptors,
0174                                           BNormLayer_t *L = nullptr);
0175 
0176    static void InitializeConvDescriptors(TDescriptors * & descriptors,
0177                                          ConvLayer_t *L = nullptr);
0178 
0179    static void InitializePoolDescriptors(TDescriptors * & descriptors,
0180                                         PoolingLayer_t *L = nullptr);
0181 
0182    static void InitializeRNNDescriptors(TDescriptors *&descriptors, RNNLayer_t *layer)
0183    {
0184       InitializeRecurrentDescriptors<RNNLayer_t>(descriptors, layer);
0185    }
0186    static void InitializeLSTMDescriptors(TDescriptors *&descriptors, LSTMLayer_t *layer) {
0187       InitializeRecurrentDescriptors<LSTMLayer_t>(descriptors, layer);
0188    }
0189    static void InitializeGRUDescriptors(TDescriptors *&descriptors, GRULayer_t *layer) {
0190       InitializeRecurrentDescriptors<GRULayer_t>(descriptors, layer);
0191    }
0192    template<typename RNNLayer>
0193    static void InitializeRecurrentDescriptors(TDescriptors *&descriptors, RNNLayer *L);
0194    // static void InitializeRNNDescriptors(TDescriptors *&descriptors, LSTMLayer_t *L = nullptr);
0195    // static void InitializeRNNDescriptors(TDescriptors *&descriptors, GRULayer_t *L = nullptr);
0196 
0197    static void InitializeActivationDescriptor(ActivationDescriptor_t & descriptors, EActivationFunction activFunc, double coef = 0.0);
0198 
0199    static void ReleaseConvDescriptors(TDescriptors    * descriptors );
0200    static void ReleasePoolDescriptors(TDescriptors * descriptors );
0201    static void ReleaseRNNDescriptors(TDescriptors *descriptors);
0202    static void ReleaseBNormDescriptors(TDescriptors * descriptors );
0203    static void ReleaseDescriptor(EmptyDescriptor_t       & emptyDescr) {}        // Does nothing
0204    static void ReleaseDescriptor(ActivationDescriptor_t  & activationDescr);
0205    static void ReleaseDescriptor(ConvolutionDescriptor_t & convolutionDescr);
0206    static void ReleaseDescriptor(DropoutDescriptor_t     & dropoutDescr);
0207    static void ReleaseDescriptor(FilterDescriptor_t      & filterDescr);
0208    static void ReleaseDescriptor(PoolingDescriptor_t     & poolingDescr);
0209    static void ReleaseDescriptor(TensorDescriptor_t      & tensorDescr);
0210 
0211 
0212    static void InitializeConvWorkspace(TWorkspace * & workspace,
0213                                        TDescriptors * & descriptors,
0214                                        const DNN::CNN::TConvParams & params,
0215                                        ConvLayer_t *L = nullptr);
0216    static void InitializePoolDropoutWorkspace(TWorkspace * & workspace,
0217                                        TDescriptors * & descriptors,
0218                                        const DNN::CNN::TConvParams & params,
0219                                        PoolingLayer_t *L = nullptr);
0220 
0221    static void InitializeRNNWorkspace(TWorkspace *&workspace, TDescriptors *&descriptors, RNNLayer_t *layer)
0222    {
0223       InitializeRecurrentWorkspace<RNNLayer_t>(workspace, descriptors, layer);
0224    }
0225    static void InitializeLSTMWorkspace(TWorkspace *&workspace, TDescriptors *&descriptors, LSTMLayer_t *layer)
0226    {
0227       InitializeRecurrentWorkspace<LSTMLayer_t>(workspace, descriptors, layer);
0228    }
0229    static void InitializeGRUWorkspace(TWorkspace *&workspace, TDescriptors *&descriptors, GRULayer_t *layer)
0230    {
0231       InitializeRecurrentWorkspace<GRULayer_t>(workspace, descriptors, layer);
0232    }
0233    template<typename RNNLayer>
0234    static void InitializeRecurrentWorkspace(TWorkspace *&workspace, TDescriptors *&descriptors,
0235                                              RNNLayer *layer);
0236 
0237    static void FreeConvWorkspace(TWorkspace * workspace);
0238    static void FreePoolDropoutWorkspace(TWorkspace * workspace);
0239    static void FreeRNNWorkspace(TWorkspace *workspace);
0240 
0241    // tensor inizialization for recurrent networks
0242    static void InitializeRNNTensors(RNNLayer_t *layer) { InitializeRecurrentTensors<RNNLayer_t>(layer); }
0243    static void InitializeLSTMTensors(LSTMLayer_t *layer) { InitializeRecurrentTensors<LSTMLayer_t>(layer); }
0244    static void InitializeGRUTensors(GRULayer_t *layer) { InitializeRecurrentTensors<GRULayer_t>(layer); }
0245    template <typename RNNLayer>
0246    static void InitializeRecurrentTensors(RNNLayer *layer);
0247 
0248    //____________________________________________________________________________
0249    //
0250    // Propagation
0251    //____________________________________________________________________________
0252 
0253    /** @name Forward Propagation
0254     * Low-level functions required for the forward propagation of activations
0255     * through the network.
0256     */
0257       ///@{
0258    /** Matrix-multiply \p input with the transpose of \pweights and
0259     *  write the results into \p output. */
0260    static void MultiplyTranspose(Tensor_t &output, const Tensor_t &input, const Matrix_t &weights);
0261 
0262    /** Add the vectors biases row-wise to the matrix output */
0263    static void AddRowWise(Tensor_t &output,const Matrix_t &biases);
0264 
0265    /** @name Backward Propagation (Dense Layers)
0266     * Low-level functions required for the forward propagation of activations
0267     * through the network.
0268     */
0269       ///@{
0270    /** Perform the complete backward propagation step. If the provided
0271     *  \p activationGradientsBackward matrix is not empty, compute the
0272     *  gradients of the objective function with respect to the activations
0273     *  of the previous layer (backward direction).
0274     *  Also compute the weight and the bias gradients. Modifies the values
0275     *  in \p df and thus produces only a valid result, if it is applied the
0276     *  first time after the corresponding forward propagation has been per-
0277     *  formed. */
0278    static void Backward(Tensor_t & activationGradientsBackward,
0279                         Matrix_t & weightGradients,
0280                         Matrix_t & biasGradients,
0281                         Tensor_t & df,
0282                         const Tensor_t & activationGradients,
0283                         const Matrix_t & weights,
0284                         const Tensor_t & activationBackward);
0285 
0286    /** Above functions extended to vectors */
0287    static void ScaleAdd(Tensor_t & A, const Tensor_t & B,
0288                         Scalar_t alpha = 1.0,
0289                         Scalar_t beta = 1.0);
0290 
0291    /** Deep copy from B to A. */
0292    static void Copy(Tensor_t & A, const Tensor_t & B);
0293 
0294    // copy from another tensor
0295    template<typename ATensor_t>
0296    static void CopyDiffArch(Tensor_t & A,
0297                             const ATensor_t & B);
0298 
0299    template <typename ATensor_t>
0300    static void CopyWeightsDiffArch(Tensor_t &A, const ATensor_t &B);
0301 
0302    //template<>
0303    static void CopyDiffArch(Tensor_t A, const Tensor_t & B ) { Copy(A,B); }
0304 
0305       // copy from vector of matrices of different types
0306    template<typename AMatrix_t>
0307    static void CopyDiffArch(std::vector<Tensor_t>  & A,
0308                             const std::vector<AMatrix_t> & B);
0309 
0310 
0311    //____________________________________________________________________________
0312    //
0313    // Activation Functions
0314    //____________________________________________________________________________
0315 
0316    /** @name Activation Functions
0317     * For each activation function, the low-level interface contains two routines.
0318     * One that applies the activation function to a matrix and one that evaluate
0319     * the derivatives of the activation function at the elements of a given matrix
0320     * and writes the results into the result matrix.
0321     */
0322    ///@{
0323    static void Identity(Tensor_t & X) {}
0324    static void IdentityDerivative(Tensor_t & dX, Tensor_t& X,
0325                                   Tensor_t & Y,  Tensor_t & dY,
0326                                   ActivationDescriptor_t activationDescr,
0327                                   const AFloat alpha = 1,
0328                                   const AFloat beta = 1) {}
0329 
0330    static void ActivationFunctionForward(Tensor_t & X, EActivationFunction activFunct,
0331                           const ActivationDescriptor_t activationDescr,
0332                           const double coef = 0.0, const AFloat alpha = 1,
0333                           const AFloat beta = 0);
0334 
0335    // same as above but using different input/output tensors
0336    static void ActivationFunctionForward(Tensor_t &Y, const Tensor_t & X, EActivationFunction activFunct,
0337                                          const ActivationDescriptor_t activationDescr, const double coef = 0.0,
0338                                          const AFloat alpha = 1, const AFloat beta = 0);
0339 
0340    /** Computes the gradient of the activation function */
0341    static void ActivationFunctionBackward(Tensor_t & dX, const Tensor_t & Y,
0342                                           const Tensor_t & dY,  const Tensor_t & X,
0343                                           EActivationFunction activFunct,
0344                                           const ActivationDescriptor_t activationDescr,
0345                                           const AFloat alpha = 1,
0346                                           const AFloat beta = 0);
0347 
0348    //
0349    // No cudnn implementation for the following activation functions
0350    //
0351    //static void SymmetricRelu(Tensor_t & B);
0352 
0353    // implementations not used by Cudnn
0354    static void Relu(Tensor_t &) {}
0355    static void Sigmoid(Tensor_t &) {}
0356    static void Tanh(Tensor_t &) {}
0357    static void FastTanh(Tensor_t &) {}
0358    static void SymmetricRelu(Tensor_t &) {}
0359    static void SoftSign(Tensor_t &) {}
0360    static void Gauss(Tensor_t &) {}
0361 
0362    static void IdentityDerivative(Tensor_t &, const Tensor_t &) {}
0363    static void ReluDerivative(Tensor_t &, const Tensor_t &) {}
0364    static void SigmoidDerivative(Tensor_t &, const Tensor_t &) {}
0365    static void TanhDerivative(Tensor_t &, const Tensor_t &) {}
0366    static void FastTanhDerivative(Tensor_t &, const Tensor_t &) {}
0367    static void SymmetricReluDerivative(Tensor_t & , const Tensor_t & ) {}
0368    static void SoftSignDerivative(Tensor_t & , const Tensor_t & ) {}
0369    static void GaussDerivative(Tensor_t & ,  const Tensor_t & ) {}
0370    ///@}
0371 
0372    //____________________________________________________________________________
0373    //
0374    // Loss Functions
0375    //____________________________________________________________________________
0376 
0377    /** @name Loss Functions
0378     * Loss functions compute a scalar value given the \p output of the network
0379     * for a given training input and the expected network prediction \p Y that
0380     * quantifies the quality of the prediction. For each function also a routing
0381     * that computes the gradients (suffixed by Gradients) must be provided for
0382     * the starting of the backpropagation algorithm.
0383     */
0384       ///@{
0385 
0386    static Scalar_t MeanSquaredError(const Matrix_t &Y, const Matrix_t &output,
0387                                     const Matrix_t &weights);
0388    static void MeanSquaredErrorGradients(Matrix_t &dY, const Matrix_t &Y,
0389                                          const Matrix_t &output, const Matrix_t &weights);
0390 
0391    /** Sigmoid transformation is implicitly applied, thus \p output should
0392     *  hold the linear activations of the last layer in the net. */
0393    static Scalar_t CrossEntropy(const Matrix_t &Y, const Matrix_t &output,
0394                                 const Matrix_t &weights);
0395 
0396    static void CrossEntropyGradients(Matrix_t &dY, const Matrix_t &Y,
0397                                      const Matrix_t &output, const Matrix_t &weights);
0398 
0399    /** Softmax transformation is implicitly applied, thus \p output should
0400     *  hold the linear activations of the last layer in the net. */
0401    static Scalar_t SoftmaxCrossEntropy(const Matrix_t &Y, const Matrix_t &output,
0402                                        const Matrix_t &weights);
0403    static void SoftmaxCrossEntropyGradients(Matrix_t &dY, const Matrix_t &Y,
0404                                             const Matrix_t &output, const Matrix_t &weights);
0405    ///@}
0406 
0407    //____________________________________________________________________________
0408    //
0409    // Output Functions
0410    //____________________________________________________________________________
0411 
0412    /** @name Output Functions
0413     * Output functions transform the activations \p output of the
0414     * output layer in the network to a valid prediction \p YHat for
0415     * the desired usage of the network, e.g.  the identity function
0416     * for regression or the sigmoid transformation for two-class
0417     * classification.
0418     */
0419    ///@{
0420    static void Sigmoid(Matrix_t &YHat,
0421                        const Matrix_t & );
0422    static void Softmax(Matrix_t &YHat,
0423                        const Matrix_t & );
0424    ///@}
0425 
0426 
0427 
0428       //____________________________________________________________________________
0429       //
0430       // Dropout
0431       //____________________________________________________________________________
0432 
0433    /** @name Dropout
0434     */
0435       ///@{
0436 
0437    /** Apply dropout with activation probability \p p to the given
0438     *  tensor \p A and scale the result by reciprocal of \p p. */
0439    static void DropoutForward(Tensor_t & A,
0440                               TDescriptors * descriptors,
0441                               TWorkspace         * workspace,
0442                               Scalar_t p);
0443 
0444    static void DropoutBackward(Tensor_t & A,
0445                                TDescriptors * descriptors,
0446                                TWorkspace   * workspace);
0447 
0448       ///@}
0449 
0450    //____________________________________________________________________________
0451    //
0452    // Batch Normalization
0453    //____________________________________________________________________________
0454 
0455    /** @name Batch Normalization Layer Propagation
0456     */
0457    ///@{
0458 
0459    /** The input from each batch are normalized during training to have zero mean and unit variance
0460      * and they are then scaled by two parameter, different for each input variable:
0461      *  - a scale factor \gamma gamma
0462      *  - an offset \beta beta */
0463 
0464    static void BatchNormLayerForwardTraining(int axis, const Tensor_t &x, Tensor_t &y, Matrix_t &gamma, Matrix_t &beta,
0465                                              Matrix_t &mean, Matrix_t &, Matrix_t &iVariance, Matrix_t &runningMeans,
0466                                              Matrix_t &runningVars, Scalar_t nTrainedBatches, Scalar_t momentum,
0467                                              Scalar_t epsilon, const TensorDescriptor_t &bnParDescriptor);
0468 
0469    /** During inference the inputs are not normalized using the batch mean but the previously computed
0470     * at  running mean and variance */
0471 
0472    static void BatchNormLayerForwardInference(int axis, const Tensor_t &x, Matrix_t &gamma, Matrix_t &beta,
0473                                               Tensor_t &y, const Matrix_t &runningMeans,
0474                                               const Matrix_t &runningVars, Scalar_t epsilon,
0475                                               const TensorDescriptor_t &);
0476 
0477    static void BatchNormLayerBackward(int axis, const Tensor_t &x, const Tensor_t &dy, Tensor_t &dx,
0478                                       Matrix_t &gamma, //  Matrix_t &beta, (not needed)
0479                                       Matrix_t &dgamma, Matrix_t &dbeta, const Matrix_t &mean, const Matrix_t &variance,
0480                                       const Matrix_t &iVariance, Scalar_t epsilon, const TensorDescriptor_t &);
0481 
0482    //____________________________________________________________________________
0483    //
0484    // Regularization
0485    //____________________________________________________________________________
0486 
0487    /** @name Regularization
0488     * For each regularization type two functions are required, one named
0489     * <tt><Type>Regularization</tt> that evaluates the corresponding
0490     * regularization functional for a given weight matrix and the
0491     * <tt>Add<Type>RegularizationGradients</tt>, that adds the regularization
0492     * component in the gradients to the provided matrix.
0493     */
0494 
0495    static Scalar_t L1Regularization(const Matrix_t &W)
0496    {
0497       TCudaMatrix<AFloat> mW(W.GetDeviceBuffer(), W.GetSize(), 1);
0498       return TCuda<AFloat>::L1Regularization(mW);
0499    }
0500    static void AddL1RegularizationGradients(Matrix_t &A, const Matrix_t &W, Scalar_t weightDecay)
0501    {
0502       TCudaMatrix<AFloat> mA(A.GetDeviceBuffer(), A.GetSize(), 1);
0503       TCudaMatrix<AFloat> mW(W.GetDeviceBuffer(), W.GetSize(), 1);
0504       return TCuda<AFloat>::AddL1RegularizationGradients(mA, mW, weightDecay);
0505    }
0506 
0507    static Scalar_t L2Regularization(const Matrix_t &W)
0508    {
0509       TCudaMatrix<AFloat> mW(W.GetDeviceBuffer(), W.GetSize(), 1);
0510       return TCuda<AFloat>::L2Regularization(mW);
0511    }
0512    static void AddL2RegularizationGradients(Matrix_t &A, const Matrix_t &W, Scalar_t weightDecay)
0513    {
0514       TCudaMatrix<AFloat> mA(A.GetDeviceBuffer(), A.GetSize(), 1);
0515       TCudaMatrix<AFloat> mW(W.GetDeviceBuffer(), W.GetSize(), 1);
0516       return TCuda<AFloat>::AddL1RegularizationGradients(mA, mW, weightDecay);
0517    }
0518     ///@}
0519 
0520     //____________________________________________________________________________
0521     //
0522     // Initialization
0523     //____________________________________________________________________________
0524 
0525     /** @name Initialization
0526      * For each initialization method, one function in the low-level interface
0527      * is provided. The naming scheme is <p>Initialize<Type></p> for a given
0528      * initialization method Type.
0529      */
0530     ///@{
0531 
0532    static void InitializeGauss(Matrix_t &A);
0533    static void InitializeUniform(Matrix_t &A);
0534    static void InitializeIdentity(Matrix_t &A);
0535    static void InitializeZero(Matrix_t &A);
0536    static void InitializeGlorotNormal(Matrix_t &A);
0537    static void InitializeGlorotUniform(Matrix_t &A);
0538 
0539    // return static instance of random generator used for initialization
0540    // if generator does not exist it is created the first time with a random seed (e.g. seed = 0)
0541    static TRandom &GetRandomGenerator();
0542    // set random seed for the static generator
0543    // if the static generator does not exists it is created
0544    static void SetRandomSeed(size_t seed);
0545    ///@}
0546 
0547    //____________________________________________________________________________
0548    //
0549    // Dropout
0550    //____________________________________________________________________________
0551 
0552    /** @name Dropout
0553     */
0554    ///@{
0555 
0556    /** Apply dropout with activation probability \p p to the given
0557     *  tensor \p A and scale the result by reciprocal of \p p. */
0558    static void Dropout(Tensor_t &A, Scalar_t p) {}
0559 
0560    ///@}
0561 
0562    //____________________________________________________________________________
0563    //
0564    //  Convolutional Layer Propagation
0565    //____________________________________________________________________________
0566 
0567    /** @name Forward Propagation in Convolutional Layer
0568     */
0569    ///@{
0570 
0571    /** Add the biases in the Convolutional Layer.  */
0572    static void AddConvBiases(Matrix_t &output, const Matrix_t &biases);
0573    ///@}
0574 
0575    /** Dummy placeholder - preparation is currently only required for the CUDA architecture. */
0576    static void PrepareInternals(Tensor_t &) {}
0577 
0578    /** Forward propagation in the Convolutional layer */
0579    static void ConvLayerForward(Tensor_t &output,
0580                                 Tensor_t &inputActivationFunc, // this is output conv w/o activ func.
0581                                 const Tensor_t &input, const Matrix_t &weights, const Matrix_t &biases,
0582                                 const DNN::CNN::TConvParams &params, EActivationFunction activFunc,
0583                                 Tensor_t & /* inputPrime */, const ConvDescriptors_t &descriptors,
0584                                 ConvWorkspace_t &workspace);
0585    // const AFloat alpha = 1,
0586    // const AFloat beta  = 1);
0587 
0588    /** @name Backward Propagation in Convolutional Layer
0589     */
0590    ///@{
0591 
0592    /** Perform the complete backward propagation step in a Convolutional Layer.
0593     *  If the provided \p activationGradientsBackward matrix is not empty, compute the
0594     *  gradients of the objective function with respect to the activations
0595     *  of the previous layer (backward direction).
0596     *  Also compute the weight and the bias gradients. Modifies the values
0597     *  in \p df and thus produces only a valid result, if it is applied the
0598     *  first time after the corresponding forward propagation has been per-
0599     *  formed. */
0600    static void ConvLayerBackward(Tensor_t &activationGradientsBackward, Matrix_t &weightGradients,
0601                                  Matrix_t &biasGradients, Tensor_t &inputActivation, Tensor_t &activationGradients,
0602                                  const Matrix_t &weights, const Tensor_t &activationBackward,
0603                                  const Tensor_t &outputTensor, EActivationFunction activFunc,
0604                                  const ConvDescriptors_t &descriptors, ConvWorkspace_t &workspace, size_t /*batchSize*/,
0605                                  size_t /*inputHeight*/, size_t /*inputWidth*/, size_t /*depth*/, size_t /*height*/,
0606                                  size_t /*width*/, size_t /*filterDepth*/, size_t /*filterHeight*/,
0607                                  size_t /*filterWidth*/, size_t /*nLocalViews*/);
0608 
0609    ///@}
0610 
0611    //____________________________________________________________________________
0612    //
0613    //  Max Pooling Layer Propagation
0614    //____________________________________________________________________________
0615    /** @name Forward Propagation in Max Pooling Layer
0616     */
0617    ///@{
0618 
0619    /** Downsample the matrix \p C to the matrix \p A, using max
0620     * operation, such that the winning indices are stored in matrix
0621     * \p B. No winning indices needed for cuDNN. */
0622    static void Downsample(Tensor_t &A, Tensor_t & /*B*/, const Tensor_t &C, const PoolingDescriptors_t &descriptors,
0623                           PoolingWorkspace_t &workspace, size_t imgHeight, size_t imgWidth, size_t fltHeight,
0624                           size_t fltWidth, size_t strideRows, size_t strideCols);
0625 
0626    ///@}
0627 
0628    /** @name Backward Propagation in Max Pooling Layer
0629     */
0630    ///@{
0631    /** Perform the complete backward propagation step in a Pooling Layer. Based on the
0632     *  input to and output from the MaxPoolLayer, the gradients for the winning pixels
0633     *  are computed. */
0634    static void MaxPoolLayerBackward(Tensor_t &activationGradientsBackward, const Tensor_t &activationGradients,
0635                                     const Tensor_t & /*indexMatrix*/, const Tensor_t &inputActivation,
0636                                     const Tensor_t &outputTensor, const PoolingDescriptors_t &descriptors,
0637                                     PoolingWorkspace_t &workspace, size_t imgHeight, size_t imgWidth, size_t fltHeight,
0638                                     size_t fltWidth, size_t strideRows, size_t strideCols, size_t nLocalViews);
0639 
0640    ///@}
0641 
0642    //____________________________________________________________________________
0643    //
0644    //  Reshape Layer Propagation
0645    //____________________________________________________________________________
0646    /** @name Forward and Backward Propagation in Reshape Layer
0647     */
0648    ///@{
0649 
0650    /** Transform the matrix \p B to a matrix with different dimensions \p A */
0651    // static void Reshape(Matrix_t &A, const Matrix_t &B);
0652 
0653    /** Flattens the tensor \p B, such that each matrix, is stretched in
0654     *  one row, resulting with a matrix \p A. */
0655    static void Flatten(Tensor_t &A, const Tensor_t &B);
0656 
0657    /** Transforms each row of \p B to a matrix and stores it in the
0658     *  tensor \p B. */
0659    static void Deflatten(Tensor_t &A, const Tensor_t &B); // size_t index, size_t nRows,size_t nCols);
0660 
0661    /** Rearrage data according to time fill B x T x D out with T x B x D matrix in*/
0662    static void Rearrange(Tensor_t &out, const Tensor_t &in);
0663 
0664    // RNN functions
0665    static void RNNForward(const Tensor_t &x, const Tensor_t &hx, const Tensor_t &cx, const Tensor_t &weights,
0666                            Tensor_t &y, Tensor_t &hy, Tensor_t &cy, const RNNDescriptors_t &descr,
0667                            RNNWorkspace_t &workspace, bool isTraining);
0668 
0669    static void RNNBackward(const Tensor_t &x, const Tensor_t &hx, const Tensor_t &cx, const Tensor_t &y, const Tensor_t &dy,
0670                     const Tensor_t &dhy, const Tensor_t &dcy, const Tensor_t &weights, Tensor_t &dx, Tensor_t &dhx,
0671                     Tensor_t &dcx, Tensor_t &dw, const RNNDescriptors_t &desc, RNNWorkspace_t &workspace);
0672 
0673 
0674    // Backward pass for Recurrent Networks functions used by another architectures
0675    //******************************************************************************************
0676    static Matrix_t &RecurrentLayerBackward(Matrix_t &state_gradients_backward, // BxH
0677                                            Matrix_t & /* input_weight_gradients */,
0678                                            Matrix_t & /* state_weight_gradients */, Matrix_t & /* bias_gradients */,
0679                                            Matrix_t & /* df */,                  // DxH
0680                                            const Matrix_t & /* state */,         // BxH
0681                                            const Matrix_t & /* weights_input */, // HxD
0682                                            const Matrix_t & /* weights_state */, // HxH
0683                                            const Matrix_t & /* input */,         // BxD
0684                                            Matrix_t & /* input_gradient */)
0685    {
0686       return state_gradients_backward;
0687    }
0688    static Matrix_t &LSTMLayerBackward(
0689       Matrix_t & state_gradients_backward , Matrix_t & /*cell_gradients_backward*/,
0690       Matrix_t & /*input_weight_gradients*/, Matrix_t & /*forget_weight_gradients*/,
0691       Matrix_t & /*candidate_weight_gradients*/, Matrix_t & /*output_weight_gradients*/,
0692       Matrix_t & /*input_state_weight_gradients*/, Matrix_t & /*forget_state_weight_gradients*/,
0693       Matrix_t & /*candidate_state_weight_gradients*/,
0694       Matrix_t & /*output_state_weight_gradients*/, Matrix_t & /*input_bias_gradients*/,
0695       Matrix_t & /*forget_bias_gradients*/, Matrix_t & /*candidate_bias_gradients*/,
0696       Matrix_t & /*output_bias_gradients*/, Matrix_t & /*di*/, Matrix_t & /*df*/,
0697       Matrix_t & /*dc*/, Matrix_t & /*dout*/,
0698       const Matrix_t & /*precStateActivations*/, const Matrix_t & /*precCellActivations*/,
0699       const Matrix_t & /*fInput*/, const Matrix_t & /*fForget*/,
0700       const Matrix_t & /*fCandidate*/, const Matrix_t & /*fOutput*/,
0701       const Matrix_t & /*weights_input*/, const Matrix_t & /*weights_forget*/,
0702       const Matrix_t & /*weights_candidate*/, const Matrix_t & /*weights_output*/,
0703       const Matrix_t & /*weights_input_state*/, const Matrix_t & /*weights_forget_state*/,
0704       const Matrix_t & /*weights_candidate_state*/, const Matrix_t & /*weights_output_state*/,
0705       const Matrix_t & /*input*/, Matrix_t & /*input_gradient*/,
0706       Matrix_t & /*cell_gradient*/, Matrix_t & /*cell_tanh*/)
0707    {
0708       return state_gradients_backward;
0709    }
0710 
0711    /** Backward pass for GRU Network */
0712    static Matrix_t &GRULayerBackward(
0713       Matrix_t &  state_gradients_backward, Matrix_t & /*reset_weight_gradients*/,
0714       Matrix_t & /*update_weight_gradients*/, Matrix_t & /*candidate_weight_gradients*/,
0715       Matrix_t & /*reset_state_weight_gradients*/, Matrix_t & /*update_state_weight_gradients*/,
0716       Matrix_t & /*candidate_state_weight_gradients*/, Matrix_t & /*reset_bias_gradients*/,
0717       Matrix_t & /*update_bias_gradients*/, Matrix_t & /*candidate_bias_gradients*/,
0718       Matrix_t & /*dr*/, Matrix_t & /*du*/, Matrix_t & /*dc*/,
0719       const Matrix_t & /*precStateActivations*/, const Matrix_t & /*fReset*/,
0720       const Matrix_t & /*fUpdate*/, const Matrix_t & /*fCandidate*/,
0721       const Matrix_t & /*weights_reset*/, const Matrix_t & /*weights_update*/,
0722       const Matrix_t & /*weights_candidate*/, const Matrix_t & /*weights_reset_state*/,
0723       const Matrix_t & /*weights_update_state*/, const Matrix_t & /*weights_candidate_state*/,
0724       const Matrix_t & /*input*/, Matrix_t & /*input_gradient*/, bool)
0725    {
0726       return state_gradients_backward;
0727    }
0728 
0729    ///@}
0730 
0731    //____________________________________________________________________________
0732    //
0733    // Additional Arithmetic Functions
0734    //____________________________________________________________________________
0735 
0736    /** @name Additional Arithmetic Functions
0737     *
0738     * Additional arithmetic on CUDA matrices  used to implement the low-level
0739     * interface.
0740     */
0741 
0742    /** In-place Hadamard (element-wise) product of matrices \p A and \p B
0743     *  with the result being written into \p A.
0744     */
0745    static void Hadamard(Tensor_t &A, const Tensor_t &B)
0746    {
0747       TCudaMatrix<AFloat> tmpA(A.GetDeviceBuffer(), 1, A.GetSize());
0748       TCudaMatrix<AFloat> tmpB(B.GetDeviceBuffer(), 1, B.GetSize());
0749       assert(A.GetSize() == B.GetSize());
0750       TCuda<AFloat>::Hadamard(tmpA, tmpB);
0751    }
0752    // static void Hadamard(Matrix_t &A,
0753    //                      const Matrix_t &B);*/
0754    // {
0755    //    Tensor_t tA(A);
0756    //    Hadamard( tA, Tensor_t(B));
0757    // }
0758 
0759 
0760    /** Compute the sum of all elements in \p A */
0761    static Scalar_t Sum(const Matrix_t &A, Scalar_t alpha = 1.0, Scalar_t beta = 0.0);
0762 
0763    /** Check two matrices for equality, taking floating point arithmetic errors into account. */
0764    //static bool AlmostEquals(const Matrix_t &A, const Matrix_t &B, double epsilon = 0.1);
0765 
0766    /** Add the constant \p beta to all the elements of matrix \p A and write the
0767     * result into \p A.
0768     */
0769    static void ConstAdd(Matrix_t &A, Scalar_t beta) {
0770       TCudaMatrix<AFloat> tmp(A.GetDeviceBuffer(), 1, A.GetSize());
0771       TCuda<AFloat>::ConstAdd(tmp,beta);
0772    }
0773 
0774    /** Multiply the constant \p beta to all the elements of matrix \p A and write the
0775     * result into \p A.
0776     */
0777    static void ConstMult(Matrix_t &A, Scalar_t beta) {
0778       TCudaMatrix<AFloat> tmp(A.GetDeviceBuffer(), 1, A.GetSize());
0779       TCuda<AFloat>::ConstMult(tmp,beta);
0780    }
0781 
0782    /** Reciprocal each element of the matrix \p A and write the result into
0783     * \p A
0784     */
0785    static void ReciprocalElementWise(Matrix_t &A) {
0786       TCudaMatrix<AFloat> tmp(A.GetDeviceBuffer(), 1, A.GetSize());
0787       TCuda<AFloat>::ReciprocalElementWise(tmp);
0788    }
0789 
0790    /** Square each element of the matrix \p A and write the result into
0791     * \p A
0792     */
0793    static void SquareElementWise(Matrix_t &A) {
0794       TCudaMatrix<AFloat> tmp(A.GetDeviceBuffer(), 1, A.GetSize());
0795       TCuda<AFloat>::SquareElementWise(tmp);
0796    }
0797 
0798    /** Square root each element of the matrix \p A and write the result into
0799     * \p A
0800     */
0801    //static void SqrtElementWise(Matrix_t &A, Scalar_t alpha = 1, Scalar_t beta = 0, Scalar_t gamma = 0) {
0802    static void SqrtElementWise(Matrix_t &A) {
0803       TCudaMatrix<AFloat> tmp(A.GetDeviceBuffer(), 1, A.GetSize());
0804       TCuda<AFloat>::SqrtElementWise(tmp);
0805    }
0806 
0807       // optimizer functions
0808    static void AdamUpdate(Matrix_t & A, const Matrix_t & M, const Matrix_t & V, Scalar_t alpha, Scalar_t eps) {
0809       TCudaMatrix<AFloat> tmpA(A.GetDeviceBuffer(), A.GetSize(),1);
0810       TCudaMatrix<AFloat> tmpM(M.GetDeviceBuffer(), M.GetSize(),1);
0811       TCudaMatrix<AFloat> tmpV(V.GetDeviceBuffer(), V.GetSize(),1);
0812       TCuda<AFloat>::AdamUpdate(tmpA, tmpM, tmpV,alpha, eps);
0813    }
0814    static void AdamUpdateFirstMom(Matrix_t & A, const Matrix_t & B, Scalar_t beta) {
0815       TCudaMatrix<AFloat> tmpA(A.GetDeviceBuffer(), A.GetSize(),1);
0816       TCudaMatrix<AFloat> tmpB(B.GetDeviceBuffer(), B.GetSize(),1);
0817       TCuda<AFloat>::AdamUpdateFirstMom(tmpA, tmpB,  beta);
0818    }
0819    static void AdamUpdateSecondMom(Matrix_t & A, const Matrix_t & B, Scalar_t beta) {
0820       TCudaMatrix<AFloat> tmpA(A.GetDeviceBuffer(), A.GetSize(),1);
0821       TCudaMatrix<AFloat> tmpB(B.GetDeviceBuffer(), B.GetSize(),1);
0822       TCuda<AFloat>::AdamUpdateSecondMom(tmpA, tmpB,  beta);
0823    }
0824 
0825       // printing of tensor
0826    static void PrintTensor( const Tensor_t & A, const std::string name = "tensor", bool = true);
0827 
0828    static void PrintTensor4dDescriptor(TensorDescriptor_t descriptor);
0829    static void PrintTensorNdDescriptor(TensorDescriptor_t descriptor, int n = 10);
0830 
0831    ///////////////////////////////////////////////////////////////////////////////
0832    /// extra functions defined only for CPU architecture !!!
0833    //////////////////////////////////////////////////////////////////////////////
0834 
0835    /** Sum rows of (m x n) matrix \p A and write the results into the first
0836     * m elements in \p B.
0837     */
0838    static void SumRows(Matrix_t &B, const Matrix_t &A);
0839 };
0840 
0841 
0842 //____________________________________________________________________________
0843 template <typename AFloat>
0844 template <typename ATensor>
0845 void TCudnn<AFloat>::CopyDiffArch(TCudaTensor<AFloat> &B,
0846                         const ATensor &A)
0847 {
0848 
0849    // should add static assert that A has not to be same type as B
0850 
0851    // this copying tensors from different architectures
0852    if (B.GetLayout() == GetTensorLayout()) {
0853       if ( B.GetShape().size() == 4) {
0854          assert(B.GetShape().size() == 4);
0855          size_t firstSize = (A.GetLayout() == GetTensorLayout()) ? A.GetShape()[0] : A.GetShape().back();
0856          for (size_t i = 0; i < firstSize; ++i) {
0857             TMatrixT<AFloat> matIn = A.At(i).GetMatrix(); // this convert tensor (B,D,HW) in  (D,HW)i -> (D,HW)i
0858             // TMAtrix has the correct layout (row-wise) no need to traspose in this case
0859             TCudaTensor<AFloat> tmpOut = B.At(i); // matrix (D,HW)
0860             // copy will copy the buffer
0861             TCudaTensor<AFloat> tmpIn(matIn.GetMatrixArray(), tmpOut.GetShape(), tmpOut.GetLayout());
0862             Copy(tmpOut, tmpIn);
0863          }
0864       }
0865       else {
0866          // for RNN weights
0867          TMatrixT<AFloat> tmp = A;
0868          TCudaMatrix<AFloat> tmp2(tmp);
0869          TCudaTensor<AFloat> tA(tmp2);
0870          Copy(B, tA);
0871       }
0872    } else {
0873       // case of same layout (column major)
0874       TMatrixT<AFloat> tmp = A;
0875       TCudaMatrix<AFloat> tmp2(tmp);
0876       TCudaTensor<AFloat> tA(tmp2);
0877       Copy(B, tA);
0878    }
0879 }
0880 
0881 //____________________________________________________________________________
0882 template <typename AFloat>
0883 template <typename AMatrix>
0884 void TCudnn<AFloat>::CopyWeightsDiffArch(TCudaTensor<AFloat> &B, const  AMatrix &A)
0885 {
0886    // copy from another architecture using the reference one
0887    // this is not very efficient since creates temporary objects
0888    TMatrixT<AFloat> tmp = A; // .GetMatrix();
0889    // we need to traspose for different layout
0890    if (B.GetLayout() == GetTensorLayout()  ) {
0891       // this is for CNN weights that are in row-major formats
0892       //assert(B.GetShape().size() == 4);  // weights shape should be 4
0893       tmp.T();
0894    }
0895    TCudaMatrix<AFloat> tmp2(tmp);
0896    TCudaTensor<AFloat> tA(tmp2);
0897    Copy(B, tA);
0898 }
0899 
0900 //____________________________________________________________________________
0901 template <typename AFloat>
0902 template <typename AMatrix_t>
0903 void TCudnn<AFloat>::CopyDiffArch(std::vector<Tensor_t> &B,
0904                             const std::vector<AMatrix_t> &A)
0905 {
0906    for (size_t i = 0; i < B.size(); ++i) {
0907       CopyWeightsDiffArch(B[i], A[i]);
0908    }
0909 }
0910 
0911 template <typename AFloat>
0912 void TCudnn<AFloat>::PrintTensor(const typename TCudnn<AFloat>::Tensor_t & A, const std::string name, bool truncate )
0913 {
0914    std::cout << name << "  size = " << A.GetSize() << " shape = { ";
0915    auto shape = A.GetShape();
0916    for (size_t k = 0; k < shape.size()-1; ++k)
0917       std::cout << shape[k] << " , ";
0918    std::cout << shape.back() << " } ";
0919    std::cout << " strides = { ";
0920    auto strides = A.GetStrides();
0921    for (size_t k = 0; k < strides.size()-1; ++k)
0922       std::cout << strides[k] << " , ";
0923    std::cout << strides.back() << " }\n ";
0924    if (A.GetShape().size() == 1 ) {
0925       size_t n =  A.GetShape()[0];
0926       if (truncate) n = std::min(n,size_t(10));
0927       for (size_t j = 0; j < n; ++j) {
0928          std::cout << A(0,j) << " ";
0929       }
0930       if (truncate && n < A.GetShape()[0]) std::cout << " ...... ";
0931       std::cout << " } " << std::endl;
0932    } else if (A.GetShape().size() == 2 ) {
0933       size_t n1 =  A.GetShape()[0];
0934       size_t n2 =  A.GetShape()[1];
0935       if (truncate) n1 = std::min(n1,size_t(10));
0936       for (size_t i = 0; i < n1; ++i) {
0937          std::cout << "{ ";
0938          if (truncate) n2 = std::min(n2,size_t(10));
0939          for (size_t j = 0; j < n2; ++j) {
0940             std::cout << A(i,j) << " ";
0941          }
0942          if (truncate && n2 < A.GetShape()[1]) std::cout << " ...... ";
0943          std::cout << " } " << std::endl;
0944       }
0945       if (truncate && n1 < A.GetShape()[0]) std::cout << " ...............\n";
0946    } else if  (A.GetShape().size() == 3 ) {
0947       size_t n1 =  A.GetFirstSize();
0948       size_t n2 =  A.GetHSize();
0949       size_t n3 = A.GetWSize();
0950       if (truncate) n1 = std::min(n1,size_t(10));
0951       if (truncate) n2 = std::min(n2,size_t(10));
0952       if (truncate) n3 = std::min(n3,size_t(10));
0953       for (size_t i = 0; i < n1; ++i) {
0954          std::cout << "{ ";
0955          for (size_t j = 0; j < n2; ++j) {
0956             std::cout << "{ ";
0957             for (size_t k = 0; k < n3; ++k) {
0958                std::cout << A(i,j,k) << " ";
0959             }
0960             if (truncate && n3 < A.GetWSize()) std::cout << " ...... ";
0961             std::cout << " } " << std::endl;
0962          }
0963          if (truncate && n2 < A.GetHSize()) std::cout << ".................\n";
0964          std::cout << " } " << std::endl;
0965       }
0966       if (truncate && n1 < A.GetFirstSize()) std::cout << "...................\n";
0967    } else if  (A.GetShape().size() == 4 ) {
0968       for (size_t i = 0; i < A.GetShape()[0]; ++i) {
0969          std::cout << "{ ";
0970          for (size_t j = 0; j < A.GetShape()[1]; ++j) {
0971             std::cout << "{ ";
0972             for (size_t k = 0; k < A.GetShape()[2]; ++k) {
0973                size_t n =  A.GetShape()[3];
0974                if (truncate)  n = std::min(n,size_t(10));
0975                for (size_t l = 0; l < n; ++l) {
0976                   std::cout << A(i,j,k,l) << " ";
0977                }
0978                if (truncate && n < A.GetShape()[3]) std::cout << " ...... ";
0979                std::cout << " } " << std::endl;
0980             }
0981             std::cout << " } " << std::endl;
0982          }
0983          std::cout << " } " << std::endl;
0984       }
0985    }
0986    else {
0987       for (size_t l = 0; l < A.GetSize(); ++l) {
0988          std::cout << A.GetData()[l] << " ";
0989       }
0990       std::cout << "\n";
0991    }
0992 }
0993 
0994 template <typename AFloat>
0995 void TCudnn<AFloat>::PrintTensor4dDescriptor(TensorDescriptor_t descriptor) {
0996    int n, c, h, w = 0;
0997    int s1, s2, s3, s4 = 0;
0998    cudnnDataType_t dataType;
0999    cudnnGetTensor4dDescriptor(descriptor, &dataType, &n, &c, &h, &w, &s1, &s2, &s3, &s4);
1000    std::cout << "Descriptor for 4d tensor of shape  { " << n << " , " << c << " , " << h << " , " << w << " }"
1001              << " and strides { " << s1 << " , " << s2 << " , " << s3 << " , " << s4 << " }" << std::endl;
1002 }
1003 template <typename AFloat>
1004 void TCudnn<AFloat>::PrintTensorNdDescriptor(TensorDescriptor_t descriptor, int ndim)
1005 {
1006    int n = 0;
1007    std::vector<int> dims(ndim);
1008    std::vector<int> strides(ndim);
1009    cudnnDataType_t dataType;
1010    cudnnGetTensorNdDescriptor(descriptor, ndim, &dataType, &n, dims.data(), strides.data());
1011    dims.resize(n);
1012    strides.resize(n);
1013    std::cout << "Descriptor for Nd tensor of dim = " << n << " shape  { ";
1014    for (auto d : dims)
1015       std::cout << d << " , ";
1016    std::cout << "} and strides { ";
1017    for (auto s : strides)
1018       std::cout << s << " , ";
1019    std::cout << " }" << std::endl;
1020 }
1021 
1022 // initialize the CNN options
1023 // possible options for forward (from 0 to 7)
1024 //
1025 //  0 : CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
1026 //  1 : CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
1027 //  6  : CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
1028 //  7 : CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED;  (lots of memory)
1029 
1030 // for backward data (from 0 to 5)
1031 //  1 : CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
1032 //  5  CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED;
1033 
1034 template <typename AFloat>
1035 int TCudnn<AFloat>::CNNOptions::ConvFwdAlgorithm = -1;
1036 template <typename AFloat>
1037 int TCudnn<AFloat>::CNNOptions::ConvBwdDataAlgorithm = -1;
1038 template <typename AFloat>
1039 int TCudnn<AFloat>::CNNOptions::ConvBwdFilterAlgorithm = -1;
1040 template <typename AFloat>
1041 Long_t TCudnn<AFloat>::CNNOptions::ConvMaxWorkspaceSize = -1;  // -1 let use Cudnn defaults
1042 
1043 } // namespace DNN
1044 } // namespace TMVA
1045 
1046 #endif
1047 #endif