File indexing completed on 2025-12-15 10:28:51
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015 #ifndef TMVA_RBATCHGENERATOR
0016 #define TMVA_RBATCHGENERATOR
0017
0018 #include "TMVA/RTensor.hxx"
0019 #include "ROOT/RDF/RDatasetSpec.hxx"
0020 #include "TMVA/BatchGenerator/RChunkLoader.hxx"
0021 #include "TMVA/BatchGenerator/RBatchLoader.hxx"
0022 #include "TROOT.h"
0023
0024 #include <cmath>
0025 #include <memory>
0026 #include <mutex>
0027 #include <random>
0028 #include <thread>
0029 #include <variant>
0030 #include <vector>
0031
0032 namespace TMVA {
0033 namespace Experimental {
0034 namespace Internal {
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045 template <typename... Args>
0046 class RBatchGenerator {
0047 private:
0048 std::vector<std::string> fCols;
0049
0050 std::size_t fChunkSize;
0051 std::size_t fMaxChunks;
0052 std::size_t fBatchSize;
0053 std::size_t fBlockSize;
0054 std::size_t fNumColumns;
0055 std::size_t fNumChunkCols;
0056 std::size_t fNumEntries;
0057 std::size_t fSetSeed;
0058 std::size_t fSumVecSizes;
0059
0060 ROOT::RDF::RResultPtr<std::vector<ULong64_t>> fEntries;
0061 float fValidationSplit;
0062
0063 std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
0064 std::unique_ptr<RBatchLoader> fBatchLoader;
0065
0066 std::unique_ptr<std::thread> fLoadingThread;
0067
0068 std::size_t fTrainingChunkNum;
0069 std::size_t fValidationChunkNum;
0070
0071 ROOT::RDF::RNode &f_rdf;
0072
0073 std::mutex fIsActiveMutex;
0074
0075 bool fDropRemainder;
0076 bool fShuffle;
0077 bool fIsActive{false};
0078 bool fNotFiltered;
0079 bool fUseWholeFile;
0080
0081 bool fEpochActive{false};
0082 bool fTrainingEpochActive{false};
0083 bool fValidationEpochActive{false};
0084
0085 std::size_t fNumTrainingEntries;
0086 std::size_t fNumValidationEntries;
0087
0088 std::size_t fNumTrainingChunks;
0089 std::size_t fNumValidationChunks;
0090
0091 std::size_t fLeftoverTrainingBatchSize;
0092 std::size_t fLeftoverValidationBatchSize;
0093
0094 std::size_t fNumFullTrainingBatches;
0095 std::size_t fNumFullValidationBatches;
0096
0097 std::size_t fNumLeftoverTrainingBatches;
0098 std::size_t fNumLeftoverValidationBatches;
0099
0100 std::size_t fNumTrainingBatches;
0101 std::size_t fNumValidationBatches;
0102
0103 TMVA::Experimental::RTensor<float> fTrainTensor;
0104 TMVA::Experimental::RTensor<float> fTrainChunkTensor;
0105
0106 TMVA::Experimental::RTensor<float> fValidationTensor;
0107 TMVA::Experimental::RTensor<float> fValidationChunkTensor;
0108
0109 public:
0110 RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
0111 const std::size_t batchSize, const std::vector<std::string> &cols,
0112 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
0113 const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
0114 bool dropRemainder = true, const std::size_t setSeed = 0)
0115
0116 : f_rdf(rdf),
0117 fCols(cols),
0118 fChunkSize(chunkSize),
0119 fBlockSize(blockSize),
0120 fBatchSize(batchSize),
0121 fValidationSplit(validationSplit),
0122 fMaxChunks(maxChunks),
0123 fDropRemainder(dropRemainder),
0124 fSetSeed(setSeed),
0125 fShuffle(shuffle),
0126 fNotFiltered(f_rdf.GetFilterNames().empty()),
0127 fUseWholeFile(maxChunks == 0),
0128 fNumColumns(cols.size()),
0129 fTrainTensor({0, 0}),
0130 fTrainChunkTensor({0, 0}),
0131 fValidationTensor({0, 0}),
0132 fValidationChunkTensor({0, 0})
0133 {
0134
0135 fNumEntries = f_rdf.Count().GetValue();
0136 fEntries = f_rdf.Take<ULong64_t>("rdfentry_");
0137
0138 fSumVecSizes = std::accumulate(vecSizes.begin(), vecSizes.end(), 0);
0139 fNumChunkCols = fNumColumns + fSumVecSizes - vecSizes.size();
0140
0141
0142 fEntries->push_back((*fEntries)[fNumEntries - 1] + 1);
0143
0144 fChunkLoader =
0145 std::make_unique<RChunkLoader<Args...>>(f_rdf, fNumEntries, fEntries, fChunkSize, fBlockSize, fValidationSplit,
0146 fCols, vecSizes, vecPadding, fShuffle, fSetSeed);
0147 fBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols);
0148
0149
0150 fChunkLoader->SplitDataset();
0151
0152
0153 fNumValidationEntries = static_cast<std::size_t>(fValidationSplit * fNumEntries);
0154 fNumTrainingEntries = fNumEntries - fNumValidationEntries;
0155
0156 fLeftoverTrainingBatchSize = fNumTrainingEntries % fBatchSize;
0157 fLeftoverValidationBatchSize = fNumValidationEntries % fBatchSize;
0158
0159 fNumFullTrainingBatches = fNumTrainingEntries / fBatchSize;
0160 fNumFullValidationBatches = fNumValidationEntries / fBatchSize;
0161
0162 fNumLeftoverTrainingBatches = fLeftoverTrainingBatchSize == 0 ? 0 : 1;
0163 fNumLeftoverValidationBatches = fLeftoverValidationBatchSize == 0 ? 0 : 1;
0164
0165 if (dropRemainder) {
0166 fNumTrainingBatches = fNumFullTrainingBatches;
0167 fNumValidationBatches = fNumFullValidationBatches;
0168 }
0169
0170 else {
0171 fNumTrainingBatches = fNumFullTrainingBatches + fNumLeftoverTrainingBatches;
0172 fNumValidationBatches = fNumFullValidationBatches + fNumLeftoverValidationBatches;
0173 }
0174
0175
0176 fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
0177 fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
0178
0179 fTrainingChunkNum = 0;
0180 fValidationChunkNum = 0;
0181 }
0182
0183 ~RBatchGenerator() { DeActivate(); }
0184
0185 void DeActivate()
0186 {
0187 {
0188 std::lock_guard<std::mutex> lock(fIsActiveMutex);
0189 fIsActive = false;
0190 }
0191
0192 fBatchLoader->DeActivate();
0193
0194 if (fLoadingThread) {
0195 if (fLoadingThread->joinable()) {
0196 fLoadingThread->join();
0197 }
0198 }
0199 }
0200
0201
0202
0203 void Activate()
0204 {
0205 if (fIsActive)
0206 return;
0207
0208 {
0209 std::lock_guard<std::mutex> lock(fIsActiveMutex);
0210 fIsActive = true;
0211 }
0212
0213 fBatchLoader->Activate();
0214
0215 }
0216
0217 void ActivateEpoch() { fEpochActive = true; }
0218
0219 void DeActivateEpoch() { fEpochActive = false; }
0220
0221 void ActivateTrainingEpoch() { fTrainingEpochActive = true; }
0222
0223 void DeActivateTrainingEpoch() { fTrainingEpochActive = false; }
0224
0225 void ActivateValidationEpoch() { fValidationEpochActive = true; }
0226
0227 void DeActivateValidationEpoch() { fValidationEpochActive = false; }
0228
0229
0230 void CreateTrainBatches()
0231 {
0232
0233 fChunkLoader->CreateTrainingChunksIntervals();
0234 fTrainingEpochActive = true;
0235 fTrainingChunkNum = 0;
0236 fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, fTrainingChunkNum);
0237 std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum;
0238 fBatchLoader->CreateTrainingBatches(fTrainChunkTensor, lastTrainingBatch, fLeftoverTrainingBatchSize,
0239 fDropRemainder);
0240 fTrainingChunkNum++;
0241 }
0242
0243
0244 void CreateValidationBatches()
0245 {
0246
0247 fChunkLoader->CreateValidationChunksIntervals();
0248 fValidationEpochActive = true;
0249 fValidationChunkNum = 0;
0250 fChunkLoader->LoadValidationChunk(fValidationChunkTensor, fValidationChunkNum);
0251 std::size_t lastValidationBatch = fNumValidationChunks - fValidationChunkNum;
0252 fBatchLoader->CreateValidationBatches(fValidationChunkTensor, lastValidationBatch, fLeftoverValidationBatchSize,
0253 fDropRemainder);
0254 fValidationChunkNum++;
0255 }
0256
0257
0258 TMVA::Experimental::RTensor<float> GetTrainBatch()
0259 {
0260 auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();
0261
0262
0263 if (batchQueue < 1 && fTrainingChunkNum < fNumTrainingChunks) {
0264 fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, fTrainingChunkNum);
0265 std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum;
0266 fBatchLoader->CreateTrainingBatches(fTrainChunkTensor, lastTrainingBatch, fLeftoverTrainingBatchSize,
0267 fDropRemainder);
0268 fTrainingChunkNum++;
0269 }
0270
0271 else {
0272 ROOT::Internal::RDF::ChangeBeginAndEndEntries(f_rdf, 0, fNumEntries);
0273 }
0274
0275
0276 return fBatchLoader->GetTrainBatch();
0277 }
0278
0279
0280 TMVA::Experimental::RTensor<float> GetValidationBatch()
0281 {
0282 auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();
0283
0284
0285 if (batchQueue < 1 && fValidationChunkNum < fNumValidationChunks) {
0286 fChunkLoader->LoadValidationChunk(fValidationChunkTensor, fValidationChunkNum);
0287 std::size_t lastValidationBatch = fNumValidationChunks - fValidationChunkNum;
0288 fBatchLoader->CreateValidationBatches(fValidationChunkTensor, lastValidationBatch,
0289 fLeftoverValidationBatchSize, fDropRemainder);
0290 fValidationChunkNum++;
0291 }
0292
0293 else {
0294 ROOT::Internal::RDF::ChangeBeginAndEndEntries(f_rdf, 0, fNumEntries);
0295 }
0296
0297
0298 return fBatchLoader->GetValidationBatch();
0299 }
0300
0301 std::size_t NumberOfTrainingBatches() { return fNumTrainingBatches; }
0302 std::size_t NumberOfValidationBatches() { return fNumValidationBatches; }
0303
0304 std::size_t TrainRemainderRows() { return fLeftoverTrainingBatchSize; }
0305 std::size_t ValidationRemainderRows() { return fLeftoverValidationBatchSize; }
0306
0307 bool IsActive() { return fIsActive; }
0308 bool TrainingIsActive() { return fTrainingEpochActive; }
0309
0310
0311 };
0312
0313 }
0314 }
0315 }
0316
0317 #endif