File indexing completed on 2025-08-28 08:12:14
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/TrackFindingGnn/TrackFindingAlgorithmGnn.hpp"
0010
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/Plugins/Gnn/GraphStoreHook.hpp"
0013 #include "Acts/Plugins/Gnn/TruthGraphMetricsHook.hpp"
0014 #include "Acts/Utilities/Helpers.hpp"
0015 #include "Acts/Utilities/Zip.hpp"
0016 #include "ActsExamples/EventData/Index.hpp"
0017 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0018 #include "ActsExamples/EventData/ProtoTrack.hpp"
0019 #include "ActsExamples/EventData/SimSpacePoint.hpp"
0020 #include "ActsExamples/Framework/WhiteBoard.hpp"
0021
0022 #include <algorithm>
0023 #include <chrono>
0024 #include <numeric>
0025
0026 #include "createFeatures.hpp"
0027
0028 using namespace ActsExamples;
0029 using namespace Acts::UnitLiterals;
0030
0031 namespace {
0032
0033 struct LoopHook : public Acts::GnnHook {
0034 std::vector<Acts::GnnHook*> hooks;
0035
0036 void operator()(const Acts::PipelineTensors& tensors,
0037 const Acts::ExecutionContext& ctx) const override {
0038 for (auto hook : hooks) {
0039 (*hook)(tensors, ctx);
0040 }
0041 }
0042 };
0043
0044 }
0045
0046 ActsExamples::TrackFindingAlgorithmGnn::TrackFindingAlgorithmGnn(
0047 Config config, Acts::Logging::Level level)
0048 : ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
0049 m_cfg(std::move(config)),
0050 m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
0051 m_cfg.trackBuilder, logger().clone()) {
0052 if (m_cfg.inputSpacePoints.empty()) {
0053 throw std::invalid_argument("Missing spacepoint input collection");
0054 }
0055 if (m_cfg.outputProtoTracks.empty()) {
0056 throw std::invalid_argument("Missing protoTrack output collection");
0057 }
0058
0059 m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0060 m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0061 m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);
0062
0063 m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph);
0064 m_outputGraph.maybeInitialize(m_cfg.outputGraph);
0065
0066
0067 m_timing.classifierTimes.resize(
0068 m_cfg.edgeClassifiers.size(),
0069 decltype(m_timing.classifierTimes)::value_type{0.f});
0070
0071
0072 const static std::array clFeatures = {
0073 NodeFeature::eClusterLoc0, NodeFeature::eClusterLoc0,
0074 NodeFeature::eCellCount, NodeFeature::eChargeSum,
0075 NodeFeature::eCluster1R, NodeFeature::eCluster2R};
0076
0077 auto wantClFeatures = std::ranges::any_of(
0078 m_cfg.nodeFeatures,
0079 [&](const auto& f) { return Acts::rangeContainsValue(clFeatures, f); });
0080
0081 if (wantClFeatures && !m_inputClusters.isInitialized()) {
0082 throw std::invalid_argument("Cluster features requested, but not provided");
0083 }
0084
0085 if (m_cfg.nodeFeatures.size() != m_cfg.featureScales.size()) {
0086 throw std::invalid_argument(
0087 "Number of features mismatches number of scale parameters.");
0088 }
0089 }
0090
0091
0092
0093 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmGnn::execute(
0094 const ActsExamples::AlgorithmContext& ctx) const {
0095 using Clock = std::chrono::high_resolution_clock;
0096 using Duration = std::chrono::duration<double, std::milli>;
0097 auto t0 = Clock::now();
0098
0099
0100 LoopHook hook;
0101
0102 std::unique_ptr<Acts::TruthGraphMetricsHook> truthGraphHook;
0103 if (m_inputTruthGraph.isInitialized()) {
0104 truthGraphHook = std::make_unique<Acts::TruthGraphMetricsHook>(
0105 m_inputTruthGraph(ctx).edges, this->logger().clone());
0106 hook.hooks.push_back(&*truthGraphHook);
0107 }
0108
0109 std::unique_ptr<Acts::GraphStoreHook> graphStoreHook;
0110 if (m_outputGraph.isInitialized()) {
0111 graphStoreHook = std::make_unique<Acts::GraphStoreHook>();
0112 hook.hooks.push_back(&*graphStoreHook);
0113 }
0114
0115
0116 const auto& spacepoints = m_inputSpacePoints(ctx);
0117
0118 const ClusterContainer* clusters = nullptr;
0119 if (m_inputClusters.isInitialized()) {
0120 clusters = &m_inputClusters(ctx);
0121 }
0122
0123
0124
0125 const std::size_t numSpacepoints = spacepoints.size();
0126 const std::size_t numFeatures = m_cfg.nodeFeatures.size();
0127 ACTS_DEBUG("Received " << numSpacepoints << " spacepoints");
0128 ACTS_DEBUG("Construct " << numFeatures << " node features");
0129
0130 auto t01 = Clock::now();
0131
0132 std::vector<std::uint64_t> moduleIds;
0133 moduleIds.reserve(spacepoints.size());
0134
0135 for (auto isp = 0ul; isp < numSpacepoints; ++isp) {
0136 const auto& sp = spacepoints[isp];
0137
0138
0139
0140
0141
0142 const auto& sl1 = sp.sourceLinks().at(0).template get<IndexSourceLink>();
0143
0144 if (m_cfg.geometryIdMap != nullptr) {
0145 moduleIds.push_back(m_cfg.geometryIdMap->right.at(sl1.geometryId()));
0146 } else {
0147 moduleIds.push_back(sl1.geometryId().value());
0148 }
0149 }
0150
0151 auto t02 = Clock::now();
0152
0153
0154 std::vector<int> idxs(numSpacepoints);
0155 std::iota(idxs.begin(), idxs.end(), 0);
0156 std::ranges::sort(idxs, {}, [&](auto i) { return moduleIds[i]; });
0157
0158 std::ranges::sort(moduleIds);
0159
0160 SimSpacePointContainer sortedSpacepoints;
0161 sortedSpacepoints.reserve(spacepoints.size());
0162 std::ranges::transform(idxs, std::back_inserter(sortedSpacepoints),
0163 [&](auto i) { return spacepoints[i]; });
0164
0165 auto t03 = Clock::now();
0166
0167 auto features = createFeatures(sortedSpacepoints, clusters,
0168 m_cfg.nodeFeatures, m_cfg.featureScales);
0169
0170 auto t1 = Clock::now();
0171
0172 auto ms = [](auto a, auto b) {
0173 return std::chrono::duration<double, std::milli>(b - a).count();
0174 };
0175 ACTS_DEBUG("Setup time: " << ms(t0, t01));
0176 ACTS_DEBUG("ModuleId mapping & copy: " << ms(t01, t02));
0177 ACTS_DEBUG("Spacepoint sort: " << ms(t02, t03));
0178 ACTS_DEBUG("Feature creation: " << ms(t03, t1));
0179
0180
0181 Acts::GnnTiming timing;
0182 #ifdef ACTS_GNN_CPUONLY
0183 Acts::Device device = {Acts::Device::Type::eCPU, 0};
0184 #else
0185 Acts::Device device = {Acts::Device::Type::eCUDA, 0};
0186 #endif
0187 auto trackCandidates =
0188 m_pipeline.run(features, moduleIds, idxs, device, hook, &timing);
0189
0190 auto t2 = Clock::now();
0191
0192 ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
0193 << " candidates");
0194
0195
0196 std::vector<ProtoTrack> protoTracks;
0197 protoTracks.reserve(trackCandidates.size());
0198
0199 int nShortTracks = 0;
0200
0201
0202
0203 for (auto& candidate : trackCandidates) {
0204 ProtoTrack onetrack;
0205 onetrack.reserve(candidate.size());
0206
0207 for (auto i : candidate) {
0208 for (const auto& sl : spacepoints.at(i).sourceLinks()) {
0209 onetrack.push_back(sl.template get<IndexSourceLink>().index());
0210 }
0211 }
0212
0213 if (onetrack.size() < m_cfg.minMeasurementsPerTrack) {
0214 nShortTracks++;
0215 continue;
0216 }
0217
0218 protoTracks.push_back(std::move(onetrack));
0219 }
0220
0221 ACTS_DEBUG("Removed " << nShortTracks << " with less then "
0222 << m_cfg.minMeasurementsPerTrack << " hits");
0223 ACTS_DEBUG("Created " << protoTracks.size() << " proto tracks");
0224
0225 m_outputProtoTracks(ctx, std::move(protoTracks));
0226
0227 if (m_outputGraph.isInitialized()) {
0228 auto graph = graphStoreHook->storedGraph();
0229 std::transform(graph.first.begin(), graph.first.end(), graph.first.begin(),
0230 [&](const auto& a) -> std::int64_t { return idxs.at(a); });
0231 m_outputGraph(ctx, {graph.first, graph.second});
0232 }
0233
0234 auto t3 = Clock::now();
0235
0236 {
0237 std::lock_guard<std::mutex> lock(m_mutex);
0238
0239 m_timing.preprocessingTime(Duration(t1 - t0).count());
0240 m_timing.graphBuildingTime(timing.graphBuildingTime.count());
0241
0242 assert(timing.classifierTimes.size() == m_timing.classifierTimes.size());
0243 for (auto [aggr, a] :
0244 Acts::zip(m_timing.classifierTimes, timing.classifierTimes)) {
0245 aggr(a.count());
0246 }
0247
0248 m_timing.trackBuildingTime(timing.trackBuildingTime.count());
0249 m_timing.postprocessingTime(Duration(t3 - t2).count());
0250 m_timing.fullTime(Duration(t3 - t0).count());
0251 }
0252
0253 return ActsExamples::ProcessCode::SUCCESS;
0254 }
0255
0256 ActsExamples::ProcessCode TrackFindingAlgorithmGnn::finalize() {
0257 namespace ba = boost::accumulators;
0258
0259 auto print = [](const auto& t) {
0260 std::stringstream ss;
0261 ss << ba::mean(t) << " +- " << std::sqrt(ba::variance(t)) << " ";
0262 ss << "[" << ba::min(t) << ", " << ba::max(t) << "]";
0263 return ss.str();
0264 };
0265
0266 ACTS_INFO("GNN timing info");
0267 ACTS_INFO("- preprocessing: " << print(m_timing.preprocessingTime));
0268 ACTS_INFO("- graph building: " << print(m_timing.graphBuildingTime));
0269
0270 for (const auto& t : m_timing.classifierTimes) {
0271 ACTS_INFO("- classifier: " << print(t));
0272 }
0273
0274 ACTS_INFO("- track building: " << print(m_timing.trackBuildingTime));
0275 ACTS_INFO("- postprocessing: " << print(m_timing.postprocessingTime));
0276 ACTS_INFO("- full timing: " << print(m_timing.fullTime));
0277
0278 return {};
0279 }