Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_BATCHGENERATOR
0002 #define TMVA_BATCHGENERATOR
0003 
0004 #include <iostream>
0005 #include <vector>
0006 #include <thread>
0007 #include <memory>
0008 #include <cmath>
0009 #include <mutex>
0010 
0011 #include "TMVA/RTensor.hxx"
0012 #include "ROOT/RDF/RDatasetSpec.hxx"
0013 #include "TMVA/RChunkLoader.hxx"
0014 #include "TMVA/RBatchLoader.hxx"
0015 #include "TMVA/Tools.h"
0016 #include "TRandom3.h"
0017 #include "TROOT.h"
0018 
0019 namespace TMVA {
0020 namespace Experimental {
0021 namespace Internal {
0022 
0023 template <typename... Args>
0024 class RBatchGenerator {
0025 private:
0026    TMVA::RandomGenerator<TRandom3> fRng = TMVA::RandomGenerator<TRandom3>(0);
0027 
0028    std::string fFileName;
0029    std::string fTreeName;
0030 
0031    std::vector<std::string> fCols;
0032    std::string fFilters;
0033 
0034    std::size_t fChunkSize;
0035    std::size_t fMaxChunks;
0036    std::size_t fBatchSize;
0037    std::size_t fMaxBatches;
0038    std::size_t fNumColumns;
0039    std::size_t fNumEntries;
0040    std::size_t fCurrentRow = 0;
0041 
0042    float fValidationSplit;
0043 
0044    std::unique_ptr<TMVA::Experimental::Internal::RChunkLoader<Args...>> fChunkLoader;
0045    std::unique_ptr<TMVA::Experimental::Internal::RBatchLoader> fBatchLoader;
0046 
0047    std::unique_ptr<std::thread> fLoadingThread;
0048 
0049    bool fUseWholeFile = true;
0050 
0051    std::unique_ptr<TMVA::Experimental::RTensor<float>> fChunkTensor;
0052    std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
0053 
0054    std::vector<std::vector<std::size_t>> fTrainingIdxs;
0055    std::vector<std::vector<std::size_t>> fValidationIdxs;
0056 
0057    // filled batch elements
0058    std::mutex fIsActiveLock;
0059 
0060    bool fShuffle = true;
0061    bool fIsActive = false;
0062 
0063    std::vector<std::size_t> fVecSizes;
0064    float fVecPadding;
0065 
0066 public:
0067    RBatchGenerator(const std::string &treeName, const std::string &fileName, const std::size_t chunkSize,
0068                    const std::size_t batchSize, const std::vector<std::string> &cols, const std::string &filters = "",
0069                    const std::vector<std::size_t> &vecSizes = {}, const float vecPadding = 0.0,
0070                    const float validationSplit = 0.0, const std::size_t maxChunks = 0, const std::size_t numColumns = 0,
0071                    bool shuffle = true)
0072       : fTreeName(treeName),
0073         fFileName(fileName),
0074         fChunkSize(chunkSize),
0075         fBatchSize(batchSize),
0076         fCols(cols),
0077         fFilters(filters),
0078         fVecSizes(vecSizes),
0079         fVecPadding(vecPadding),
0080         fValidationSplit(validationSplit),
0081         fMaxChunks(maxChunks),
0082         fNumColumns((numColumns != 0) ? numColumns : cols.size()),
0083         fShuffle(shuffle),
0084         fUseWholeFile(maxChunks == 0)
0085    {
0086       // limits the number of batches that can be contained in the batchqueue based on the chunksize
0087       fMaxBatches = ceil((fChunkSize / fBatchSize) * (1 - fValidationSplit));
0088 
0089       // get the number of fNumEntries in the dataframe
0090       std::unique_ptr<TFile> f{TFile::Open(fFileName.c_str())};
0091       std::unique_ptr<TTree> t{f->Get<TTree>(fTreeName.c_str())};
0092       fNumEntries = t->GetEntries();
0093 
0094       fChunkLoader = std::make_unique<TMVA::Experimental::Internal::RChunkLoader<Args...>>(
0095          fTreeName, fFileName, fChunkSize, fCols, fFilters, fVecSizes, fVecPadding);
0096       fBatchLoader = std::make_unique<TMVA::Experimental::Internal::RBatchLoader>(fBatchSize, fNumColumns, fMaxBatches);
0097 
0098       // Create tensor to load the chunk into
0099       fChunkTensor =
0100          std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fChunkSize, fNumColumns});
0101    }
0102 
0103    ~RBatchGenerator() { DeActivate(); }
0104 
0105    /// \brief De-activate the loading process by deactivating the batchgenerator
0106    /// and joining the loading thread
0107    void DeActivate()
0108    {
0109       {
0110          std::lock_guard<std::mutex> lock(fIsActiveLock);
0111          fIsActive = false;
0112       }
0113 
0114       fBatchLoader->DeActivate();
0115 
0116       if (fLoadingThread) {
0117          if (fLoadingThread->joinable()) {
0118             fLoadingThread->join();
0119          }
0120       }
0121    }
0122 
0123    /// \brief Activate the loading process by starting the batchloader, and
0124    /// spawning the loading thread.
0125    void Activate()
0126    {
0127       if (fIsActive)
0128          return;
0129 
0130       {
0131          std::lock_guard<std::mutex> lock(fIsActiveLock);
0132          fIsActive = true;
0133       }
0134 
0135       fCurrentRow = 0;
0136       fBatchLoader->Activate();
0137       fLoadingThread = std::make_unique<std::thread>(&RBatchGenerator::LoadChunks, this);
0138    }
0139 
0140    /// \brief Returns the next batch of training data if available.
0141    /// Returns empty RTensor otherwise.
0142    /// \return
0143    const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0144    {
0145       // Get next batch if available
0146       return fBatchLoader->GetTrainBatch();
0147    }
0148 
0149    /// \brief Returns the next batch of validation data if available.
0150    /// Returns empty RTensor otherwise.
0151    /// \return
0152    const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0153    {
0154       // Get next batch if available
0155       return fBatchLoader->GetValidationBatch();
0156    }
0157 
0158    bool HasTrainData() { return fBatchLoader->HasTrainData(); }
0159 
0160    bool HasValidationData() { return fBatchLoader->HasValidationData(); }
0161 
0162    void LoadChunks()
0163    {
0164       ROOT::EnableThreadSafety();
0165 
0166       for (std::size_t current_chunk = 0; ((current_chunk < fMaxChunks) || fUseWholeFile) && fCurrentRow < fNumEntries;
0167            current_chunk++) {
0168 
0169          // stop the loop when the loading is not active anymore
0170          {
0171             std::lock_guard<std::mutex> lock(fIsActiveLock);
0172             if (!fIsActive)
0173                return;
0174          }
0175 
0176          // A pair that consists the proccessed, and passed events while loading the chunk
0177          std::pair<std::size_t, std::size_t> report = fChunkLoader->LoadChunk(*fChunkTensor, fCurrentRow);
0178          fCurrentRow += report.first;
0179 
0180          CreateBatches(current_chunk, report.second);
0181 
0182          // Stop loading if the number of processed events is smaller than the desired chunk size
0183          if (report.first < fChunkSize) {
0184             break;
0185          }
0186       }
0187 
0188       fBatchLoader->DeActivate();
0189    }
0190 
0191    /// \brief Create batches for the current_chunk.
0192    /// \param currentChunk
0193    /// \param processedEvents
0194    void CreateBatches(std::size_t currentChunk, std::size_t processedEvents)
0195    {
0196 
0197       // Check if the indices in this chunk where already split in train and validations
0198       if (fTrainingIdxs.size() > currentChunk) {
0199          fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
0200       } else {
0201          // Create the Validation batches if this is not the first epoch
0202          createIdxs(processedEvents);
0203          fBatchLoader->CreateTrainingBatches(*fChunkTensor, fTrainingIdxs[currentChunk], fShuffle);
0204          fBatchLoader->CreateValidationBatches(*fChunkTensor, fValidationIdxs[currentChunk]);
0205       }
0206    }
0207 
0208    /// \brief plit the events of the current chunk into validation and training events
0209    /// \param processedEvents
0210    void createIdxs(std::size_t processedEvents)
0211    {
0212       // Create a vector of number 1..processedEvents
0213       std::vector<std::size_t> row_order = std::vector<std::size_t>(processedEvents);
0214       std::iota(row_order.begin(), row_order.end(), 0);
0215 
0216       if (fShuffle) {
0217          std::shuffle(row_order.begin(), row_order.end(), fRng);
0218       }
0219 
0220       // calculate the number of events used for validation
0221       std::size_t num_validation = ceil(processedEvents * fValidationSplit);
0222 
0223       // Devide the vector into training and validation
0224       std::vector<std::size_t> valid_idx({row_order.begin(), row_order.begin() + num_validation});
0225       std::vector<std::size_t> train_idx({row_order.begin() + num_validation, row_order.end()});
0226 
0227       fTrainingIdxs.push_back(train_idx);
0228       fValidationIdxs.push_back(valid_idx);
0229    }
0230 
0231    void StartValidation() { fBatchLoader->StartValidation(); }
0232    bool IsActive() { return fIsActive; }
0233 };
0234 
0235 } // namespace Internal
0236 } // namespace Experimental
0237 } // namespace TMVA
0238 
0239 #endif // TMVA_BATCHGENERATOR