Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:14:37

0001 // Author: Dante Niewenhuis, VU Amsterdam 07/2023
0002 // Author: Kristupas Pranckietis, Vilnius University 05/2024
0003 // Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
0004 // Author: Vincenzo Eduardo Padulano, CERN 10/2024
0005 
0006 /*************************************************************************
0007  * Copyright (C) 1995-2024, Rene Brun and Fons Rademakers.               *
0008  * All rights reserved.                                                  *
0009  *                                                                       *
0010  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0011  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0012  *************************************************************************/
0013 
0014 #ifndef TMVA_RBATCHGENERATOR
0015 #define TMVA_RBATCHGENERATOR
0016 
0017 #include "TMVA/RTensor.hxx"
0018 #include "ROOT/RDF/RDatasetSpec.hxx"
0019 #include "TMVA/BatchGenerator/RChunkLoader.hxx"
0020 #include "TMVA/BatchGenerator/RBatchLoader.hxx"
0021 #include "TROOT.h"
0022 
0023 #include <cmath>
0024 #include <memory>
0025 #include <mutex>
0026 #include <random>
0027 #include <thread>
0028 #include <variant>
0029 #include <vector>
0030 
0031 namespace TMVA {
0032 namespace Experimental {
0033 namespace Internal {
0034 
0035 template <typename... Args>
0036 class RBatchGenerator {
0037 private:
0038    std::mt19937 fRng;
0039    std::mt19937 fFixedRng;
0040    std::random_device::result_type fFixedSeed;
0041 
0042    std::size_t fChunkSize;
0043    std::size_t fMaxChunks;
0044    std::size_t fBatchSize;
0045    std::size_t fNumEntries;
0046 
0047    float fValidationSplit;
0048 
0049    std::variant<std::shared_ptr<RChunkLoader<Args...>>, std::shared_ptr<RChunkLoaderFilters<Args...>>> fChunkLoader;
0050 
0051    std::unique_ptr<RBatchLoader> fBatchLoader;
0052 
0053    std::unique_ptr<std::thread> fLoadingThread;
0054 
0055    std::unique_ptr<TMVA::Experimental::RTensor<float>> fChunkTensor;
0056 
0057    ROOT::RDF::RNode &f_rdf;
0058 
0059    std::mutex fIsActiveMutex;
0060 
0061    bool fDropRemainder;
0062    bool fShuffle;
0063    bool fIsActive{false}; // Whether the loading thread is active
0064    bool fNotFiltered;
0065    bool fUseWholeFile;
0066 
0067 public:
0068    RBatchGenerator(ROOT::RDF::RNode &rdf, const std::size_t chunkSize, const std::size_t batchSize,
0069                    const std::vector<std::string> &cols, const std::size_t numColumns,
0070                    const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
0071                    const float validationSplit = 0.0, const std::size_t maxChunks = 0, bool shuffle = true,
0072                    bool dropRemainder = true)
0073       : fRng(std::random_device{}()),
0074         fFixedSeed(std::uniform_int_distribution<std::random_device::result_type>{}(fRng)),
0075         f_rdf(rdf),
0076         fChunkSize(chunkSize),
0077         fBatchSize(batchSize),
0078         fValidationSplit(validationSplit),
0079         fMaxChunks(maxChunks),
0080         fDropRemainder(dropRemainder),
0081         fShuffle(shuffle),
0082         fNotFiltered(f_rdf.GetFilterNames().empty()),
0083         fUseWholeFile(maxChunks == 0)
0084    {
0085 
0086       // Create tensor to load the chunk into
0087       fChunkTensor =
0088          std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fChunkSize, numColumns});
0089 
0090       if (fNotFiltered) {
0091          fNumEntries = f_rdf.Count().GetValue();
0092 
0093          fChunkLoader = std::make_unique<TMVA::Experimental::Internal::RChunkLoader<Args...>>(
0094             f_rdf, *fChunkTensor, fChunkSize, cols, vecSizes, vecPadding);
0095       } else {
0096          auto report = f_rdf.Report();
0097          fNumEntries = f_rdf.Count().GetValue();
0098          std::size_t numAllEntries = report.begin()->GetAll();
0099 
0100          fChunkLoader = std::make_unique<TMVA::Experimental::Internal::RChunkLoaderFilters<Args...>>(
0101             f_rdf, *fChunkTensor, fChunkSize, cols, fNumEntries, numAllEntries, vecSizes, vecPadding);
0102       }
0103 
0104       std::size_t maxBatches = ceil((fChunkSize / fBatchSize) * (1 - fValidationSplit));
0105 
0106       // limits the number of batches that can be contained in the batchqueue based on the chunksize
0107       fBatchLoader = std::make_unique<TMVA::Experimental::Internal::RBatchLoader>(*fChunkTensor, fBatchSize, numColumns,
0108                                                                                   maxBatches);
0109    }
0110 
0111    ~RBatchGenerator() { DeActivate(); }
0112 
0113    /// \brief De-activate the loading process by deactivating the batchgenerator
0114    /// and joining the loading thread
0115    void DeActivate()
0116    {
0117       {
0118          std::lock_guard<std::mutex> lock(fIsActiveMutex);
0119          fIsActive = false;
0120       }
0121 
0122       fBatchLoader->DeActivate();
0123 
0124       if (fLoadingThread) {
0125          if (fLoadingThread->joinable()) {
0126             fLoadingThread->join();
0127          }
0128       }
0129    }
0130 
0131    /// \brief Activate the loading process by starting the batchloader, and
0132    /// spawning the loading thread.
0133    void Activate()
0134    {
0135       if (fIsActive)
0136          return;
0137 
0138       {
0139          std::lock_guard<std::mutex> lock(fIsActiveMutex);
0140          fIsActive = true;
0141       }
0142 
0143       fFixedRng.seed(fFixedSeed);
0144       fBatchLoader->Activate();
0145       // fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
0146       if (fNotFiltered) {
0147          fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunksNoFilters, this);
0148       } else {
0149          fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunksFilters, this);
0150       }
0151    }
0152 
0153    /// \brief Returns the next batch of training data if available.
0154    /// Returns empty RTensor otherwise.
0155    /// \return
0156    const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0157    {
0158       // Get next batch if available
0159       return fBatchLoader->GetTrainBatch();
0160    }
0161 
0162    /// \brief Returns the next batch of validation data if available.
0163    /// Returns empty RTensor otherwise.
0164    /// \return
0165    const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0166    {
0167       // Get next batch if available
0168       return fBatchLoader->GetValidationBatch();
0169    }
0170 
0171    std::size_t NumberOfTrainingBatches()
0172    {
0173       std::size_t entriesForTraining =
0174          (fNumEntries / fChunkSize) * (fChunkSize - floor(fChunkSize * fValidationSplit)) + fNumEntries % fChunkSize -
0175          floor(fValidationSplit * (fNumEntries % fChunkSize));
0176 
0177       if (fDropRemainder || !(entriesForTraining % fBatchSize)) {
0178          return entriesForTraining / fBatchSize;
0179       }
0180 
0181       return entriesForTraining / fBatchSize + 1;
0182    }
0183 
0184    /// @brief Return number of training remainder rows
0185    /// @return
0186    std::size_t TrainRemainderRows()
0187    {
0188       std::size_t entriesForTraining =
0189          (fNumEntries / fChunkSize) * (fChunkSize - floor(fChunkSize * fValidationSplit)) + fNumEntries % fChunkSize -
0190          floor(fValidationSplit * (fNumEntries % fChunkSize));
0191 
0192       if (fDropRemainder || !(entriesForTraining % fBatchSize)) {
0193          return 0;
0194       }
0195 
0196       return entriesForTraining % fBatchSize;
0197    }
0198 
0199    /// @brief Calculate number of validation batches and return it
0200    /// @return
0201    std::size_t NumberOfValidationBatches()
0202    {
0203       std::size_t entriesForValidation = (fNumEntries / fChunkSize) * floor(fChunkSize * fValidationSplit) +
0204                                          floor((fNumEntries % fChunkSize) * fValidationSplit);
0205 
0206       if (fDropRemainder || !(entriesForValidation % fBatchSize)) {
0207 
0208          return entriesForValidation / fBatchSize;
0209       }
0210 
0211       return entriesForValidation / fBatchSize + 1;
0212    }
0213 
0214    /// @brief Return number of validation remainder rows
0215    /// @return
0216    std::size_t ValidationRemainderRows()
0217    {
0218       std::size_t entriesForValidation = (fNumEntries / fChunkSize) * floor(fChunkSize * fValidationSplit) +
0219                                          floor((fNumEntries % fChunkSize) * fValidationSplit);
0220 
0221       if (fDropRemainder || !(entriesForValidation % fBatchSize)) {
0222 
0223          return 0;
0224       }
0225 
0226       return entriesForValidation % fBatchSize;
0227    }
0228 
0229    /// @brief Load chunks when no filters are applied on rdataframe
0230    void LoadChunksNoFilters()
0231    {
0232       for (std::size_t currentChunk = 0, currentEntry = 0;
0233            ((currentChunk < fMaxChunks) || fUseWholeFile) && currentEntry < fNumEntries; currentChunk++) {
0234 
0235          // stop the loop when the loading is not active anymore
0236          {
0237             std::lock_guard<std::mutex> lock(fIsActiveMutex);
0238             if (!fIsActive)
0239                return;
0240          }
0241 
0242          // A pair that consists the proccessed, and passed events while loading the chunk
0243          std::size_t report = std::get<std::shared_ptr<RChunkLoader<Args...>>>(fChunkLoader)->LoadChunk(currentEntry);
0244          currentEntry += report;
0245 
0246          CreateBatches(report);
0247       }
0248 
0249       if (!fDropRemainder) {
0250          fBatchLoader->LastBatches();
0251       }
0252 
0253       fBatchLoader->DeActivate();
0254    }
0255 
0256    void LoadChunksFilters()
0257    {
0258       std::size_t currentChunk = 0;
0259       for (std::size_t processedEvents = 0, currentRow = 0;
0260            ((currentChunk < fMaxChunks) || fUseWholeFile) && processedEvents < fNumEntries; currentChunk++) {
0261 
0262          // stop the loop when the loading is not active anymore
0263          {
0264             std::lock_guard<std::mutex> lock(fIsActiveMutex);
0265             if (!fIsActive)
0266                return;
0267          }
0268 
0269          // A pair that consists the proccessed, and passed events while loading the chunk
0270          std::pair<std::size_t, std::size_t> report =
0271             std::get<std::shared_ptr<RChunkLoaderFilters<Args...>>>(fChunkLoader)->LoadChunk(currentRow);
0272 
0273          currentRow += report.first;
0274          processedEvents += report.second;
0275 
0276          CreateBatches(report.second);
0277       }
0278 
0279       if (currentChunk < fMaxChunks || fUseWholeFile) {
0280          CreateBatches(std::get<std::shared_ptr<RChunkLoaderFilters<Args...>>>(fChunkLoader)->LastChunk());
0281       }
0282 
0283       if (!fDropRemainder) {
0284          fBatchLoader->LastBatches();
0285       }
0286 
0287       fBatchLoader->DeActivate();
0288    }
0289 
0290    /// \brief Create batches
0291    /// \param processedEvents
0292    void CreateBatches(std::size_t processedEvents)
0293    {
0294       auto &&[trainingIndices, validationIndices] = createIndices(processedEvents);
0295 
0296       fBatchLoader->CreateTrainingBatches(trainingIndices);
0297       fBatchLoader->CreateValidationBatches(validationIndices);
0298    }
0299 
0300    /// \brief split the events of the current chunk into training and validation events, shuffle if needed
0301    /// \param events
0302    std::pair<std::vector<std::size_t>, std::vector<std::size_t>> createIndices(std::size_t events)
0303    {
0304       // Create a vector of number 1..events
0305       std::vector<std::size_t> row_order = std::vector<std::size_t>(events);
0306       std::iota(row_order.begin(), row_order.end(), 0);
0307 
0308       if (fShuffle) {
0309          // Shuffle the entry indices at every new epoch
0310          std::shuffle(row_order.begin(), row_order.end(), fFixedRng);
0311       }
0312 
0313       // calculate the number of events used for validation
0314       std::size_t num_validation = floor(events * fValidationSplit);
0315 
0316       // Devide the vector into training and validation and return
0317       std::vector<std::size_t> trainingIndices =
0318          std::vector<std::size_t>({row_order.begin(), row_order.end() - num_validation});
0319       std::vector<std::size_t> validationIndices =
0320          std::vector<std::size_t>({row_order.end() - num_validation, row_order.end()});
0321 
0322       if (fShuffle) {
0323          std::shuffle(trainingIndices.begin(), trainingIndices.end(), fRng);
0324       }
0325 
0326       return std::make_pair(trainingIndices, validationIndices);
0327    }
0328 
0329    bool IsActive() { return fIsActive; }
0330 };
0331 
0332 } // namespace Internal
0333 } // namespace Experimental
0334 } // namespace TMVA
0335 
0336 #endif // TMVA_RBATCHGENERATOR