File indexing completed on 2025-12-16 09:23:35
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/TrackFindingGnn/TrackFindingAlgorithmGnn.hpp"
0010
0011 #include "Acts/Definitions/Units.hpp"
0012 #include "Acts/Utilities/Helpers.hpp"
0013 #include "Acts/Utilities/Zip.hpp"
0014 #include "ActsExamples/EventData/Index.hpp"
0015 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0016 #include "ActsExamples/EventData/ProtoTrack.hpp"
0017 #include "ActsExamples/EventData/SimSpacePoint.hpp"
0018 #include "ActsExamples/Framework/WhiteBoard.hpp"
0019 #include "ActsPlugins/Gnn/GraphStoreHook.hpp"
0020 #include "ActsPlugins/Gnn/TruthGraphMetricsHook.hpp"
0021 #include "ActsPlugins/Gnn/detail/NvtxUtils.hpp"
0022
0023 #include <algorithm>
0024 #include <chrono>
0025 #include <numeric>
0026
0027 #include "createFeatures.hpp"
0028
0029 #ifdef ACTS_GNN_WITH_CUDA
0030 #include <cuda_runtime_api.h>
0031 #endif
0032
0033 using namespace Acts;
0034 using namespace ActsPlugins;
0035 using namespace Acts::UnitLiterals;
0036
0037 namespace {
0038
0039 struct LoopHook : public GnnHook {
0040 std::vector<GnnHook*> hooks;
0041
0042 void operator()(const PipelineTensors& tensors,
0043 const ExecutionContext& ctx) const override {
0044 for (auto hook : hooks) {
0045 (*hook)(tensors, ctx);
0046 }
0047 }
0048 };
0049
0050 }
0051
0052 ActsExamples::TrackFindingAlgorithmGnn::TrackFindingAlgorithmGnn(
0053 Config config, Logging::Level level)
0054 : ActsExamples::IAlgorithm("TrackFindingMLBasedAlgorithm", level),
0055 m_cfg(std::move(config)),
0056 m_pipeline(m_cfg.graphConstructor, m_cfg.edgeClassifiers,
0057 m_cfg.trackBuilder, logger().clone()) {
0058 if (m_cfg.inputSpacePoints.empty()) {
0059 throw std::invalid_argument("Missing spacepoint input collection");
0060 }
0061 if (m_cfg.outputProtoTracks.empty()) {
0062 throw std::invalid_argument("Missing protoTrack output collection");
0063 }
0064
0065 m_inputSpacePoints.initialize(m_cfg.inputSpacePoints);
0066 m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0067 m_outputProtoTracks.initialize(m_cfg.outputProtoTracks);
0068
0069 m_inputTruthGraph.maybeInitialize(m_cfg.inputTruthGraph);
0070 m_outputGraph.maybeInitialize(m_cfg.outputGraph);
0071
0072
0073 m_timing.classifierTimes.resize(
0074 m_cfg.edgeClassifiers.size(),
0075 decltype(m_timing.classifierTimes)::value_type{0.f});
0076
0077
0078 const static std::array clFeatures = {
0079 NodeFeature::eClusterLoc0, NodeFeature::eClusterLoc0,
0080 NodeFeature::eCellCount, NodeFeature::eChargeSum,
0081 NodeFeature::eCluster1R, NodeFeature::eCluster2R};
0082
0083 auto wantClFeatures = std::ranges::any_of(
0084 m_cfg.nodeFeatures,
0085 [&](const auto& f) { return rangeContainsValue(clFeatures, f); });
0086
0087 if (wantClFeatures && !m_inputClusters.isInitialized()) {
0088 throw std::invalid_argument("Cluster features requested, but not provided");
0089 }
0090
0091 if (m_cfg.nodeFeatures.size() != m_cfg.featureScales.size()) {
0092 throw std::invalid_argument(
0093 "Number of features mismatches number of scale parameters.");
0094 }
0095
0096
0097 #ifdef ACTS_GNN_WITH_CUDA
0098 cudaGetLastError();
0099 #endif
0100 }
0101
0102
0103
0104 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmGnn::execute(
0105 const ActsExamples::AlgorithmContext& ctx) const {
0106 ACTS_NVTX_START(data_preparation);
0107
0108 using Clock = std::chrono::high_resolution_clock;
0109 using Duration = std::chrono::duration<double, std::milli>;
0110 auto t0 = Clock::now();
0111
0112
0113 LoopHook hook;
0114
0115 std::unique_ptr<TruthGraphMetricsHook> truthGraphHook;
0116 if (m_inputTruthGraph.isInitialized()) {
0117 truthGraphHook = std::make_unique<TruthGraphMetricsHook>(
0118 m_inputTruthGraph(ctx).edges, this->logger().clone());
0119 hook.hooks.push_back(&*truthGraphHook);
0120 }
0121
0122 std::unique_ptr<GraphStoreHook> graphStoreHook;
0123 if (m_outputGraph.isInitialized()) {
0124 graphStoreHook = std::make_unique<GraphStoreHook>();
0125 hook.hooks.push_back(&*graphStoreHook);
0126 }
0127
0128
0129 const auto& spacepoints = m_inputSpacePoints(ctx);
0130
0131 const ClusterContainer* clusters = nullptr;
0132 if (m_inputClusters.isInitialized()) {
0133 clusters = &m_inputClusters(ctx);
0134 }
0135
0136
0137
0138 const std::size_t numSpacepoints = spacepoints.size();
0139 const std::size_t numFeatures = m_cfg.nodeFeatures.size();
0140 ACTS_DEBUG("Received " << numSpacepoints << " spacepoints");
0141 ACTS_DEBUG("Construct " << numFeatures << " node features");
0142
0143 auto t01 = Clock::now();
0144
0145 std::vector<std::uint64_t> moduleIds;
0146 moduleIds.reserve(spacepoints.size());
0147
0148 for (auto isp = 0ul; isp < numSpacepoints; ++isp) {
0149 const auto& sp = spacepoints[isp];
0150
0151
0152
0153
0154
0155 const auto& sl1 = sp.sourceLinks().at(0).template get<IndexSourceLink>();
0156
0157 if (m_cfg.geometryIdMap != nullptr) {
0158 moduleIds.push_back(m_cfg.geometryIdMap->right.at(sl1.geometryId()));
0159 } else {
0160 moduleIds.push_back(sl1.geometryId().value());
0161 }
0162 }
0163
0164 auto t02 = Clock::now();
0165
0166
0167 std::vector<int> idxs(numSpacepoints);
0168 std::iota(idxs.begin(), idxs.end(), 0);
0169 std::ranges::sort(idxs, {}, [&](auto i) { return moduleIds[i]; });
0170
0171 std::ranges::sort(moduleIds);
0172
0173 SimSpacePointContainer sortedSpacepoints;
0174 sortedSpacepoints.reserve(spacepoints.size());
0175 std::ranges::transform(idxs, std::back_inserter(sortedSpacepoints),
0176 [&](auto i) { return spacepoints[i]; });
0177
0178 auto t03 = Clock::now();
0179
0180 auto features = createFeatures(sortedSpacepoints, clusters,
0181 m_cfg.nodeFeatures, m_cfg.featureScales);
0182
0183 auto t1 = Clock::now();
0184
0185 auto ms = [](auto a, auto b) {
0186 return std::chrono::duration<double, std::milli>(b - a).count();
0187 };
0188 ACTS_DEBUG("Setup time: " << ms(t0, t01));
0189 ACTS_DEBUG("ModuleId mapping & copy: " << ms(t01, t02));
0190 ACTS_DEBUG("Spacepoint sort: " << ms(t02, t03));
0191 ACTS_DEBUG("Feature creation: " << ms(t03, t1));
0192
0193
0194 ACTS_NVTX_STOP(data_preparation);
0195 GnnTiming timing;
0196 #ifdef ACTS_GNN_CPUONLY
0197 Device device = {Device::Type::eCPU, 0};
0198 #else
0199 Device device = {Device::Type::eCUDA, 0};
0200 #endif
0201 auto trackCandidates =
0202 m_pipeline.run(features, moduleIds, idxs, device, hook, &timing);
0203 ACTS_NVTX_START(post_processing);
0204
0205 auto t2 = Clock::now();
0206
0207 ACTS_DEBUG("Done with pipeline, received " << trackCandidates.size()
0208 << " candidates");
0209
0210
0211 std::vector<ProtoTrack> protoTracks;
0212 protoTracks.reserve(trackCandidates.size());
0213
0214 int nShortTracks = 0;
0215
0216
0217
0218 for (auto& candidate : trackCandidates) {
0219 ProtoTrack onetrack;
0220 onetrack.reserve(candidate.size());
0221
0222 for (auto i : candidate) {
0223 for (const auto& sl : spacepoints.at(i).sourceLinks()) {
0224 onetrack.push_back(sl.template get<IndexSourceLink>().index());
0225 }
0226 }
0227
0228 if (onetrack.size() < m_cfg.minMeasurementsPerTrack) {
0229 nShortTracks++;
0230 continue;
0231 }
0232
0233 protoTracks.push_back(std::move(onetrack));
0234 }
0235
0236 ACTS_DEBUG("Removed " << nShortTracks << " with less then "
0237 << m_cfg.minMeasurementsPerTrack << " hits");
0238 ACTS_DEBUG("Created " << protoTracks.size() << " proto tracks");
0239
0240 m_outputProtoTracks(ctx, std::move(protoTracks));
0241
0242 if (m_outputGraph.isInitialized()) {
0243 auto graph = graphStoreHook->storedGraph();
0244 std::transform(graph.first.begin(), graph.first.end(), graph.first.begin(),
0245 [&](const auto& a) -> std::int64_t { return idxs.at(a); });
0246 m_outputGraph(ctx, {graph.first, graph.second});
0247 }
0248
0249 auto t3 = Clock::now();
0250
0251 {
0252 std::lock_guard<std::mutex> lock(m_mutex);
0253
0254 m_timing.preprocessingTime(Duration(t1 - t0).count());
0255 m_timing.graphBuildingTime(timing.graphBuildingTime.count());
0256
0257 assert(timing.classifierTimes.size() == m_timing.classifierTimes.size());
0258 for (auto [aggr, a] :
0259 zip(m_timing.classifierTimes, timing.classifierTimes)) {
0260 aggr(a.count());
0261 }
0262
0263 m_timing.trackBuildingTime(timing.trackBuildingTime.count());
0264 m_timing.postprocessingTime(Duration(t3 - t2).count());
0265 m_timing.fullTime(Duration(t3 - t0).count());
0266 }
0267
0268 ACTS_NVTX_STOP(post_processing);
0269 return ActsExamples::ProcessCode::SUCCESS;
0270 }
0271
0272 ActsExamples::ProcessCode ActsExamples::TrackFindingAlgorithmGnn::finalize() {
0273 namespace ba = boost::accumulators;
0274
0275 auto print = [](const auto& t) {
0276 std::stringstream ss;
0277 ss << ba::mean(t) << " +- " << std::sqrt(ba::variance(t)) << " ";
0278 ss << "[" << ba::min(t) << ", " << ba::max(t) << "]";
0279 return ss.str();
0280 };
0281
0282 ACTS_INFO("GNN timing info");
0283 ACTS_INFO("- preprocessing: " << print(m_timing.preprocessingTime));
0284 ACTS_INFO("- graph building: " << print(m_timing.graphBuildingTime));
0285
0286 for (const auto& t : m_timing.classifierTimes) {
0287 ACTS_INFO("- classifier: " << print(t));
0288 }
0289
0290 ACTS_INFO("- track building: " << print(m_timing.trackBuildingTime));
0291 ACTS_INFO("- postprocessing: " << print(m_timing.postprocessingTime));
0292 ACTS_INFO("- full timing: " << print(m_timing.fullTime));
0293
0294 return {};
0295 }