File indexing completed on 2025-01-18 10:10:52
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019 #ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
0020 #define TMVA_DNN_ARCHITECTURES_REFERENCE_DATALOADER
0021
0022 #include "TMVA/DNN/DataLoader.h"
0023
0024 #include <random>
0025
0026 namespace TMVA {
0027 namespace DNN {
0028
0029 template <typename AReal>
0030 class TReference;
0031
0032 template <typename AData, typename AReal>
0033 class TDataLoader<AData, TReference<AReal>> {
0034 private:
0035 using BatchIterator_t = TBatchIterator<AData, TReference<AReal>>;
0036
0037 const AData &fData;
0038
0039 size_t fNSamples;
0040 size_t fBatchSize;
0041 size_t fNInputFeatures;
0042 size_t fNOutputFeatures;
0043 size_t fBatchIndex;
0044
0045 TMatrixT<AReal> inputMatrix;
0046 TMatrixT<AReal> outputMatrix;
0047 TMatrixT<AReal> weightMatrix;
0048
0049 std::vector<size_t> fSampleIndices;
0050
0051 public:
0052 TDataLoader(const AData &data, size_t nSamples, size_t batchSize, size_t nInputFeatures, size_t nOutputFeatures,
0053 size_t nthreads = 1);
0054 TDataLoader(const TDataLoader &) = default;
0055 TDataLoader(TDataLoader &&) = default;
0056 TDataLoader &operator=(const TDataLoader &) = default;
0057 TDataLoader &operator=(TDataLoader &&) = default;
0058
0059
0060
0061 void CopyInput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
0062
0063
0064 void CopyOutput(TMatrixT<AReal> &matrix, IndexIterator_t begin);
0065
0066
0067 void CopyWeights(TMatrixT<AReal> &matrix, IndexIterator_t begin);
0068
0069 BatchIterator_t begin() { return BatchIterator_t(*this); }
0070 BatchIterator_t end() { return BatchIterator_t(*this, fNSamples / fBatchSize); }
0071
0072
0073
0074
0075 void Shuffle();
0076
0077
0078
0079
0080 TBatch<TReference<AReal>> GetBatch();
0081 };
0082
0083 template <typename AData, typename AReal>
0084 TDataLoader<AData, TReference<AReal>>::TDataLoader(const AData &data, size_t nSamples, size_t batchSize,
0085 size_t nInputFeatures, size_t nOutputFeatures, size_t )
0086 : fData(data), fNSamples(nSamples), fBatchSize(batchSize), fNInputFeatures(nInputFeatures),
0087 fNOutputFeatures(nOutputFeatures), fBatchIndex(0), inputMatrix(batchSize, nInputFeatures),
0088 outputMatrix(batchSize, nOutputFeatures), weightMatrix(batchSize, 1), fSampleIndices()
0089 {
0090 fSampleIndices.reserve(fNSamples);
0091 for (size_t i = 0; i < fNSamples; i++) {
0092 fSampleIndices.push_back(i);
0093 }
0094 }
0095
0096 template <typename AData, typename AReal>
0097 TBatch<TReference<AReal>> TDataLoader<AData, TReference<AReal>>::GetBatch()
0098 {
0099 fBatchIndex %= (fNSamples / fBatchSize);
0100
0101 size_t sampleIndex = fBatchIndex * fBatchSize;
0102 IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
0103
0104 CopyInput(inputMatrix, sampleIndexIterator);
0105 CopyOutput(outputMatrix, sampleIndexIterator);
0106 CopyWeights(weightMatrix, sampleIndexIterator);
0107
0108 fBatchIndex++;
0109
0110 return TBatch<TReference<AReal>>(inputMatrix, outputMatrix, weightMatrix);
0111 }
0112
0113
0114 template <typename AData, typename AReal>
0115 void TDataLoader<AData, TReference<AReal>>::Shuffle()
0116 {
0117 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), std::default_random_engine{});
0118 }
0119
0120 }
0121 }
0122
0123 #endif