File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_RBatchLoader
0002 #define TMVA_RBatchLoader
0003
0004 #include <iostream>
0005 #include <vector>
0006 #include <memory>
0007
0008
0009 #include <queue>
0010 #include <mutex>
0011 #include <condition_variable>
0012
0013 #include "TMVA/RTensor.hxx"
0014 #include "TMVA/Tools.h"
0015 #include "TRandom3.h"
0016
0017 namespace TMVA {
0018 namespace Experimental {
0019 namespace Internal {
0020
0021 class RBatchLoader {
0022 private:
0023 std::size_t fBatchSize;
0024 std::size_t fNumColumns;
0025 std::size_t fMaxBatches;
0026
0027 bool fIsActive = false;
0028 TMVA::RandomGenerator<TRandom3> fRng = TMVA::RandomGenerator<TRandom3>(0);
0029
0030 std::mutex fBatchLock;
0031 std::condition_variable fBatchCondition;
0032
0033 std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fTrainingBatchQueue;
0034 std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fValidationBatches;
0035 std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
0036
0037 std::size_t fValidationIdx = 0;
0038
0039 TMVA::Experimental::RTensor<float> fEmptyTensor = TMVA::Experimental::RTensor<float>({0});
0040
0041 public:
0042 RBatchLoader(const std::size_t batchSize, const std::size_t numColumns, const std::size_t maxBatches)
0043 : fBatchSize(batchSize), fNumColumns(numColumns), fMaxBatches(maxBatches)
0044 {
0045 }
0046
0047 ~RBatchLoader() { DeActivate(); }
0048
0049 public:
0050
0051
0052
0053 const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0054 {
0055 std::unique_lock<std::mutex> lock(fBatchLock);
0056 fBatchCondition.wait(lock, [this]() { return !fTrainingBatchQueue.empty() || !fIsActive; });
0057
0058 if (fTrainingBatchQueue.empty()) {
0059 fCurrentBatch = std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({0}));
0060 return *fCurrentBatch;
0061 }
0062
0063 fCurrentBatch = std::move(fTrainingBatchQueue.front());
0064 fTrainingBatchQueue.pop();
0065
0066 fBatchCondition.notify_all();
0067
0068 return *fCurrentBatch;
0069 }
0070
0071
0072
0073
0074
0075 const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0076 {
0077 if (HasValidationData()) {
0078 return *fValidationBatches[fValidationIdx++].get();
0079 }
0080
0081 return fEmptyTensor;
0082 }
0083
0084
0085
0086 bool HasTrainData()
0087 {
0088 {
0089 std::unique_lock<std::mutex> lock(fBatchLock);
0090 if (!fTrainingBatchQueue.empty() || fIsActive)
0091 return true;
0092 }
0093
0094 return false;
0095 }
0096
0097
0098
0099 bool HasValidationData()
0100 {
0101 std::unique_lock<std::mutex> lock(fBatchLock);
0102 return fValidationIdx < fValidationBatches.size();
0103 }
0104
0105
0106 void Activate()
0107 {
0108 {
0109 std::lock_guard<std::mutex> lock(fBatchLock);
0110 fIsActive = true;
0111 }
0112 fBatchCondition.notify_all();
0113 }
0114
0115
0116
0117 void DeActivate()
0118 {
0119 {
0120 std::lock_guard<std::mutex> lock(fBatchLock);
0121 fIsActive = false;
0122 }
0123 fBatchCondition.notify_all();
0124 }
0125
0126
0127
0128
0129
0130 std::unique_ptr<TMVA::Experimental::RTensor<float>>
0131 CreateBatch(const TMVA::Experimental::RTensor<float> &chunkTensor, const std::vector<std::size_t> idx)
0132 {
0133 auto batch =
0134 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({fBatchSize, fNumColumns}));
0135
0136 for (std::size_t i = 0; i < fBatchSize; i++) {
0137 std::copy(chunkTensor.GetData() + (idx[i] * fNumColumns), chunkTensor.GetData() + ((idx[i] + 1) * fNumColumns),
0138 batch->GetData() + i * fNumColumns);
0139 }
0140
0141 return batch;
0142 }
0143
0144
0145
0146
0147
0148
0149
0150 void CreateTrainingBatches(const TMVA::Experimental::RTensor<float> &chunkTensor,
0151 std::vector<std::size_t> eventIndices, const bool shuffle = true)
0152 {
0153
0154
0155 {
0156 std::unique_lock<std::mutex> lock(fBatchLock);
0157 fBatchCondition.wait(lock, [this]() { return (fTrainingBatchQueue.size() < fMaxBatches) || !fIsActive; });
0158 if (!fIsActive)
0159 return;
0160 }
0161
0162 if (shuffle)
0163 std::shuffle(eventIndices.begin(), eventIndices.end(), fRng);
0164
0165 std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> batches;
0166
0167
0168 for (std::size_t start = 0; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0169
0170
0171 std::vector<std::size_t> idx;
0172 for (std::size_t i = start; i < (start + fBatchSize); i++) {
0173 idx.push_back(eventIndices[i]);
0174 }
0175
0176
0177 batches.emplace_back(CreateBatch(chunkTensor, idx));
0178 }
0179
0180 {
0181 std::unique_lock<std::mutex> lock(fBatchLock);
0182 for (std::size_t i = 0; i < batches.size(); i++) {
0183 fTrainingBatchQueue.push(std::move(batches[i]));
0184 }
0185 }
0186
0187 fBatchCondition.notify_one();
0188 }
0189
0190
0191
0192
0193
0194 void CreateValidationBatches(const TMVA::Experimental::RTensor<float> &chunkTensor,
0195 const std::vector<std::size_t> eventIndices)
0196 {
0197
0198 for (std::size_t start = 0; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0199
0200 std::vector<std::size_t> idx;
0201
0202 for (std::size_t i = start; i < (start + fBatchSize); i++) {
0203 idx.push_back(eventIndices[i]);
0204 }
0205
0206 {
0207 std::unique_lock<std::mutex> lock(fBatchLock);
0208 fValidationBatches.emplace_back(CreateBatch(chunkTensor, idx));
0209 }
0210 }
0211 }
0212
0213
0214 void StartValidation()
0215 {
0216 std::unique_lock<std::mutex> lock(fBatchLock);
0217 fValidationIdx = 0;
0218 }
0219 };
0220
0221 }
0222 }
0223 }
0224
0225 #endif