File indexing completed on 2025-09-17 09:14:38
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 #ifndef TMVA_RBATCHLOADER
0015 #define TMVA_RBATCHLOADER
0016
0017 #include <vector>
0018 #include <memory>
0019 #include <numeric>
0020
0021
0022 #include <queue>
0023 #include <mutex>
0024 #include <condition_variable>
0025
0026 #include "TMVA/RTensor.hxx"
0027 #include "TMVA/Tools.h"
0028
0029 namespace TMVA {
0030 namespace Experimental {
0031 namespace Internal {
0032
0033 class RBatchLoader {
0034 private:
0035 const TMVA::Experimental::RTensor<float> &fChunkTensor;
0036 std::size_t fBatchSize;
0037 std::size_t fNumColumns;
0038 std::size_t fMaxBatches;
0039 std::size_t fTrainingRemainderRow = 0;
0040 std::size_t fValidationRemainderRow = 0;
0041
0042 bool fIsActive = false;
0043
0044 std::mutex fBatchLock;
0045 std::condition_variable fBatchCondition;
0046
0047 std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fTrainingBatchQueue;
0048 std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fValidationBatchQueue;
0049 std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
0050
0051 std::unique_ptr<TMVA::Experimental::RTensor<float>> fTrainingRemainder;
0052 std::unique_ptr<TMVA::Experimental::RTensor<float>> fValidationRemainder;
0053
0054 public:
0055 RBatchLoader(const TMVA::Experimental::RTensor<float> &chunkTensor, const std::size_t batchSize,
0056 const std::size_t numColumns, const std::size_t maxBatches)
0057 : fChunkTensor(chunkTensor), fBatchSize(batchSize), fNumColumns(numColumns), fMaxBatches(maxBatches)
0058 {
0059
0060 fTrainingRemainder =
0061 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fBatchSize - 1, fNumColumns});
0062 fValidationRemainder =
0063 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fBatchSize - 1, fNumColumns});
0064 }
0065
0066 ~RBatchLoader() { DeActivate(); }
0067
0068 public:
0069
0070
0071
0072 const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0073 {
0074 std::unique_lock<std::mutex> lock(fBatchLock);
0075 fBatchCondition.wait(lock, [this]() { return !fTrainingBatchQueue.empty() || !fIsActive; });
0076
0077 if (fTrainingBatchQueue.empty()) {
0078 fCurrentBatch = std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({0}));
0079 return *fCurrentBatch;
0080 }
0081
0082 fCurrentBatch = std::move(fTrainingBatchQueue.front());
0083 fTrainingBatchQueue.pop();
0084
0085 fBatchCondition.notify_all();
0086
0087 return *fCurrentBatch;
0088 }
0089
0090
0091
0092
0093
0094 const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0095 {
0096 if (fValidationBatchQueue.empty()) {
0097 fCurrentBatch = std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({0}));
0098 return *fCurrentBatch;
0099 }
0100
0101 fCurrentBatch = std::move(fValidationBatchQueue.front());
0102 fValidationBatchQueue.pop();
0103
0104 return *fCurrentBatch;
0105 }
0106
0107
0108 void Activate()
0109 {
0110 fTrainingRemainderRow = 0;
0111 fValidationRemainderRow = 0;
0112
0113 {
0114 std::lock_guard<std::mutex> lock(fBatchLock);
0115 fIsActive = true;
0116 }
0117 fBatchCondition.notify_all();
0118 }
0119
0120
0121
0122 void DeActivate()
0123 {
0124 {
0125 std::lock_guard<std::mutex> lock(fBatchLock);
0126 fIsActive = false;
0127 }
0128 fBatchCondition.notify_all();
0129 }
0130
0131 std::unique_ptr<TMVA::Experimental::RTensor<float>>
0132 CreateBatch(const TMVA::Experimental::RTensor<float> &chunkTensor, std::span<const std::size_t> idxs,
0133 std::size_t batchSize)
0134 {
0135 auto batch =
0136 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({batchSize, fNumColumns}));
0137
0138 for (std::size_t i = 0; i < batchSize; i++) {
0139 std::copy(chunkTensor.GetData() + (idxs[i] * fNumColumns),
0140 chunkTensor.GetData() + ((idxs[i] + 1) * fNumColumns), batch->GetData() + i * fNumColumns);
0141 }
0142
0143 return batch;
0144 }
0145
0146 std::unique_ptr<TMVA::Experimental::RTensor<float>>
0147 CreateFirstBatch(const TMVA::Experimental::RTensor<float> &remainderTensor, std::size_t remainderTensorRow,
0148 std::span<const std::size_t> eventIndices)
0149 {
0150 auto batch =
0151 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({fBatchSize, fNumColumns}));
0152
0153 for (size_t i = 0; i < remainderTensorRow; i++) {
0154 std::copy(remainderTensor.GetData() + i * fNumColumns, remainderTensor.GetData() + (i + 1) * fNumColumns,
0155 batch->GetData() + i * fNumColumns);
0156 }
0157
0158 for (std::size_t i = 0; i < (fBatchSize - remainderTensorRow); i++) {
0159 std::copy(fChunkTensor.GetData() + eventIndices[i] * fNumColumns,
0160 fChunkTensor.GetData() + (eventIndices[i] + 1) * fNumColumns,
0161 batch->GetData() + (i + remainderTensorRow) * fNumColumns);
0162 }
0163
0164 return batch;
0165 }
0166
0167
0168
0169
0170
0171
0172 void SaveRemainingData(TMVA::Experimental::RTensor<float> &remainderTensor, const std::size_t remainderTensorRow,
0173 const std::vector<std::size_t> eventIndices, const std::size_t start = 0)
0174 {
0175 for (std::size_t i = start; i < eventIndices.size(); i++) {
0176 std::copy(fChunkTensor.GetData() + eventIndices[i] * fNumColumns,
0177 fChunkTensor.GetData() + (eventIndices[i] + 1) * fNumColumns,
0178 remainderTensor.GetData() + (i - start + remainderTensorRow) * fNumColumns);
0179 }
0180 }
0181
0182
0183
0184
0185
0186 void CreateTrainingBatches(const std::vector<std::size_t> &eventIndices)
0187 {
0188
0189
0190 {
0191 std::unique_lock<std::mutex> lock(fBatchLock);
0192 fBatchCondition.wait(lock, [this]() { return (fTrainingBatchQueue.size() < fMaxBatches) || !fIsActive; });
0193 if (!fIsActive)
0194 return;
0195 }
0196
0197 std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> batches;
0198
0199 if (eventIndices.size() + fTrainingRemainderRow >= fBatchSize) {
0200 batches.emplace_back(CreateFirstBatch(*fTrainingRemainder, fTrainingRemainderRow, eventIndices));
0201 } else {
0202 SaveRemainingData(*fTrainingRemainder, fTrainingRemainderRow, eventIndices);
0203 fTrainingRemainderRow += eventIndices.size();
0204 return;
0205 }
0206
0207
0208 std::size_t start = fBatchSize - fTrainingRemainderRow;
0209 for (; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0210
0211 std::span<const std::size_t> idxs{eventIndices.data() + start, eventIndices.data() + start + fBatchSize};
0212
0213
0214 batches.emplace_back(CreateBatch(fChunkTensor, idxs, fBatchSize));
0215 }
0216
0217 {
0218 std::unique_lock<std::mutex> lock(fBatchLock);
0219 for (std::size_t i = 0; i < batches.size(); i++) {
0220 fTrainingBatchQueue.push(std::move(batches[i]));
0221 }
0222 }
0223
0224 fBatchCondition.notify_all();
0225
0226 fTrainingRemainderRow = eventIndices.size() - start;
0227 SaveRemainingData(*fTrainingRemainder, 0, eventIndices, start);
0228 }
0229
0230
0231
0232
0233
0234 void CreateValidationBatches(const std::vector<std::size_t> &eventIndices)
0235 {
0236 if (eventIndices.size() + fValidationRemainderRow >= fBatchSize) {
0237 fValidationBatchQueue.push(CreateFirstBatch(*fValidationRemainder, fValidationRemainderRow, eventIndices));
0238 } else {
0239 SaveRemainingData(*fValidationRemainder, fValidationRemainderRow, eventIndices);
0240 fValidationRemainderRow += eventIndices.size();
0241 return;
0242 }
0243
0244
0245 std::size_t start = fBatchSize - fValidationRemainderRow;
0246 for (; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0247
0248 std::vector<std::size_t> idx;
0249
0250 for (std::size_t i = start; i < (start + fBatchSize); i++) {
0251 idx.push_back(eventIndices[i]);
0252 }
0253
0254 fValidationBatchQueue.push(CreateBatch(fChunkTensor, idx, fBatchSize));
0255 }
0256
0257 fValidationRemainderRow = eventIndices.size() - start;
0258 SaveRemainingData(*fValidationRemainder, 0, eventIndices, start);
0259 }
0260
0261 void LastBatches()
0262 {
0263 {
0264 if (fTrainingRemainderRow) {
0265 std::vector<std::size_t> idx = std::vector<std::size_t>(fTrainingRemainderRow);
0266 std::iota(idx.begin(), idx.end(), 0);
0267
0268 std::unique_ptr<TMVA::Experimental::RTensor<float>> batch =
0269 CreateBatch(*fTrainingRemainder, idx, fTrainingRemainderRow);
0270
0271 std::unique_lock<std::mutex> lock(fBatchLock);
0272 fTrainingBatchQueue.push(std::move(batch));
0273 }
0274 }
0275
0276 if (fValidationRemainderRow) {
0277 std::vector<std::size_t> idx = std::vector<std::size_t>(fValidationRemainderRow);
0278 std::iota(idx.begin(), idx.end(), 0);
0279
0280 fValidationBatchQueue.push(CreateBatch(*fValidationRemainder, idx, fValidationRemainderRow));
0281 }
0282 }
0283 };
0284
0285 }
0286 }
0287 }
0288
0289 #endif