Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 #ifndef TMVA_RBatchLoader
0002 #define TMVA_RBatchLoader
0003 
0004 #include <iostream>
0005 #include <vector>
0006 #include <memory>
0007 
0008 // Imports for threading
0009 #include <queue>
0010 #include <mutex>
0011 #include <condition_variable>
0012 
0013 #include "TMVA/RTensor.hxx"
0014 #include "TMVA/Tools.h"
0015 #include "TRandom3.h"
0016 
0017 namespace TMVA {
0018 namespace Experimental {
0019 namespace Internal {
0020 
0021 class RBatchLoader {
0022 private:
0023    std::size_t fBatchSize;
0024    std::size_t fNumColumns;
0025    std::size_t fMaxBatches;
0026 
0027    bool fIsActive = false;
0028    TMVA::RandomGenerator<TRandom3> fRng = TMVA::RandomGenerator<TRandom3>(0);
0029 
0030    std::mutex fBatchLock;
0031    std::condition_variable fBatchCondition;
0032 
0033    std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fTrainingBatchQueue;
0034    std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fValidationBatches;
0035    std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
0036 
0037    std::size_t fValidationIdx = 0;
0038 
0039    TMVA::Experimental::RTensor<float> fEmptyTensor = TMVA::Experimental::RTensor<float>({0});
0040 
0041 public:
0042    RBatchLoader(const std::size_t batchSize, const std::size_t numColumns, const std::size_t maxBatches)
0043       : fBatchSize(batchSize), fNumColumns(numColumns), fMaxBatches(maxBatches)
0044    {
0045    }
0046 
0047    ~RBatchLoader() { DeActivate(); }
0048 
0049 public:
0050    /// \brief Return a batch of data as a unique pointer.
0051    /// After the batch has been processed, it should be distroyed.
0052    /// \return Training batch
0053    const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0054    {
0055       std::unique_lock<std::mutex> lock(fBatchLock);
0056       fBatchCondition.wait(lock, [this]() { return !fTrainingBatchQueue.empty() || !fIsActive; });
0057 
0058       if (fTrainingBatchQueue.empty()) {
0059          fCurrentBatch = std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({0}));
0060          return *fCurrentBatch;
0061       }
0062 
0063       fCurrentBatch = std::move(fTrainingBatchQueue.front());
0064       fTrainingBatchQueue.pop();
0065 
0066       fBatchCondition.notify_all();
0067 
0068       return *fCurrentBatch;
0069    }
0070 
0071    /// \brief Returns a batch of data for validation
0072    /// The owner of this batch has to be with the RBatchLoader.
0073    /// This is because the same validation batches should be used in all epochs.
0074    /// \return Validation batch
0075    const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0076    {
0077       if (HasValidationData()) {
0078          return *fValidationBatches[fValidationIdx++].get();
0079       }
0080 
0081       return fEmptyTensor;
0082    }
0083 
0084    /// \brief Checks if there are more training batches available
0085    /// \return
0086    bool HasTrainData()
0087    {
0088       {
0089          std::unique_lock<std::mutex> lock(fBatchLock);
0090          if (!fTrainingBatchQueue.empty() || fIsActive)
0091             return true;
0092       }
0093 
0094       return false;
0095    }
0096 
0097    /// \brief Checks if there are more training batches available
0098    /// \return
0099    bool HasValidationData()
0100    {
0101       std::unique_lock<std::mutex> lock(fBatchLock);
0102       return fValidationIdx < fValidationBatches.size();
0103    }
0104 
0105    /// \brief Activate the batchloader so it will accept chunks to batch
0106    void Activate()
0107    {
0108       {
0109          std::lock_guard<std::mutex> lock(fBatchLock);
0110          fIsActive = true;
0111       }
0112       fBatchCondition.notify_all();
0113    }
0114 
0115    /// \brief DeActivate the batchloader. This means that no more batches are created.
0116    /// Batches can still be returned if they are already loaded
0117    void DeActivate()
0118    {
0119       {
0120          std::lock_guard<std::mutex> lock(fBatchLock);
0121          fIsActive = false;
0122       }
0123       fBatchCondition.notify_all();
0124    }
0125 
0126    /// \brief Create a batch filled with the events on the given idx
0127    /// \param chunkTensor
0128    /// \param idx
0129    /// \return
0130    std::unique_ptr<TMVA::Experimental::RTensor<float>>
0131    CreateBatch(const TMVA::Experimental::RTensor<float> &chunkTensor, const std::vector<std::size_t> idx)
0132    {
0133       auto batch =
0134          std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({fBatchSize, fNumColumns}));
0135 
0136       for (std::size_t i = 0; i < fBatchSize; i++) {
0137          std::copy(chunkTensor.GetData() + (idx[i] * fNumColumns), chunkTensor.GetData() + ((idx[i] + 1) * fNumColumns),
0138                    batch->GetData() + i * fNumColumns);
0139       }
0140 
0141       return batch;
0142    }
0143 
0144    /// \brief Create training batches from the given chunk of data based on the given event indices
0145    /// Batches are added to the training queue of batches
0146    /// The eventIndices can be shuffled to ensure random order for each epoch
0147    /// \param chunkTensor
0148    /// \param eventIndices
0149    /// \param shuffle
0150    void CreateTrainingBatches(const TMVA::Experimental::RTensor<float> &chunkTensor,
0151                               std::vector<std::size_t> eventIndices, const bool shuffle = true)
0152    {
0153       // Wait until less than a full chunk of batches are in the queue before loading splitting the next chunk into
0154       // batches
0155       {
0156          std::unique_lock<std::mutex> lock(fBatchLock);
0157          fBatchCondition.wait(lock, [this]() { return (fTrainingBatchQueue.size() < fMaxBatches) || !fIsActive; });
0158          if (!fIsActive)
0159             return;
0160       }
0161 
0162       if (shuffle)
0163          std::shuffle(eventIndices.begin(), eventIndices.end(), fRng); // Shuffle the order of idx
0164 
0165       std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> batches;
0166 
0167       // Create tasks of fBatchSize untill all idx are used
0168       for (std::size_t start = 0; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0169 
0170          // Grab the first fBatchSize indices from the
0171          std::vector<std::size_t> idx;
0172          for (std::size_t i = start; i < (start + fBatchSize); i++) {
0173             idx.push_back(eventIndices[i]);
0174          }
0175 
0176          // Fill a batch
0177          batches.emplace_back(CreateBatch(chunkTensor, idx));
0178       }
0179 
0180       {
0181          std::unique_lock<std::mutex> lock(fBatchLock);
0182          for (std::size_t i = 0; i < batches.size(); i++) {
0183             fTrainingBatchQueue.push(std::move(batches[i]));
0184          }
0185       }
0186 
0187       fBatchCondition.notify_one();
0188    }
0189 
0190    /// \brief Create validation batches from the given chunk based on the given event indices
0191    /// Batches are added to the vector of validation batches
0192    /// \param chunkTensor
0193    /// \param eventIndices
0194    void CreateValidationBatches(const TMVA::Experimental::RTensor<float> &chunkTensor,
0195                                 const std::vector<std::size_t> eventIndices)
0196    {
0197       // Create tasks of fBatchSize untill all idx are used
0198       for (std::size_t start = 0; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0199 
0200          std::vector<std::size_t> idx;
0201 
0202          for (std::size_t i = start; i < (start + fBatchSize); i++) {
0203             idx.push_back(eventIndices[i]);
0204          }
0205 
0206          {
0207             std::unique_lock<std::mutex> lock(fBatchLock);
0208             fValidationBatches.emplace_back(CreateBatch(chunkTensor, idx));
0209          }
0210       }
0211    }
0212 
0213    /// \brief Reset the validation process
0214    void StartValidation()
0215    {
0216       std::unique_lock<std::mutex> lock(fBatchLock);
0217       fValidationIdx = 0;
0218    }
0219 };
0220 
0221 } // namespace Internal
0222 } // namespace Experimental
0223 } // namespace TMVA
0224 
0225 #endif // TMVA_RBatchLoader