File indexing completed on 2026-06-20 07:36:33
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Io/Arrow/ArrowTrackOutputConverter.hpp"
0010
0011 #include "Acts/Definitions/TrackParametrization.hpp"
0012 #include "ActsExamples/EventData/IndexSourceLink.hpp"
0013 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0014
0015 #include <cmath>
0016 #include <cstdint>
0017 #include <limits>
0018 #include <memory>
0019 #include <stdexcept>
0020 #include <vector>
0021
0022 #include <arrow/api.h>
0023
0024 namespace ActsExamples {
0025
0026 namespace {
0027
0028 void check(const arrow::Status& s, const char* what) {
0029 if (!s.ok()) {
0030 throw std::runtime_error(std::string(what) + ": " + s.ToString());
0031 }
0032 }
0033
0034 }
0035
0036 ArrowTrackOutputConverter::ArrowTrackOutputConverter(
0037 const Config& cfg, std::unique_ptr<const Acts::Logger> logger)
0038 : ArrowOutputConverter("ArrowTrackOutputConverter", std::move(logger)),
0039 m_cfg(cfg) {
0040 if (m_cfg.inputTracks.empty()) {
0041 throw std::invalid_argument("Missing tracks input collection");
0042 }
0043 if (m_cfg.outputTable.empty()) {
0044 throw std::invalid_argument("Missing output table name");
0045 }
0046 m_inputTracks.initialize(m_cfg.inputTracks);
0047 m_inputTrackParticleMatching.maybeInitialize(
0048 m_cfg.inputTrackParticleMatching);
0049 m_inputParticles.maybeInitialize(m_cfg.inputParticles);
0050 m_inputMeasurementSimHitsMap.maybeInitialize(
0051 m_cfg.inputMeasurementSimHitsMap);
0052 m_outputTable.initialize(m_cfg.outputTable);
0053 }
0054
0055 std::vector<std::string> ArrowTrackOutputConverter::collections() const {
0056 return {m_cfg.outputTable};
0057 }
0058
0059 ProcessCode ArrowTrackOutputConverter::execute(
0060 const AlgorithmContext& ctx) const {
0061 const ConstTrackContainer& tracks = m_inputTracks(ctx);
0062 const TrackParticleMatching* matching =
0063 m_inputTrackParticleMatching.isInitialized()
0064 ? &m_inputTrackParticleMatching(ctx)
0065 : nullptr;
0066 const SimParticleContainer* particles =
0067 m_inputParticles.isInitialized() ? &m_inputParticles(ctx) : nullptr;
0068 const MeasurementSimHitsMap* measToSimHits =
0069 m_inputMeasurementSimHitsMap.isInitialized()
0070 ? &m_inputMeasurementSimHitsMap(ctx)
0071 : nullptr;
0072
0073 auto* pool = arrow::default_memory_pool();
0074
0075 arrow::ListBuilder d0List(pool, std::make_shared<arrow::FloatBuilder>(pool));
0076 arrow::ListBuilder z0List(pool, std::make_shared<arrow::FloatBuilder>(pool));
0077 arrow::ListBuilder phiList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0078 arrow::ListBuilder thetaList(pool,
0079 std::make_shared<arrow::FloatBuilder>(pool));
0080 arrow::ListBuilder qopList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0081 arrow::ListBuilder tList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0082 arrow::ListBuilder majIdList(pool,
0083 std::make_shared<arrow::UInt64Builder>(pool));
0084 arrow::ListBuilder hitIdsList(
0085 pool, std::make_shared<arrow::ListBuilder>(
0086 pool, std::make_shared<arrow::UInt32Builder>(pool)));
0087 arrow::ListBuilder trackIdList(pool,
0088 std::make_shared<arrow::UInt16Builder>(pool));
0089
0090 check(d0List.Append(), "open d0 list");
0091 check(z0List.Append(), "open z0 list");
0092 check(phiList.Append(), "open phi list");
0093 check(thetaList.Append(), "open theta list");
0094 check(qopList.Append(), "open qop list");
0095 if (m_cfg.writeTime) {
0096 check(tList.Append(), "open t list");
0097 } else {
0098
0099
0100 check(tList.AppendNull(), "append null t list");
0101 }
0102 check(majIdList.Append(), "open majority_particle_id list");
0103 check(hitIdsList.Append(), "open hit_ids outer list");
0104 check(trackIdList.Append(), "open track_id list");
0105
0106 auto* d0V = static_cast<arrow::FloatBuilder*>(d0List.value_builder());
0107 auto* z0V = static_cast<arrow::FloatBuilder*>(z0List.value_builder());
0108 auto* phiV = static_cast<arrow::FloatBuilder*>(phiList.value_builder());
0109 auto* thetaV = static_cast<arrow::FloatBuilder*>(thetaList.value_builder());
0110 auto* qopV = static_cast<arrow::FloatBuilder*>(qopList.value_builder());
0111 auto* tV = static_cast<arrow::FloatBuilder*>(tList.value_builder());
0112 auto* majIdV = static_cast<arrow::UInt64Builder*>(majIdList.value_builder());
0113 auto* hitIdsInner =
0114 static_cast<arrow::ListBuilder*>(hitIdsList.value_builder());
0115 auto* hitIdsV =
0116 static_cast<arrow::UInt32Builder*>(hitIdsInner->value_builder());
0117 auto* trackIdV =
0118 static_cast<arrow::UInt16Builder*>(trackIdList.value_builder());
0119
0120 const auto n = tracks.size();
0121 check(d0V->Reserve(n), "reserve d0");
0122 check(z0V->Reserve(n), "reserve z0");
0123 check(phiV->Reserve(n), "reserve phi");
0124 check(thetaV->Reserve(n), "reserve theta");
0125 check(qopV->Reserve(n), "reserve qop");
0126 if (m_cfg.writeTime) {
0127 check(tV->Reserve(n), "reserve t");
0128 }
0129 check(majIdV->Reserve(n), "reserve majority_particle_id");
0130 check(trackIdV->Reserve(n), "reserve track_id");
0131
0132
0133
0134
0135 constexpr std::uint64_t kUnmatched =
0136 std::numeric_limits<std::uint64_t>::max();
0137
0138 for (const auto& track : tracks) {
0139 if (track.hasReferenceSurface()) {
0140 const auto& p = track.parameters();
0141 d0V->UnsafeAppend(static_cast<float>(p[Acts::eBoundLoc0]));
0142 z0V->UnsafeAppend(static_cast<float>(p[Acts::eBoundLoc1]));
0143 phiV->UnsafeAppend(static_cast<float>(p[Acts::eBoundPhi]));
0144 thetaV->UnsafeAppend(static_cast<float>(p[Acts::eBoundTheta]));
0145 qopV->UnsafeAppend(static_cast<float>(p[Acts::eBoundQOverP]));
0146 if (m_cfg.writeTime) {
0147 tV->UnsafeAppend(static_cast<float>(p[Acts::eBoundTime]));
0148 }
0149 } else {
0150 d0V->UnsafeAppendNull();
0151 z0V->UnsafeAppendNull();
0152 phiV->UnsafeAppendNull();
0153 thetaV->UnsafeAppendNull();
0154 qopV->UnsafeAppendNull();
0155 if (m_cfg.writeTime) {
0156 tV->UnsafeAppendNull();
0157 }
0158 }
0159
0160
0161
0162
0163 std::uint64_t majId = kUnmatched;
0164 if (matching != nullptr && particles != nullptr) {
0165 auto it = matching->find(track.index());
0166 if (it != matching->end() && it->second.particle.has_value()) {
0167 const auto& bc = it->second.particle.value();
0168 auto pIt = particles->find(bc);
0169 if (pIt != particles->end()) {
0170 majId = static_cast<std::uint64_t>(
0171 std::distance(particles->begin(), pIt));
0172 }
0173 }
0174 }
0175 majIdV->UnsafeAppend(majId);
0176
0177 check(hitIdsInner->Append(), "open hit_ids inner list");
0178
0179
0180
0181 if (measToSimHits != nullptr) {
0182
0183
0184
0185
0186 std::vector<std::uint32_t> hitIds;
0187 for (const auto& state : track.trackStatesReversed()) {
0188 if (!state.hasUncalibratedSourceLink()) {
0189 continue;
0190 }
0191 const auto sl =
0192 state.getUncalibratedSourceLink().template get<IndexSourceLink>();
0193 const auto measIdx = static_cast<Index>(sl.index());
0194
0195
0196 auto range = measToSimHits->equal_range(measIdx);
0197 for (auto it = range.first; it != range.second; ++it) {
0198 hitIds.push_back(static_cast<std::uint32_t>(it->second));
0199 }
0200 }
0201 for (auto rit = hitIds.rbegin(); rit != hitIds.rend(); ++rit) {
0202 check(hitIdsV->Append(*rit), "append hit_id");
0203 }
0204 }
0205
0206 trackIdV->UnsafeAppend(static_cast<std::uint16_t>(track.index()));
0207 }
0208
0209 auto finish = [](arrow::ListBuilder& b) {
0210 std::shared_ptr<arrow::Array> out;
0211 check(b.Finish(&out), "finish list");
0212 return out;
0213 };
0214
0215 std::vector<std::shared_ptr<arrow::Array>> arrays = {
0216 finish(d0List), finish(z0List), finish(phiList),
0217 finish(thetaList), finish(qopList), finish(majIdList),
0218 finish(hitIdsList), finish(trackIdList), finish(tList),
0219 };
0220
0221 auto table =
0222 arrow::Table::Make(ActsPlugins::ArrowUtil::trackSchema(), arrays);
0223 m_outputTable(ctx, ActsPlugins::ArrowUtil::ArrowTable{std::move(table)});
0224
0225 return ProcessCode::SUCCESS;
0226 }
0227
0228 }