Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:52

0001 // @(#)root/tmva/tmva/dnn:$Id$
0002 // Author: Simon Pfreundschuh 06/06/17
0003 
0004 /*************************************************************************
0005  * Copyright (C) 2016, Simon Pfreundschuh                                *
0006  * All rights reserved.                                                  *
0007  *                                                                       *
0008  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0009  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0010  *************************************************************************/
0011 
0012 /////////////////////////////////////////////////////////////////////
0013 // Partial specialization of the TDataLoader class to adapt it to  //
0014 // the TMatrix class. Also the data transfer is kept simple, since //
0015 // this implementation (being intended as reference and fallback   //
0016 // is not optimized for performance.                               //
0017 /////////////////////////////////////////////////////////////////////
0018 
0019 #ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
0020 #define TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
0021 
0022 #include "TMVA/DNN/DataLoader.h"
0023 
0024 #include <random>
0025 
0026 namespace TMVA {
0027 namespace DNN {
0028 
0029 template <typename AReal>
0030 class TReference;
0031 
0032 template <typename AData, typename AReal>
0033 class TDataLoader<AData, TReference<AReal>> {
0034 private:
0035    using BatchIterator_t = TBatchIterator<AData, TReference<AReal>>;
0036 
0037    const AData &fData;
0038 
0039    size_t fNSamples;
0040    size_t fBatchSize;
0041    size_t fNInputFeatures;
0042    size_t fNOutputFeatures;
0043    size_t fBatchIndex;
0044 
0045    TMatrixT<AReal> inputMatrix;
0046    TMatrixT<AReal> outputMatrix;
0047    TMatrixT<AReal> weightMatrix;
0048 
0049    std::vector<size_t> fSampleIndices; ///< Ordering of the samples in the epoch.
0050 
0051 public:
0052    TDataLoader(const AData &data, size_t nSamples, size_t batchSize, size_t nInputFeatures, size_t nOutputFeatures,
0053                size_t nthreads = 1);
0054    TDataLoader(const TDataLoader &) = default;
0055    TDataLoader(TDataLoader &&) = default;
0056    TDataLoader &operator=(const TDataLoader &) = default;
0057    TDataLoader &operator=(TDataLoader &&) = default;
0058 
0059    /** Copy input matrix into the given host buffer. Function to be specialized by
0060     *  the architecture-specific backend. */
0061    void CopyInput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
0062    /** Copy output matrix into the given host buffer. Function to be specialized
0063     * by the architecture-specific backend. */
0064    void CopyOutput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
0065    /** Copy weight matrix into the given host buffer. Function to be specialized
0066     * by the architecture-specific backend. */
0067    void CopyWeights(TMatrixT<AReal> &matrix, IndexIterator_t begin);
0068 
0069    BatchIterator_t begin() { return BatchIterator_t(*this); }
0070    BatchIterator_t end() { return BatchIterator_t(*this, fNSamples / fBatchSize); }
0071 
0072    /** Shuffle the order of the samples in the batch. The shuffling is indirect,
0073     *  i.e. only the indices are shuffled. No input data is moved by this
0074     * routine. */
0075    void Shuffle();
0076 
0077    /** Return the next batch from the training set. The TDataLoader object
0078     *  keeps an internal counter that cycles over the batches in the training
0079     *  set. */
0080    TBatch<TReference<AReal>> GetBatch();
0081 };
0082 
0083 template <typename AData, typename AReal>
0084 TDataLoader<AData, TReference<AReal>>::TDataLoader(const AData &data, size_t nSamples, size_t batchSize,
0085                                                    size_t nInputFeatures, size_t nOutputFeatures, size_t /*nthreads*/)
0086    : fData(data), fNSamples(nSamples), fBatchSize(batchSize), fNInputFeatures(nInputFeatures),
0087      fNOutputFeatures(nOutputFeatures), fBatchIndex(0), inputMatrix(batchSize, nInputFeatures),
0088      outputMatrix(batchSize, nOutputFeatures), weightMatrix(batchSize, 1), fSampleIndices()
0089 {
0090    fSampleIndices.reserve(fNSamples);
0091    for (size_t i = 0; i < fNSamples; i++) {
0092       fSampleIndices.push_back(i);
0093    }
0094 }
0095 
0096 template <typename AData, typename AReal>
0097 TBatch<TReference<AReal>> TDataLoader<AData, TReference<AReal>>::GetBatch()
0098 {
0099    fBatchIndex %= (fNSamples / fBatchSize); // Cycle through samples.
0100 
0101    size_t sampleIndex = fBatchIndex * fBatchSize;
0102    IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
0103 
0104    CopyInput(inputMatrix, sampleIndexIterator);
0105    CopyOutput(outputMatrix, sampleIndexIterator);
0106    CopyWeights(weightMatrix, sampleIndexIterator);
0107 
0108    fBatchIndex++;
0109 
0110    return TBatch<TReference<AReal>>(inputMatrix, outputMatrix, weightMatrix);
0111 }
0112 
0113 //______________________________________________________________________________
0114 template <typename AData, typename AReal>
0115 void TDataLoader<AData, TReference<AReal>>::Shuffle()
0116 {
0117    std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), std::default_random_engine{});
0118 }
0119 
0120 } // namespace DNN
0121 } // namespace TMVA
0122 
0123 #endif // TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER