Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:14:38

0001 // Author: Dante Niewenhuis, VU Amsterdam 07/2023
0002 // Author: Kristupas Pranckietis, Vilnius University 05/2024
0003 // Author: Nopphakorn Subsa-Ard, King Mongkut's University of Technology Thonburi (KMUTT) (TH) 08/2024
0004 // Author: Vincenzo Eduardo Padulano, CERN 10/2024
0005 
0006 /*************************************************************************
0007  * Copyright (C) 1995-2024, Rene Brun and Fons Rademakers.               *
0008  * All rights reserved.                                                  *
0009  *                                                                       *
0010  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0011  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0012  *************************************************************************/
0013 
0014 #ifndef TMVA_RBATCHLOADER
0015 #define TMVA_RBATCHLOADER
0016 
0017 #include <vector>
0018 #include <memory>
0019 #include <numeric>
0020 
0021 // Imports for threading
0022 #include <queue>
0023 #include <mutex>
0024 #include <condition_variable>
0025 
0026 #include "TMVA/RTensor.hxx"
0027 #include "TMVA/Tools.h"
0028 
0029 namespace TMVA {
0030 namespace Experimental {
0031 namespace Internal {
0032 
0033 class RBatchLoader {
0034 private:
0035    const TMVA::Experimental::RTensor<float> &fChunkTensor;
0036    std::size_t fBatchSize;
0037    std::size_t fNumColumns;
0038    std::size_t fMaxBatches;
0039    std::size_t fTrainingRemainderRow = 0;
0040    std::size_t fValidationRemainderRow = 0;
0041 
0042    bool fIsActive = false;
0043 
0044    std::mutex fBatchLock;
0045    std::condition_variable fBatchCondition;
0046 
0047    std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fTrainingBatchQueue;
0048    std::queue<std::unique_ptr<TMVA::Experimental::RTensor<float>>> fValidationBatchQueue;
0049    std::unique_ptr<TMVA::Experimental::RTensor<float>> fCurrentBatch;
0050 
0051    std::unique_ptr<TMVA::Experimental::RTensor<float>> fTrainingRemainder;
0052    std::unique_ptr<TMVA::Experimental::RTensor<float>> fValidationRemainder;
0053 
0054 public:
0055    RBatchLoader(const TMVA::Experimental::RTensor<float> &chunkTensor, const std::size_t batchSize,
0056                 const std::size_t numColumns, const std::size_t maxBatches)
0057       : fChunkTensor(chunkTensor), fBatchSize(batchSize), fNumColumns(numColumns), fMaxBatches(maxBatches)
0058    {
0059       // Create remainders tensors
0060       fTrainingRemainder =
0061          std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fBatchSize - 1, fNumColumns});
0062       fValidationRemainder =
0063          std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>{fBatchSize - 1, fNumColumns});
0064    }
0065 
0066    ~RBatchLoader() { DeActivate(); }
0067 
0068 public:
0069    /// \brief Return a batch of data as a unique pointer.
0070    /// After the batch has been processed, it should be destroyed.
0071    /// \return Training batch
0072    const TMVA::Experimental::RTensor<float> &GetTrainBatch()
0073    {
0074       std::unique_lock<std::mutex> lock(fBatchLock);
0075       fBatchCondition.wait(lock, [this]() { return !fTrainingBatchQueue.empty() || !fIsActive; });
0076 
0077       if (fTrainingBatchQueue.empty()) {
0078          fCurrentBatch = std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({0}));
0079          return *fCurrentBatch;
0080       }
0081 
0082       fCurrentBatch = std::move(fTrainingBatchQueue.front());
0083       fTrainingBatchQueue.pop();
0084 
0085       fBatchCondition.notify_all();
0086 
0087       return *fCurrentBatch;
0088    }
0089 
0090    /// \brief Returns a batch of data for validation
0091    /// The owner of this batch has to be with the RBatchLoader.
0092    /// This is because the same validation batches should be used in all epochs.
0093    /// \return Validation batch
0094    const TMVA::Experimental::RTensor<float> &GetValidationBatch()
0095    {
0096       if (fValidationBatchQueue.empty()) {
0097          fCurrentBatch = std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({0}));
0098          return *fCurrentBatch;
0099       }
0100 
0101       fCurrentBatch = std::move(fValidationBatchQueue.front());
0102       fValidationBatchQueue.pop();
0103 
0104       return *fCurrentBatch;
0105    }
0106 
0107    /// \brief Activate the batchloader so it will accept chunks to batch
0108    void Activate()
0109    {
0110       fTrainingRemainderRow = 0;
0111       fValidationRemainderRow = 0;
0112 
0113       {
0114          std::lock_guard<std::mutex> lock(fBatchLock);
0115          fIsActive = true;
0116       }
0117       fBatchCondition.notify_all();
0118    }
0119 
0120    /// \brief DeActivate the batchloader. This means that no more batches are created.
0121    /// Batches can still be returned if they are already loaded
0122    void DeActivate()
0123    {
0124       {
0125          std::lock_guard<std::mutex> lock(fBatchLock);
0126          fIsActive = false;
0127       }
0128       fBatchCondition.notify_all();
0129    }
0130 
0131    std::unique_ptr<TMVA::Experimental::RTensor<float>>
0132    CreateBatch(const TMVA::Experimental::RTensor<float> &chunkTensor, std::span<const std::size_t> idxs,
0133                std::size_t batchSize)
0134    {
0135       auto batch =
0136          std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({batchSize, fNumColumns}));
0137 
0138       for (std::size_t i = 0; i < batchSize; i++) {
0139          std::copy(chunkTensor.GetData() + (idxs[i] * fNumColumns),
0140                    chunkTensor.GetData() + ((idxs[i] + 1) * fNumColumns), batch->GetData() + i * fNumColumns);
0141       }
0142 
0143       return batch;
0144    }
0145 
0146    std::unique_ptr<TMVA::Experimental::RTensor<float>>
0147    CreateFirstBatch(const TMVA::Experimental::RTensor<float> &remainderTensor, std::size_t remainderTensorRow,
0148                     std::span<const std::size_t> eventIndices)
0149    {
0150       auto batch =
0151          std::make_unique<TMVA::Experimental::RTensor<float>>(std::vector<std::size_t>({fBatchSize, fNumColumns}));
0152 
0153       for (size_t i = 0; i < remainderTensorRow; i++) {
0154          std::copy(remainderTensor.GetData() + i * fNumColumns, remainderTensor.GetData() + (i + 1) * fNumColumns,
0155                    batch->GetData() + i * fNumColumns);
0156       }
0157 
0158       for (std::size_t i = 0; i < (fBatchSize - remainderTensorRow); i++) {
0159          std::copy(fChunkTensor.GetData() + eventIndices[i] * fNumColumns,
0160                    fChunkTensor.GetData() + (eventIndices[i] + 1) * fNumColumns,
0161                    batch->GetData() + (i + remainderTensorRow) * fNumColumns);
0162       }
0163 
0164       return batch;
0165    }
0166 
0167    /// @brief save to remaining data when the whole chunk has to be saved
0168    /// @param chunkTensor
0169    /// @param remainderTensor
0170    /// @param remainderTensorRow
0171    /// @param eventIndices
0172    void SaveRemainingData(TMVA::Experimental::RTensor<float> &remainderTensor, const std::size_t remainderTensorRow,
0173                           const std::vector<std::size_t> eventIndices, const std::size_t start = 0)
0174    {
0175       for (std::size_t i = start; i < eventIndices.size(); i++) {
0176          std::copy(fChunkTensor.GetData() + eventIndices[i] * fNumColumns,
0177                    fChunkTensor.GetData() + (eventIndices[i] + 1) * fNumColumns,
0178                    remainderTensor.GetData() + (i - start + remainderTensorRow) * fNumColumns);
0179       }
0180    }
0181 
0182    /// \brief Create training batches from the given chunk of data based on the given event indices
0183    /// Batches are added to the training queue of batches
0184    /// \param chunkTensor
0185    /// \param eventIndices
0186    void CreateTrainingBatches(const std::vector<std::size_t> &eventIndices)
0187    {
0188       // Wait until less than a full chunk of batches are in the queue before splitting the next chunk into
0189       // batches
0190       {
0191          std::unique_lock<std::mutex> lock(fBatchLock);
0192          fBatchCondition.wait(lock, [this]() { return (fTrainingBatchQueue.size() < fMaxBatches) || !fIsActive; });
0193          if (!fIsActive)
0194             return;
0195       }
0196 
0197       std::vector<std::unique_ptr<TMVA::Experimental::RTensor<float>>> batches;
0198 
0199       if (eventIndices.size() + fTrainingRemainderRow >= fBatchSize) {
0200          batches.emplace_back(CreateFirstBatch(*fTrainingRemainder, fTrainingRemainderRow, eventIndices));
0201       } else {
0202          SaveRemainingData(*fTrainingRemainder, fTrainingRemainderRow, eventIndices);
0203          fTrainingRemainderRow += eventIndices.size();
0204          return;
0205       }
0206 
0207       // Create tasks of fBatchSize until all idx are used
0208       std::size_t start = fBatchSize - fTrainingRemainderRow;
0209       for (; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0210          // Grab the first fBatchSize indices
0211          std::span<const std::size_t> idxs{eventIndices.data() + start, eventIndices.data() + start + fBatchSize};
0212 
0213          // Fill a batch
0214          batches.emplace_back(CreateBatch(fChunkTensor, idxs, fBatchSize));
0215       }
0216 
0217       {
0218          std::unique_lock<std::mutex> lock(fBatchLock);
0219          for (std::size_t i = 0; i < batches.size(); i++) {
0220             fTrainingBatchQueue.push(std::move(batches[i]));
0221          }
0222       }
0223 
0224       fBatchCondition.notify_all();
0225 
0226       fTrainingRemainderRow = eventIndices.size() - start;
0227       SaveRemainingData(*fTrainingRemainder, 0, eventIndices, start);
0228    }
0229 
0230    /// \brief Create validation batches from the given chunk based on the given event indices
0231    /// Batches are added to the vector of validation batches
0232    /// \param chunkTensor
0233    /// \param eventIndices
0234    void CreateValidationBatches(const std::vector<std::size_t> &eventIndices)
0235    {
0236       if (eventIndices.size() + fValidationRemainderRow >= fBatchSize) {
0237          fValidationBatchQueue.push(CreateFirstBatch(*fValidationRemainder, fValidationRemainderRow, eventIndices));
0238       } else {
0239          SaveRemainingData(*fValidationRemainder, fValidationRemainderRow, eventIndices);
0240          fValidationRemainderRow += eventIndices.size();
0241          return;
0242       }
0243 
0244       // Create tasks of fBatchSize untill all idx are used
0245       std::size_t start = fBatchSize - fValidationRemainderRow;
0246       for (; (start + fBatchSize) <= eventIndices.size(); start += fBatchSize) {
0247 
0248          std::vector<std::size_t> idx;
0249 
0250          for (std::size_t i = start; i < (start + fBatchSize); i++) {
0251             idx.push_back(eventIndices[i]);
0252          }
0253 
0254          fValidationBatchQueue.push(CreateBatch(fChunkTensor, idx, fBatchSize));
0255       }
0256 
0257       fValidationRemainderRow = eventIndices.size() - start;
0258       SaveRemainingData(*fValidationRemainder, 0, eventIndices, start);
0259    }
0260 
0261    void LastBatches()
0262    {
0263       {
0264          if (fTrainingRemainderRow) {
0265             std::vector<std::size_t> idx = std::vector<std::size_t>(fTrainingRemainderRow);
0266             std::iota(idx.begin(), idx.end(), 0);
0267 
0268             std::unique_ptr<TMVA::Experimental::RTensor<float>> batch =
0269                CreateBatch(*fTrainingRemainder, idx, fTrainingRemainderRow);
0270 
0271             std::unique_lock<std::mutex> lock(fBatchLock);
0272             fTrainingBatchQueue.push(std::move(batch));
0273          }
0274       }
0275 
0276       if (fValidationRemainderRow) {
0277          std::vector<std::size_t> idx = std::vector<std::size_t>(fValidationRemainderRow);
0278          std::iota(idx.begin(), idx.end(), 0);
0279 
0280          fValidationBatchQueue.push(CreateBatch(*fValidationRemainder, idx, fValidationRemainderRow));
0281       }
0282    }
0283 };
0284 
0285 } // namespace Internal
0286 } // namespace Experimental
0287 } // namespace TMVA
0288 
0289 #endif // TMVA_RBATCHLOADER