Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-15 10:28:51

0001 // Author: Dante Niewenhuis, VU Amsterdam 07/2023
0002 // Author: Kristupas Pranckietis, Vilnius University 05/2024
0003 // Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
0004 // Author: Vincenzo Eduardo Padulano, CERN 10/2024
0005 // Author: Martin Føll, University of Oslo (UiO) & CERN 05/2025
0006 
0007 /*************************************************************************
0008  * Copyright (C) 1995-2025, Rene Brun and Fons Rademakers.               *
0009  * All rights reserved.                                                  *
0010  *                                                                       *
0011  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0012  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0013  *************************************************************************/
0014 
0015 #ifndef TMVA_RBATCHGENERATOR
0016 #define TMVA_RBATCHGENERATOR
0017 
0018 #include "TMVA/RTensor.hxx"
0019 #include "ROOT/RDF/RDatasetSpec.hxx"
0020 #include "TMVA/BatchGenerator/RChunkLoader.hxx"
0021 #include "TMVA/BatchGenerator/RBatchLoader.hxx"
0022 #include "TROOT.h"
0023 
0024 #include <cmath>
0025 #include <memory>
0026 #include <mutex>
0027 #include <random>
0028 #include <thread>
0029 #include <variant>
0030 #include <vector>
0031 
0032 namespace TMVA {
0033 namespace Experimental {
0034 namespace Internal {
0035 
0036 // clang-format off
0037 /**
0038 \class ROOT::TMVA::Experimental::Internal::RBatchGenerator
0039 \ingroup tmva
0040 \brief 
0041 
0042 In this class, the processes of loading chunks (see RChunkLoader) and creating batches from those chunks (see RBatchLoader) are combined, allowing batches from the training and validation sets to be loaded directly from a dataset in an RDataFrame.
0043 */
0044 
0045 template <typename... Args>
0046 class RBatchGenerator {
0047 private:
0048    std::vector<std::string> fCols;
0049    // clang-format on
0050    std::size_t fChunkSize;
0051    std::size_t fMaxChunks;
0052    std::size_t fBatchSize;
0053    std::size_t fBlockSize;
0054    std::size_t fNumColumns;
0055    std::size_t fNumChunkCols;
0056    std::size_t fNumEntries;
0057    std::size_t fSetSeed;
0058    std::size_t fSumVecSizes;
0059 
0060    ROOT::RDF::RResultPtr<std::vector<ULong64_t>> fEntries;
0061    float fValidationSplit;
0062 
0063    std::unique_ptr<RChunkLoader<Args...>> fChunkLoader;
0064    std::unique_ptr<RBatchLoader> fBatchLoader;
0065 
0066    std::unique_ptr<std::thread> fLoadingThread;
0067 
0068    std::size_t fTrainingChunkNum;
0069    std::size_t fValidationChunkNum;
0070 
0071    ROOT::RDF::RNode &f_rdf;
0072 
0073    std::mutex fIsActiveMutex;
0074 
0075    bool fDropRemainder;
0076    bool fShuffle;
0077    bool fIsActive{false}; // Whether the loading thread is active
0078    bool fNotFiltered;
0079    bool fUseWholeFile;
0080 
0081    bool fEpochActive{false};
0082    bool fTrainingEpochActive{false};
0083    bool fValidationEpochActive{false};
0084 
0085    std::size_t fNumTrainingEntries;
0086    std::size_t fNumValidationEntries;
0087 
0088    std::size_t fNumTrainingChunks;
0089    std::size_t fNumValidationChunks;
0090 
0091    std::size_t fLeftoverTrainingBatchSize;
0092    std::size_t fLeftoverValidationBatchSize;
0093 
0094    std::size_t fNumFullTrainingBatches;
0095    std::size_t fNumFullValidationBatches;
0096 
0097    std::size_t fNumLeftoverTrainingBatches;
0098    std::size_t fNumLeftoverValidationBatches;
0099 
0100    std::size_t fNumTrainingBatches;
0101    std::size_t fNumValidationBatches;
0102 
0103    TMVA::Experimental::RTensor<float> fTrainTensor;
0104    TMVA::Experimental::RTensor<float> fTrainChunkTensor;
0105 
0106    TMVA::Experimental::RTensor<float> fValidationTensor;
0107    TMVA::Experimental::RTensor<float> fValidationChunkTensor;
0108 
0109 public:
0110    RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t blockSize,
0111                    const std::size_t batchSize, const std::vector<std::string> &cols,
0112                    const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
0113                    const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
0114                    bool dropRemainder = true, const std::size_t setSeed = 0)
0115 
0116       : f_rdf(rdf),
0117         fCols(cols),
0118         fChunkSize(chunkSize),
0119         fBlockSize(blockSize),
0120         fBatchSize(batchSize),
0121         fValidationSplit(validationSplit),
0122         fMaxChunks(maxChunks),
0123         fDropRemainder(dropRemainder),
0124         fSetSeed(setSeed),
0125         fShuffle(shuffle),
0126         fNotFiltered(f_rdf.GetFilterNames().empty()),
0127         fUseWholeFile(maxChunks == 0),
0128         fNumColumns(cols.size()),
0129         fTrainTensor({0, 0}),
0130         fTrainChunkTensor({0, 0}),
0131         fValidationTensor({0, 0}),
0132         fValidationChunkTensor({0, 0})
0133    {
0134 
0135       fNumEntries = f_rdf.Count().GetValue();
0136       fEntries = f_rdf.Take<ULong64_t>("rdfentry_");
0137 
0138       fSumVecSizes = std::accumulate(vecSizes.begin(), vecSizes.end(), 0);
0139       fNumChunkCols = fNumColumns + fSumVecSizes - vecSizes.size();
0140       
0141       // add the last element in entries to not go out of range when filling chunks
0142       fEntries->push_back((*fEntries)[fNumEntries - 1] + 1);
0143 
0144       fChunkLoader =
0145          std::make_unique<RChunkLoader<Args...>>(f_rdf, fNumEntries, fEntries, fChunkSize, fBlockSize, fValidationSplit,
0146                                                  fCols, vecSizes, vecPadding, fShuffle, fSetSeed);
0147       fBatchLoader = std::make_unique<RBatchLoader>(fBatchSize, fNumChunkCols);
0148 
0149       // split the dataset into training and validation sets
0150       fChunkLoader->SplitDataset();
0151 
0152       // number of training and validation entries after the split
0153       fNumValidationEntries = static_cast<std::size_t>(fValidationSplit * fNumEntries);
0154       fNumTrainingEntries = fNumEntries - fNumValidationEntries;
0155       
0156       fLeftoverTrainingBatchSize = fNumTrainingEntries % fBatchSize;
0157       fLeftoverValidationBatchSize = fNumValidationEntries % fBatchSize;
0158 
0159       fNumFullTrainingBatches = fNumTrainingEntries / fBatchSize;
0160       fNumFullValidationBatches = fNumValidationEntries / fBatchSize;
0161 
0162       fNumLeftoverTrainingBatches = fLeftoverTrainingBatchSize == 0 ? 0 : 1;
0163       fNumLeftoverValidationBatches = fLeftoverValidationBatchSize == 0 ? 0 : 1;
0164 
0165       if (dropRemainder) {
0166          fNumTrainingBatches = fNumFullTrainingBatches;
0167          fNumValidationBatches = fNumFullValidationBatches;
0168       }
0169 
0170       else {
0171          fNumTrainingBatches = fNumFullTrainingBatches + fNumLeftoverTrainingBatches;
0172          fNumValidationBatches = fNumFullValidationBatches + fNumLeftoverValidationBatches;
0173       }
0174 
0175       // number of training and validation chunks, calculated in RChunkConstructor
0176       fNumTrainingChunks = fChunkLoader->GetNumTrainingChunks();
0177       fNumValidationChunks = fChunkLoader->GetNumValidationChunks();
0178 
0179       fTrainingChunkNum = 0;
0180       fValidationChunkNum = 0;
0181    }
0182 
0183    ~RBatchGenerator() { DeActivate(); }
0184 
0185    void DeActivate()
0186    {
0187       {
0188          std::lock_guard<std::mutex> lock(fIsActiveMutex);
0189          fIsActive = false;
0190       }
0191 
0192       fBatchLoader->DeActivate();
0193 
0194       if (fLoadingThread) {
0195          if (fLoadingThread->joinable()) {
0196             fLoadingThread->join();
0197          }
0198       }
0199    }
0200 
0201    /// \brief Activate the loading process by starting the batchloader, and
0202    /// spawning the loading thread.
0203    void Activate()
0204    {
0205       if (fIsActive)
0206          return;
0207 
0208       {
0209          std::lock_guard<std::mutex> lock(fIsActiveMutex);
0210          fIsActive = true;
0211       }
0212 
0213       fBatchLoader->Activate();
0214       // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
0215    }
0216 
0217    void ActivateEpoch() { fEpochActive = true; }
0218 
0219    void DeActivateEpoch() { fEpochActive = false; }
0220 
0221    void ActivateTrainingEpoch() { fTrainingEpochActive = true; }
0222 
0223    void DeActivateTrainingEpoch() { fTrainingEpochActive = false; }
0224 
0225    void ActivateValidationEpoch() { fValidationEpochActive = true; }
0226 
0227    void DeActivateValidationEpoch() { fValidationEpochActive = false; }
0228 
0229    /// \brief Create training batches by first loading a chunk (see RChunkLoader) and split it into batches (see RBatchLoader)
0230    void CreateTrainBatches()
0231    {
0232 
0233       fChunkLoader->CreateTrainingChunksIntervals();
0234       fTrainingEpochActive = true;
0235       fTrainingChunkNum = 0;
0236       fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, fTrainingChunkNum);
0237       std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum;
0238       fBatchLoader->CreateTrainingBatches(fTrainChunkTensor, lastTrainingBatch, fLeftoverTrainingBatchSize,
0239                                           fDropRemainder);
0240       fTrainingChunkNum++;
0241    }
0242 
0243    /// \brief Creates validation batches by first loading a chunk (see RChunkLoader), and then split it into batches (see RBatchLoader)   
0244    void CreateValidationBatches()
0245    {
0246 
0247       fChunkLoader->CreateValidationChunksIntervals();
0248       fValidationEpochActive = true;
0249       fValidationChunkNum = 0;
0250       fChunkLoader->LoadValidationChunk(fValidationChunkTensor, fValidationChunkNum);
0251       std::size_t lastValidationBatch = fNumValidationChunks - fValidationChunkNum;
0252       fBatchLoader->CreateValidationBatches(fValidationChunkTensor, lastValidationBatch, fLeftoverValidationBatchSize,
0253                                             fDropRemainder);
0254       fValidationChunkNum++;
0255    }
0256 
0257    /// \brief Loads a training batch from the queue
0258    TMVA::Experimental::RTensor<float> GetTrainBatch()
0259    {
0260       auto batchQueue = fBatchLoader->GetNumTrainingBatchQueue();
0261 
0262       // load the next chunk if the queue is empty
0263       if (batchQueue < 1 && fTrainingChunkNum < fNumTrainingChunks) {
0264          fChunkLoader->LoadTrainingChunk(fTrainChunkTensor, fTrainingChunkNum);
0265          std::size_t lastTrainingBatch = fNumTrainingChunks - fTrainingChunkNum;
0266          fBatchLoader->CreateTrainingBatches(fTrainChunkTensor, lastTrainingBatch, fLeftoverTrainingBatchSize,
0267                                              fDropRemainder);
0268          fTrainingChunkNum++;
0269       }
0270 
0271       else {
0272          ROOT::Internal::RDF::ChangeBeginAndEndEntries(f_rdf, 0, fNumEntries);
0273       }
0274 
0275       // Get next batch if available
0276       return fBatchLoader->GetTrainBatch();
0277    }
0278 
0279    /// \brief Loads a validation batch from the queue   
0280    TMVA::Experimental::RTensor<float> GetValidationBatch()
0281    {
0282       auto batchQueue = fBatchLoader->GetNumValidationBatchQueue();
0283 
0284       // load the next chunk if the queue is empty      
0285       if (batchQueue < 1 && fValidationChunkNum < fNumValidationChunks) {
0286          fChunkLoader->LoadValidationChunk(fValidationChunkTensor, fValidationChunkNum);
0287          std::size_t lastValidationBatch = fNumValidationChunks - fValidationChunkNum;
0288          fBatchLoader->CreateValidationBatches(fValidationChunkTensor, lastValidationBatch,
0289                                                fLeftoverValidationBatchSize, fDropRemainder);
0290          fValidationChunkNum++;
0291       }
0292 
0293       else {
0294          ROOT::Internal::RDF::ChangeBeginAndEndEntries(f_rdf, 0, fNumEntries);
0295       }
0296 
0297       // Get next batch if available
0298       return fBatchLoader->GetValidationBatch();
0299    }
0300 
0301    std::size_t NumberOfTrainingBatches() { return fNumTrainingBatches; }
0302    std::size_t NumberOfValidationBatches() { return fNumValidationBatches; }
0303 
0304    std::size_t TrainRemainderRows() { return fLeftoverTrainingBatchSize; }
0305    std::size_t ValidationRemainderRows() { return fLeftoverValidationBatchSize; }
0306 
0307    bool IsActive() { return fIsActive; }
0308    bool TrainingIsActive() { return fTrainingEpochActive; }
0309    /// \brief Returns the next batch of validation data if available.
0310    /// Returns empty RTensor otherwise.
0311 };
0312 
0313 } // namespace Internal
0314 } // namespace Experimental
0315 } // namespace TMVA
0316 
0317 #endif // TMVA_RBATCHGENERATOR