Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:23:35

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/TrackFindingGnn/TrackFindingFromPrototrackAlgorithm.hpp"
0010 
0011 #include "Acts/EventData/ProxyAccessor.hpp"
0012 #include "Acts/TrackFinding/TrackStateCreator.hpp"
0013 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0014 #include "ActsExamples/EventData/MeasurementCalibration.hpp"
0015 
0016 #include <algorithm>
0017 #include <ranges>
0018 
0019 #include <boost/accumulators/accumulators.hpp>
0020 #include <boost/accumulators/statistics.hpp>
0021 
0022 using namespace Acts;
0023 
0024 namespace {
0025 
0026 using namespace ActsExamples;
0027 
0028 struct ProtoTrackSourceLinkAccessor
0029     : GeometryIdMultisetAccessor<IndexSourceLink> {
0030   using BaseIterator = GeometryIdMultisetAccessor<IndexSourceLink>::Iterator;
0031   using Iterator = SourceLinkAdapterIterator<BaseIterator>;
0032 
0033   std::unique_ptr<const Logger> loggerPtr;
0034   Container protoTrackSourceLinks;
0035 
0036   // get the range of elements with requested geoId
0037   std::pair<Iterator, Iterator> range(const Surface& surface) const {
0038     const auto& logger = *loggerPtr;
0039 
0040     if (protoTrackSourceLinks.contains(surface.geometryId())) {
0041       auto [begin, end] =
0042           protoTrackSourceLinks.equal_range(surface.geometryId());
0043       ACTS_VERBOSE("Select " << std::distance(begin, end)
0044                              << " source-links from prototrack on "
0045                              << surface.geometryId());
0046       return {Iterator{begin}, Iterator{end}};
0047     }
0048 
0049     assert(container != nullptr);
0050     auto [begin, end] = container->equal_range(surface.geometryId());
0051     ACTS_VERBOSE("Select " << std::distance(begin, end)
0052                            << " source-links from collection on "
0053                            << surface.geometryId());
0054     return {Iterator{begin}, Iterator{end}};
0055   }
0056 };
0057 
0058 }  // namespace
0059 
0060 namespace ActsExamples {
0061 
0062 TrackFindingFromPrototrackAlgorithm::TrackFindingFromPrototrackAlgorithm(
0063     Config cfg, Logging::Level lvl)
0064     : IAlgorithm(cfg.tag + "CkfFromProtoTracks", lvl), m_cfg(cfg) {
0065   m_inputInitialTrackParameters.initialize(m_cfg.inputInitialTrackParameters);
0066   m_inputMeasurements.initialize(m_cfg.inputMeasurements);
0067   m_inputProtoTracks.initialize(m_cfg.inputProtoTracks);
0068   m_outputTracks.initialize(m_cfg.outputTracks);
0069 }
0070 
0071 ActsExamples::ProcessCode TrackFindingFromPrototrackAlgorithm::execute(
0072     const ActsExamples::AlgorithmContext& ctx) const {
0073   const auto& measurements = m_inputMeasurements(ctx);
0074   const auto& protoTracks = m_inputProtoTracks(ctx);
0075   const auto& initialParameters = m_inputInitialTrackParameters(ctx);
0076 
0077   if (initialParameters.size() != protoTracks.size()) {
0078     ACTS_FATAL("Inconsistent number of parameters and prototracks");
0079     return ProcessCode::ABORT;
0080   }
0081 
0082   // Construct a perigee surface as the target surface
0083   auto pSurface = Surface::makeShared<PerigeeSurface>(Vector3{0., 0., 0.});
0084 
0085   PropagatorPlainOptions pOptions(ctx.geoContext, ctx.magFieldContext);
0086   pOptions.maxSteps = 10000;
0087 
0088   PassThroughCalibrator pcalibrator;
0089   MeasurementCalibratorAdapter calibrator(pcalibrator, measurements);
0090   GainMatrixUpdater kfUpdater;
0091   GainMatrixSmoother kfSmoother;
0092   MeasurementSelector measSel{m_cfg.measurementSelectorCfg};
0093 
0094   // The source link accessor
0095   ProtoTrackSourceLinkAccessor sourceLinkAccessor;
0096   sourceLinkAccessor.loggerPtr = logger().clone("SourceLinkAccessor");
0097   sourceLinkAccessor.container = &measurements.orderedIndices();
0098 
0099   using TrackStateCreatorType =
0100       TrackStateCreator<IndexSourceLinkAccessor::Iterator, TrackContainer>;
0101   TrackStateCreatorType trackStateCreator;
0102   trackStateCreator.sourceLinkAccessor
0103       .template connect<&ProtoTrackSourceLinkAccessor::range>(
0104           &sourceLinkAccessor);
0105   trackStateCreator.calibrator
0106       .connect<&MeasurementCalibratorAdapter::calibrate>(&calibrator);
0107   trackStateCreator.measurementSelector.connect<&MeasurementSelector::select<
0108       typename TrackContainer::TrackStateContainerBackend>>(&measSel);
0109 
0110   CombinatorialKalmanFilterExtensions<TrackContainer> extensions;
0111   extensions.updater.connect<&GainMatrixUpdater::operator()<
0112       typename TrackContainer::TrackStateContainerBackend>>(&kfUpdater);
0113   extensions.createTrackStates
0114       .template connect<&TrackStateCreatorType ::createTrackStates>(
0115           &trackStateCreator);
0116 
0117   // Set the CombinatorialKalmanFilter options
0118   TrackFindingAlgorithm::TrackFinderOptions options(
0119       ctx.geoContext, ctx.magFieldContext, ctx.calibContext, extensions,
0120       pOptions, &(*pSurface));
0121 
0122   // Perform the track finding for all initial parameters
0123   ACTS_DEBUG("Invoke track finding with " << initialParameters.size()
0124                                           << " seeds.");
0125 
0126   auto trackContainer = std::make_shared<VectorTrackContainer>();
0127   auto trackStateContainer = std::make_shared<VectorMultiTrajectory>();
0128 
0129   TrackContainer tracks(trackContainer, trackStateContainer);
0130 
0131   tracks.addColumn<unsigned int>("trackGroup");
0132   ProxyAccessor<unsigned int> seedNumber("trackGroup");
0133 
0134   std::size_t nSeed = 0;
0135   std::size_t nFailed = 0;
0136 
0137   std::vector<std::size_t> nTracksPerSeeds;
0138   nTracksPerSeeds.reserve(initialParameters.size());
0139 
0140   for (auto i = 0ul; i < initialParameters.size(); ++i) {
0141     sourceLinkAccessor.protoTrackSourceLinks.clear();
0142 
0143     // Fill the source links via their indices from the container
0144     for (const auto hitIndex : protoTracks.at(i)) {
0145       if (auto it = measurements.orderedIndices().nth(hitIndex);
0146           it != measurements.orderedIndices().end()) {
0147         sourceLinkAccessor.protoTrackSourceLinks.insert(*it);
0148       } else {
0149         ACTS_FATAL("Proto track " << i << " contains invalid hit index"
0150                                   << hitIndex);
0151         return ProcessCode::ABORT;
0152       }
0153     }
0154 
0155     auto rootBranch = tracks.makeTrack();
0156     auto result = (*m_cfg.findTracks)(initialParameters.at(i), options, tracks,
0157                                       rootBranch);
0158     nSeed++;
0159 
0160     if (!result.ok()) {
0161       nFailed++;
0162       ACTS_WARNING("Track finding failed for proto track " << i << " with error"
0163                                                            << result.error());
0164       continue;
0165     }
0166 
0167     auto& tracksForSeed = result.value();
0168 
0169     nTracksPerSeeds.push_back(tracksForSeed.size());
0170 
0171     for (auto& track : tracksForSeed) {
0172       // Set the seed number, this number decrease by 1 since the seed number
0173       // has already been updated
0174       seedNumber(track) = nSeed - 1;
0175     }
0176   }
0177 
0178   {
0179     std::lock_guard<std::mutex> guard(m_mutex);
0180 
0181     std::copy(nTracksPerSeeds.begin(), nTracksPerSeeds.end(),
0182               std::back_inserter(m_nTracksPerSeeds));
0183   }
0184 
0185   // TODO The computeSharedHits function is still a member function of
0186   // TrackFindingAlgorithm, but could also be a free function. Uncomment this
0187   // once this is done.
0188   // Compute shared hits from all the reconstructed tracks if
0189   // (m_cfg.computeSharedHits) {
0190   //   computeSharedHits(measurements, tracks);
0191   // }
0192 
0193   ACTS_INFO("Event " << ctx.eventNumber << ": " << nFailed << " / " << nSeed
0194                      << " failed (" << ((100.f * nFailed) / nSeed) << "%)");
0195   ACTS_DEBUG("Finalized track finding with " << tracks.size()
0196                                              << " track candidates.");
0197   auto constTrackStateContainer = std::make_shared<ConstVectorMultiTrajectory>(
0198       std::move(*trackStateContainer));
0199 
0200   auto constTrackContainer =
0201       std::make_shared<ConstVectorTrackContainer>(std::move(*trackContainer));
0202 
0203   ConstTrackContainer constTracks{constTrackContainer,
0204                                   constTrackStateContainer};
0205 
0206   m_outputTracks(ctx, std::move(constTracks));
0207   return ActsExamples::ProcessCode::SUCCESS;
0208 }
0209 
0210 ActsExamples::ProcessCode TrackFindingFromPrototrackAlgorithm::finalize() {
0211   assert(std::distance(m_nTracksPerSeeds.begin(), m_nTracksPerSeeds.end()) > 0);
0212 
0213   ACTS_INFO("TrackFindingFromPrototracksAlgorithm statistics:");
0214   namespace ba = boost::accumulators;
0215   using Accumulator = ba::accumulator_set<
0216       float, ba::features<ba::tag::sum, ba::tag::mean, ba::tag::variance>>;
0217 
0218   Accumulator totalAcc;
0219   std::ranges::for_each(m_nTracksPerSeeds,
0220                         [&](auto v) { totalAcc(static_cast<float>(v)); });
0221   ACTS_INFO("- total number tracks: " << ba::sum(totalAcc));
0222   ACTS_INFO("- avg tracks per seed: " << ba::mean(totalAcc) << " +- "
0223                                       << std::sqrt(ba::variance(totalAcc)));
0224 
0225   return {};
0226 }
0227 
0228 }  // namespace ActsExamples