File indexing completed on 2025-01-18 10:10:52
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034 #ifndef TMVA_DNN_ARCHITECTURES_REFERENCE_TENSORDATALOADER
0035 #define TMVA_DNN_ARCHITECTURES_REFERENCE_TENSORDATALOADER
0036
0037 #include "TMVA/DNN/TensorDataLoader.h"
0038 #include <iostream>
0039
0040 namespace TMVA {
0041 namespace DNN {
0042
0043 template <typename AReal>
0044 class TReference;
0045
0046 template <typename AData, typename AReal>
0047 class TTensorDataLoader<AData, TReference<AReal>> {
0048 private:
0049 using BatchIterator_t = TTensorBatchIterator<AData, TReference<AReal>>;
0050
0051 const AData &fData;
0052
0053 size_t fNSamples;
0054
0055 size_t fBatchDepth;
0056 size_t fBatchHeight;
0057 size_t fBatchWidth;
0058 size_t fNOutputFeatures;
0059 size_t fBatchIndex;
0060
0061 std::vector<size_t> fInputShape;
0062
0063 std::vector<TMatrixT<AReal>> inputTensor;
0064 TMatrixT<AReal> outputMatrix;
0065 TMatrixT<AReal> weightMatrix;
0066
0067 std::vector<size_t> fSampleIndices;
0068
0069 public:
0070
0071 TTensorDataLoader(const AData &data, size_t nSamples, size_t batchDepth,
0072 size_t batchHeight, size_t batchWidth, size_t nOutputFeatures,
0073 std::vector<size_t> inputShape, size_t nStreams = 1);
0074
0075 TTensorDataLoader(const TTensorDataLoader &) = default;
0076 TTensorDataLoader(TTensorDataLoader &&) = default;
0077 TTensorDataLoader &operator=(const TTensorDataLoader &) = default;
0078 TTensorDataLoader &operator=(TTensorDataLoader &&) = default;
0079
0080
0081
0082 void CopyTensorInput(std::vector<TMatrixT<AReal>> &tensor, IndexIterator_t sampleIterator);
0083
0084
0085 void CopyTensorOutput(TMatrixT<AReal> &matrix, IndexIterator_t sampleIterator);
0086
0087
0088 void CopyTensorWeights(TMatrixT<AReal> &matrix, IndexIterator_t sampleIterator);
0089
0090 BatchIterator_t begin() { return BatchIterator_t(*this); }
0091 BatchIterator_t end() { return BatchIterator_t(*this, fNSamples / fInputShape[0]); }
0092
0093
0094
0095
0096 template<typename RNG>
0097 void Shuffle(RNG & rng);
0098
0099
0100
0101
0102 TTensorBatch<TReference<AReal>> GetTensorBatch();
0103 };
0104
0105
0106
0107
0108 template <typename AData, typename AReal>
0109 TTensorDataLoader<AData, TReference<AReal>>::TTensorDataLoader(const AData &data, size_t nSamples, size_t batchDepth,
0110 size_t batchHeight, size_t batchWidth, size_t nOutputFeatures,
0111 std::vector<size_t> inputShape, size_t )
0112 : fData(data), fNSamples(nSamples), fBatchDepth(batchDepth), fBatchHeight(batchHeight),
0113 fBatchWidth(batchWidth), fNOutputFeatures(nOutputFeatures), fBatchIndex(0), fInputShape(std::move(inputShape)), inputTensor(),
0114 outputMatrix(inputShape[0], nOutputFeatures), weightMatrix(inputShape[0], 1), fSampleIndices()
0115 {
0116
0117 inputTensor.reserve(fBatchDepth);
0118 for (size_t i = 0; i < fBatchDepth; i++) {
0119 inputTensor.emplace_back(batchHeight, batchWidth);
0120 }
0121
0122 fSampleIndices.reserve(fNSamples);
0123 for (size_t i = 0; i < fNSamples; i++) {
0124 fSampleIndices.push_back(i);
0125 }
0126 }
0127
0128 template <typename AData, typename AReal>
0129 template <typename RNG>
0130 void TTensorDataLoader<AData, TReference<AReal>>::Shuffle(RNG & rng)
0131 {
0132 std::shuffle(fSampleIndices.begin(), fSampleIndices.end(), rng);
0133 }
0134
0135 template <typename AData, typename AReal>
0136 auto TTensorDataLoader<AData, TReference<AReal>>::GetTensorBatch() -> TTensorBatch<TReference<AReal>>
0137 {
0138 fBatchIndex %= (fNSamples / fInputShape[0]);
0139
0140 size_t sampleIndex = fBatchIndex * fInputShape[0];
0141 IndexIterator_t sampleIndexIterator = fSampleIndices.begin() + sampleIndex;
0142
0143 CopyTensorInput(inputTensor, sampleIndexIterator);
0144 CopyTensorOutput(outputMatrix, sampleIndexIterator);
0145 CopyTensorWeights(weightMatrix, sampleIndexIterator);
0146
0147 fBatchIndex++;
0148 return TTensorBatch<TReference<AReal>>(inputTensor, outputMatrix, weightMatrix);
0149 }
0150
0151 }
0152 }
0153
0154 #endif