File indexing completed on 2025-01-18 10:11:04
0001 #ifndef TMVA_BATCHGENERATOR
0002 #define TMVA_BATCHGENERATOR
0003
0004 #include <iostream>
0005 #include <vector>
0006 #include <thread>
0007 #include <memory>
0008 #include <cmath>
0009 #include <mutex>
0010
0011 #include "TMVA/RTensor.hxx"
0012 #include "ROOT/RDF/RDatasetSpec.hxx"
0013 #include "TMVA/RChunkLoader.hxx"
0014 #include "TMVA/RBatchLoader.hxx"
0015 #include "TMVA/Tools.h"
0016 #include "TRandom3.h"
0017 #include "TROOT.h"
0018
0019 namespace TMVA {
0020 namespace Experimental {
0021 namespace Internal {
0022
0023 template <typename... Args>
0024 class RBatchGenerator {
0025 private:
0026 TMVA::RandomGenerator<TRandom3> fRng = TMVA::RandomGenerator<TRandom3>(0);
0027
0028 std::string fFileName;
0029 std::string fTreeName;
0030
0031 std::vector<std::string> fCols;
0032 std::string fFilters;
0033
0034 std::size_t fChunkSize;
0035 std::size_t fMaxChunks;
0036 std::size_t fBatchSize;
0037 std::size_t fMaxBatches;
0038 std::size_t fNumColumns;
0039 std::size_t fNumEntries;
0040 std::size_t fCurrentRow = 0;
0041
0042 float fValidationSplit;
0043
0044 std::unique_ptr<TMVA::Experimental::Internal::RChunkLoader<Args...>> fChunkLoader;
0045 std::unique_ptr<TMVA::Experimental::Internal::RBatchLoader> fBatchLoader;
0046
0047 std::unique_ptr<std::thread> fLoadingThread;
0048
0049 bool fUseWholeFile = true;
0050
0051 std::unique_ptr<TMVA::Experimental::RTensor<float>> fChunkTensor;
0052 std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
0053
0054 std::vector<std::vector<std::size_t>> fTrainingIdxs;
0055 std::vector<std::vector<std::size_t>> fValidationIdxs;
0056
0057
0058 std::mutex fIsActiveLock;
0059
0060 bool fShuffle = true;
0061 bool fIsActive = false;
0062
0063 std::vector<std::size_t> fVecSizes;
0064 float fVecPadding;
0065
0066 public:
0067 RBatchGenerator(const std::string &treeName, const std::string &fileName, const std::size_t chunkSize,
0068 const std::size_t batchSize, const std::vector<std::string> &cols, const std::string &filters = "",
0069 const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
0070 const float validationSplit = 0.0, const std::size_t maxChunks = 0, const std::size_t numColumns = 0,
0071 bool shuffle = true)
0072 : fTreeName(treeName),
0073 fFileName(fileName),
0074 fChunkSize(chunkSize),
0075 fBatchSize(batchSize),
0076 fCols(cols),
0077 fFilters(filters),
0078 fVecSizes(vecSizes),
0079 fVecPadding(vecPadding),
0080 fValidationSplit(validationSplit),
0081 fMaxChunks(maxChunks),
0082 fNumColumns((numColumns != 0) ? numColumns : cols.size()),
0083 fShuffle(shuffle),
0084 fUseWholeFile(maxChunks == 0)
0085 {
0086
0087 fMaxBatches = ceil((fChunkSize / fBatchSize) * (1 - fValidationSplit));
0088
0089
0090 std::unique_ptr<TFile> f{TFile::Open(fFileName.c_str())};
0091 std::unique_ptr<TTree> t{f->Get<TTree>(fTreeName.c_str())};
0092 fNumEntries = t->GetEntries();
0093
0094 fChunkLoader = std::make_unique<TMVA::Experimental::Internal::RChunkLoader<Args...>>(
0095 fTreeName, fFileName, fChunkSize, fCols, fFilters, fVecSizes, fVecPadding);
0096 fBatchLoader = std::make_unique<TMVA::Experimental::Internal::RBatchLoader>(fBatchSize, fNumColumns, fMaxBatches);
0097
0098
0099 fChunkTensor =
0100 std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fChunkSize, fNumColumns});
0101 }
0102
0103 ~RBatchGenerator() { DeActivate(); }
0104
0105
0106
0107 void DeActivate()
0108 {
0109 {
0110 std::lock_guard<std::mutex> lock(fIsActiveLock);
0111 fIsActive = false;
0112 }
0113
0114 fBatchLoader->DeActivate();
0115
0116 if (fLoadingThread) {
0117 if (fLoadingThread->joinable()) {
0118 fLoadingThread->join();
0119 }
0120 }
0121 }
0122
0123
0124
0125 void Activate()
0126 {
0127 if (fIsActive)
0128 return;
0129
0130 {
0131 std::lock_guard<std::mutex> lock(fIsActiveLock);
0132 fIsActive = true;
0133 }
0134
0135 fCurrentRow = 0;
0136 fBatchLoader->Activate();
0137 fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
0138 }
0139
0140
0141
0142
0143 const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0144 {
0145
0146 return fBatchLoader->GetTrainBatch();
0147 }
0148
0149
0150
0151
0152 const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0153 {
0154
0155 return fBatchLoader->GetValidationBatch();
0156 }
0157
0158 bool HasTrainData() { return fBatchLoader->HasTrainData(); }
0159
0160 bool HasValidationData() { return fBatchLoader->HasValidationData(); }
0161
0162 void LoadChunks()
0163 {
0164 ROOT::EnableThreadSafety();
0165
0166 for (std::size_t current_chunk = 0; ((current_chunk < fMaxChunks) || fUseWholeFile) && fCurrentRow < fNumEntries;
0167 current_chunk++) {
0168
0169
0170 {
0171 std::lock_guard<std::mutex> lock(fIsActiveLock);
0172 if (!fIsActive)
0173 return;
0174 }
0175
0176
0177 std::pair<std::size_t, std::size_t> report = fChunkLoader->LoadChunk(*fChunkTensor, fCurrentRow);
0178 fCurrentRow += report.first;
0179
0180 CreateBatches(current_chunk, report.second);
0181
0182
0183 if (report.first < fChunkSize) {
0184 break;
0185 }
0186 }
0187
0188 fBatchLoader->DeActivate();
0189 }
0190
0191
0192
0193
0194 void CreateBatches(std::size_t currentChunk, std::size_t processedEvents)
0195 {
0196
0197
0198 if (fTrainingIdxs.size() > currentChunk) {
0199 fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
0200 } else {
0201
0202 createIdxs(processedEvents);
0203 fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
0204 fBatchLoader->CreateValidationBatches(*fChunkTensor, fValidationIdxs[currentChunk]);
0205 }
0206 }
0207
0208
0209
0210 void createIdxs(std::size_t processedEvents)
0211 {
0212
0213 std::vector<std::size_t> row_order = std::vector<std::size_t>(processedEvents);
0214 std::iota(row_order.begin(), row_order.end(), 0);
0215
0216 if (fShuffle) {
0217 std::shuffle(row_order.begin(), row_order.end(), fRng);
0218 }
0219
0220
0221 std::size_t num_validation = ceil(processedEvents * fValidationSplit);
0222
0223
0224 std::vector<std::size_t> valid_idx({row_order.begin(), row_order.begin() + num_validation});
0225 std::vector<std::size_t> train_idx({row_order.begin() + num_validation, row_order.end()});
0226
0227 fTrainingIdxs.push_back(train_idx);
0228 fValidationIdxs.push_back(valid_idx);
0229 }
0230
0231 void StartValidation() { fBatchLoader->StartValidation(); }
0232 bool IsActive() { return fIsActive; }
0233 };
0234
0235 }
0236 }
0237 }
0238
0239 #endif