Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-09 07:46:30

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/Io/Root/RootSimHitReader.hpp"
0010 
0011 #include "Acts/Definitions/PdgParticle.hpp"
0012 #include "Acts/Utilities/Logger.hpp"
0013 #include "ActsExamples/EventData/SimParticle.hpp"
0014 #include "ActsExamples/Framework/AlgorithmContext.hpp"
0015 
0016 #include <algorithm>
0017 #include <cstdint>
0018 #include <stdexcept>
0019 
0020 #include <TChain.h>
0021 #include <TMathBase.h>
0022 
0023 namespace ActsExamples {
0024 
0025 RootSimHitReader::RootSimHitReader(const RootSimHitReader::Config& config,
0026                                    Acts::Logging::Level level)
0027     : IReader(),
0028       m_cfg(config),
0029       m_logger(Acts::getDefaultLogger(name(), level)) {
0030   m_inputChain = std::make_unique<TChain>(m_cfg.treeName.c_str());
0031 
0032   if (m_cfg.filePath.empty()) {
0033     throw std::invalid_argument("Missing input filename");
0034   }
0035   if (m_cfg.treeName.empty()) {
0036     throw std::invalid_argument("Missing tree name");
0037   }
0038 
0039   m_outputSimHits.initialize(m_cfg.outputSimHits);
0040 
0041   // add file to the input chain
0042   m_inputChain->Add(m_cfg.filePath.c_str());
0043   m_inputChain->LoadTree(0);
0044   ACTS_DEBUG("Adding File " << m_cfg.filePath << " to tree '" << m_cfg.treeName
0045                             << "'.");
0046 
0047   // Set the branches
0048   auto setBranches = [&]<class T>(const auto& keys, T& columns) {
0049     using MappedType = typename std::remove_reference_t<T>::mapped_type;
0050     for (auto key : keys) {
0051       MappedType a{};  // 0 or nullptr
0052       columns.emplace(key, a);
0053     }
0054     for (auto key : keys) {
0055       m_inputChain->SetBranchAddress(key, &columns.at(key));
0056     }
0057   };
0058 
0059   setBranches(m_floatKeys, m_floatColumns);
0060   setBranches(m_uint32Keys, m_uint32Columns);
0061   setBranches(m_uint64Keys, m_uint64Columns);
0062   setBranches(m_int32Keys, m_int32Columns);
0063 
0064   if (m_inputChain->FindBranch("barcode") != nullptr) {
0065     m_hasBarcodeVector = true;
0066     m_barcodeVector.allocate();
0067     m_inputChain->SetBranchAddress("barcode", &m_barcodeVector.get());
0068   } else {
0069     m_hasBarcodeVector = false;
0070     for (const auto* key : m_barcodeComponentKeys) {
0071       if (!m_uint32Columns.contains(key)) {
0072         m_uint32Columns.emplace(key, std::uint32_t{0});
0073       }
0074       m_inputChain->SetBranchAddress(key, &m_uint32Columns.at(key));
0075     }
0076   }
0077 
0078   // Because each hit is stored in a single entry in the root file, we need to
0079   // scan the file first for the positions of the events in the file in order to
0080   // efficiently read the events later on.
0081   // TODO change the file format to store one event per entry
0082 
0083   // Disable all branches and only enable event-id for a first scan of the file
0084   m_inputChain->SetBranchStatus("*", false);
0085   m_inputChain->SetBranchStatus("event_id", true);
0086 
0087   auto nEntries = static_cast<std::size_t>(m_inputChain->GetEntriesFast());
0088   if (nEntries == 0) {
0089     throw std::runtime_error("Did not find any entries in input file");
0090   }
0091 
0092   // Add the first entry
0093   m_inputChain->GetEntry(0);
0094   m_eventMap.push_back({m_uint32Columns.at("event_id"), 0ul, 0ul});
0095 
0096   // Go through all entries and store the position of new events
0097   for (auto i = 1ul; i < nEntries; ++i) {
0098     m_inputChain->GetEntry(i);
0099     const auto evtId = m_uint32Columns.at("event_id");
0100 
0101     if (evtId != std::get<0>(m_eventMap.back())) {
0102       std::get<2>(m_eventMap.back()) = i;
0103       m_eventMap.push_back({evtId, i, i});
0104     }
0105   }
0106 
0107   std::get<2>(m_eventMap.back()) = nEntries;
0108 
0109   // Sort by event id
0110   std::ranges::sort(m_eventMap, {},
0111                     [](const auto& m) { return std::get<0>(m); });
0112 
0113   // Re-Enable all branches
0114   m_inputChain->SetBranchStatus("*", true);
0115   ACTS_DEBUG("Event range: " << availableEvents().first << " - "
0116                              << availableEvents().second);
0117 }
0118 
0119 std::pair<std::size_t, std::size_t> RootSimHitReader::availableEvents() const {
0120   return {std::get<0>(m_eventMap.front()), std::get<0>(m_eventMap.back()) + 1};
0121 }
0122 
0123 ProcessCode RootSimHitReader::read(const AlgorithmContext& context) {
0124   auto it = std::ranges::find_if(m_eventMap, [&](const auto& a) {
0125     return std::get<0>(a) == context.eventNumber;
0126   });
0127 
0128   if (it == m_eventMap.end()) {
0129     // explicitly warn if it happens for the first or last event as that might
0130     // indicate a human error
0131     if ((context.eventNumber == availableEvents().first) &&
0132         (context.eventNumber == availableEvents().second - 1)) {
0133       ACTS_WARNING("Reading empty event: " << context.eventNumber);
0134     } else {
0135       ACTS_DEBUG("Reading empty event: " << context.eventNumber);
0136     }
0137 
0138     m_outputSimHits(context, {});
0139 
0140     // Return success flag
0141     return ProcessCode::SUCCESS;
0142   }
0143 
0144   // lock the mutex
0145   std::lock_guard<std::mutex> lock(m_read_mutex);
0146 
0147   ACTS_DEBUG("Reading event: " << std::get<0>(*it)
0148                                << " stored in entries: " << std::get<1>(*it)
0149                                << " - " << std::get<2>(*it));
0150 
0151   SimHitContainer hits;
0152   for (auto entry = std::get<1>(*it); entry < std::get<2>(*it); ++entry) {
0153     m_inputChain->GetEntry(entry);
0154 
0155     auto eventId = m_uint32Columns.at("event_id");
0156     if (eventId != context.eventNumber) {
0157       break;
0158     }
0159 
0160     const Acts::GeometryIdentifier geoid{m_uint64Columns.at("geometry_id")};
0161     SimBarcode pid = SimBarcode::Invalid();
0162     if (m_hasBarcodeVector && m_barcodeVector.hasValue()) {
0163       pid = SimBarcode().withData(*m_barcodeVector);
0164     } else {
0165       pid = SimBarcode()
0166                 .withVertexPrimary(static_cast<SimBarcode::PrimaryVertexId>(
0167                     m_uint32Columns.at("barcode_vertex_primary")))
0168                 .withVertexSecondary(static_cast<SimBarcode::SecondaryVertexId>(
0169                     m_uint32Columns.at("barcode_vertex_secondary")))
0170                 .withParticle(static_cast<SimBarcode::ParticleId>(
0171                     m_uint32Columns.at("barcode_particle")))
0172                 .withGeneration(static_cast<SimBarcode::GenerationId>(
0173                     m_uint32Columns.at("barcode_generation")))
0174                 .withSubParticle(static_cast<SimBarcode::SubParticleId>(
0175                     m_uint32Columns.at("barcode_sub_particle")));
0176     }
0177     const auto index = m_int32Columns.at("index");
0178 
0179     const Acts::Vector4 pos4 = {
0180         m_floatColumns.at("tx") * Acts::UnitConstants::mm,
0181         m_floatColumns.at("ty") * Acts::UnitConstants::mm,
0182         m_floatColumns.at("tz") * Acts::UnitConstants::mm,
0183         m_floatColumns.at("tt") * Acts::UnitConstants::mm,
0184     };
0185 
0186     const Acts::Vector4 before4 = {
0187         m_floatColumns.at("tpx") * Acts::UnitConstants::GeV,
0188         m_floatColumns.at("tpy") * Acts::UnitConstants::GeV,
0189         m_floatColumns.at("tpz") * Acts::UnitConstants::GeV,
0190         m_floatColumns.at("te") * Acts::UnitConstants::GeV,
0191     };
0192 
0193     const Acts::Vector4 delta = {
0194         m_floatColumns.at("deltapx") * Acts::UnitConstants::GeV,
0195         m_floatColumns.at("deltapy") * Acts::UnitConstants::GeV,
0196         m_floatColumns.at("deltapz") * Acts::UnitConstants::GeV,
0197         m_floatColumns.at("deltae") * Acts::UnitConstants::GeV,
0198     };
0199 
0200     SimHit hit(geoid, pid, pos4, before4, before4 + delta, index);
0201 
0202     hits.insert(hit);
0203   }
0204 
0205   ACTS_DEBUG("Read " << hits.size() << " hits for event "
0206                      << context.eventNumber);
0207 
0208   m_outputSimHits(context, std::move(hits));
0209 
0210   // Return success flag
0211   return ProcessCode::SUCCESS;
0212 }
0213 
0214 RootSimHitReader::~RootSimHitReader() = default;
0215 
0216 }  // namespace ActsExamples