File indexing completed on 2025-09-17 09:14:37
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 #ifndef TMVA_RBATCHGENERATOR
0015 #define TMVA_RBATCHGENERATOR
0016
0017 #include "TMVA/RTensor.hxx"
0018 #include "ROOT/RDF/RDatasetSpec.hxx"
0019 #include "TMVA/BatchGenerator/RChunkLoader.hxx"
0020 #include "TMVA/BatchGenerator/RBatchLoader.hxx"
0021 #include "TROOT.h"
0022
0023 #include <cmath>
0024 #include <memory>
0025 #include <mutex>
0026 #include <random>
0027 #include <thread>
0028 #include <variant>
0029 #include <vector>
0030
0031 namespace TMVA {
0032 namespace Experimental {
0033 namespace Internal {
0034
0035 template <typename... Args>
0036 class RBatchGenerator {
0037 private:
0038 std::mt19937 fRng;
0039 std::mt19937 fFixedRng;
0040 std::random_device::result_type fFixedSeed;
0041
0042 std::size_t fChunkSize;
0043 std::size_t fMaxChunks;
0044 std::size_t fBatchSize;
0045 std::size_t fNumEntries;
0046
0047 float fValidationSplit;
0048
0049 std::variant<std::shared_ptr<RChunkLoader<Args...>>, std::shared_ptr<RChunkLoaderFilters<Args...>>> fChunkLoader;
0050
0051 std::unique_ptr<RBatchLoader> fBatchLoader;
0052
0053 std::unique_ptr<std::thread> fLoadingThread;
0054
0055 std::unique_ptr<TMVA::Experimental::RTensor<float>> fChunkTensor;
0056
0057 ROOT::RDF::RNode &f_rdf;
0058
0059 std::mutex fIsActiveMutex;
0060
0061 bool fDropRemainder;
0062 bool fShuffle;
0063 bool fIsActive{false};
0064 bool fNotFiltered;
0065 bool fUseWholeFile;
0066
0067 public:
0068 RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t batchSize,
0069 const std::vector<std::string> &cols, const std::size_t numColumns,
0070 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
0071 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
0072 bool dropRemainder = true)
0073 : fRng(std::random_device{}()),
0074 fFixedSeed(std::uniform_int_distribution<std::random_device::result_type>{}(fRng)),
0075 f_rdf(rdf),
0076 fChunkSize(chunkSize),
0077 fBatchSize(batchSize),
0078 fValidationSplit(validationSplit),
0079 fMaxChunks(maxChunks),
0080 fDropRemainder(dropRemainder),
0081 fShuffle(shuffle),
0082 fNotFiltered(f_rdf.GetFilterNames().empty()),
0083 fUseWholeFile(maxChunks == 0)
0084 {
0085
0086
0087 fChunkTensor =
0088 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fChunkSize, numColumns});
0089
0090 if (fNotFiltered) {
0091 fNumEntries = f_rdf.Count().GetValue();
0092
0093 fChunkLoader = std::make_unique<TMVA::Experimental::Internal::RChunkLoader<Args...>>(
0094 f_rdf, *fChunkTensor, fChunkSize, cols, vecSizes, vecPadding);
0095 } else {
0096 auto report = f_rdf.Report();
0097 fNumEntries = f_rdf.Count().GetValue();
0098 std::size_t numAllEntries = report.begin()->GetAll();
0099
0100 fChunkLoader = std::make_unique<TMVA::Experimental::Internal::RChunkLoaderFilters<Args...>>(
0101 f_rdf, *fChunkTensor, fChunkSize, cols, fNumEntries, numAllEntries, vecSizes, vecPadding);
0102 }
0103
0104 std::size_t maxBatches = ceil((fChunkSize / fBatchSize) * (1 - fValidationSplit));
0105
0106
0107 fBatchLoader = std::make_unique<TMVA::Experimental::Internal::RBatchLoader>(*fChunkTensor, fBatchSize, numColumns,
0108 maxBatches);
0109 }
0110
0111 ~RBatchGenerator() { DeActivate(); }
0112
0113
0114
0115 void DeActivate()
0116 {
0117 {
0118 std::lock_guard<std::mutex> lock(fIsActiveMutex);
0119 fIsActive = false;
0120 }
0121
0122 fBatchLoader->DeActivate();
0123
0124 if (fLoadingThread) {
0125 if (fLoadingThread->joinable()) {
0126 fLoadingThread->join();
0127 }
0128 }
0129 }
0130
0131
0132
0133 void Activate()
0134 {
0135 if (fIsActive)
0136 return;
0137
0138 {
0139 std::lock_guard<std::mutex> lock(fIsActiveMutex);
0140 fIsActive = true;
0141 }
0142
0143 fFixedRng.seed(fFixedSeed);
0144 fBatchLoader->Activate();
0145
0146 if (fNotFiltered) {
0147 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunksNoFilters, this);
0148 } else {
0149 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunksFilters, this);
0150 }
0151 }
0152
0153
0154
0155
0156 const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0157 {
0158
0159 return fBatchLoader->GetTrainBatch();
0160 }
0161
0162
0163
0164
0165 const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0166 {
0167
0168 return fBatchLoader->GetValidationBatch();
0169 }
0170
0171 std::size_t NumberOfTrainingBatches()
0172 {
0173 std::size_t entriesForTraining =
0174 (fNumEntries / fChunkSize) * (fChunkSize - floor(fChunkSize * fValidationSplit)) + fNumEntries % fChunkSize -
0175 floor(fValidationSplit * (fNumEntries % fChunkSize));
0176
0177 if (fDropRemainder || !(entriesForTraining % fBatchSize)) {
0178 return entriesForTraining / fBatchSize;
0179 }
0180
0181 return entriesForTraining / fBatchSize + 1;
0182 }
0183
0184
0185
0186 std::size_t TrainRemainderRows()
0187 {
0188 std::size_t entriesForTraining =
0189 (fNumEntries / fChunkSize) * (fChunkSize - floor(fChunkSize * fValidationSplit)) + fNumEntries % fChunkSize -
0190 floor(fValidationSplit * (fNumEntries % fChunkSize));
0191
0192 if (fDropRemainder || !(entriesForTraining % fBatchSize)) {
0193 return 0;
0194 }
0195
0196 return entriesForTraining % fBatchSize;
0197 }
0198
0199
0200
0201 std::size_t NumberOfValidationBatches()
0202 {
0203 std::size_t entriesForValidation = (fNumEntries / fChunkSize) * floor(fChunkSize * fValidationSplit) +
0204 floor((fNumEntries % fChunkSize) * fValidationSplit);
0205
0206 if (fDropRemainder || !(entriesForValidation % fBatchSize)) {
0207
0208 return entriesForValidation / fBatchSize;
0209 }
0210
0211 return entriesForValidation / fBatchSize + 1;
0212 }
0213
0214
0215
0216 std::size_t ValidationRemainderRows()
0217 {
0218 std::size_t entriesForValidation = (fNumEntries / fChunkSize) * floor(fChunkSize * fValidationSplit) +
0219 floor((fNumEntries % fChunkSize) * fValidationSplit);
0220
0221 if (fDropRemainder || !(entriesForValidation % fBatchSize)) {
0222
0223 return 0;
0224 }
0225
0226 return entriesForValidation % fBatchSize;
0227 }
0228
0229
0230 void LoadChunksNoFilters()
0231 {
0232 for (std::size_t currentChunk = 0, currentEntry = 0;
0233 ((currentChunk < fMaxChunks) || fUseWholeFile) && currentEntry < fNumEntries; currentChunk++) {
0234
0235
0236 {
0237 std::lock_guard<std::mutex> lock(fIsActiveMutex);
0238 if (!fIsActive)
0239 return;
0240 }
0241
0242
0243 std::size_t report = std::get<std::shared_ptr<RChunkLoader<Args...>>>(fChunkLoader)->LoadChunk(currentEntry);
0244 currentEntry += report;
0245
0246 CreateBatches(report);
0247 }
0248
0249 if (!fDropRemainder) {
0250 fBatchLoader->LastBatches();
0251 }
0252
0253 fBatchLoader->DeActivate();
0254 }
0255
0256 void LoadChunksFilters()
0257 {
0258 std::size_t currentChunk = 0;
0259 for (std::size_t processedEvents = 0, currentRow = 0;
0260 ((currentChunk < fMaxChunks) || fUseWholeFile) && processedEvents < fNumEntries; currentChunk++) {
0261
0262
0263 {
0264 std::lock_guard<std::mutex> lock(fIsActiveMutex);
0265 if (!fIsActive)
0266 return;
0267 }
0268
0269
0270 std::pair<std::size_t, std::size_t> report =
0271 std::get<std::shared_ptr<RChunkLoaderFilters<Args...>>>(fChunkLoader)->LoadChunk(currentRow);
0272
0273 currentRow += report.first;
0274 processedEvents += report.second;
0275
0276 CreateBatches(report.second);
0277 }
0278
0279 if (currentChunk < fMaxChunks || fUseWholeFile) {
0280 CreateBatches(std::get<std::shared_ptr<RChunkLoaderFilters<Args...>>>(fChunkLoader)->LastChunk());
0281 }
0282
0283 if (!fDropRemainder) {
0284 fBatchLoader->LastBatches();
0285 }
0286
0287 fBatchLoader->DeActivate();
0288 }
0289
0290
0291
0292 void CreateBatches(std::size_t processedEvents)
0293 {
0294 auto &&[trainingIndices, validationIndices] = createIndices(processedEvents);
0295
0296 fBatchLoader->CreateTrainingBatches(trainingIndices);
0297 fBatchLoader->CreateValidationBatches(validationIndices);
0298 }
0299
0300
0301
0302 std::pair<std::vector<std::size_t>, std::vector<std::size_t>> createIndices(std::size_t events)
0303 {
0304
0305 std::vector<std::size_t> row_order = std::vector<std::size_t>(events);
0306 std::iota(row_order.begin(), row_order.end(), 0);
0307
0308 if (fShuffle) {
0309
0310 std::shuffle(row_order.begin(), row_order.end(), fFixedRng);
0311 }
0312
0313
0314 std::size_t num_validation = floor(events * fValidationSplit);
0315
0316
0317 std::vector<std::size_t> trainingIndices =
0318 std::vector<std::size_t>({row_order.begin(), row_order.end() - num_validation});
0319 std::vector<std::size_t> validationIndices =
0320 std::vector<std::size_t>({row_order.end() - num_validation, row_order.end()});
0321
0322 if (fShuffle) {
0323 std::shuffle(trainingIndices.begin(), trainingIndices.end(), fRng);
0324 }
0325
0326 return std::make_pair(trainingIndices, validationIndices);
0327 }
0328
0329 bool IsActive() { return fIsActive; }
0330 };
0331
0332 }
0333 }
0334 }
0335
0336 #endif