File indexing completed on 2026-06-20 07:36:33
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Io/Arrow/ArrowSimHitOutputConverter.hpp"
0010
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/Definitions/Units.hpp"
0013 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0014
0015 #include <array>
0016 #include <cmath>
0017 #include <cstdint>
0018 #include <limits>
0019 #include <memory>
0020 #include <stdexcept>
0021
0022 #include <arrow/api.h>
0023
0024 namespace ActsExamples {
0025
0026 namespace {
0027
0028 void check(const arrow::Status& s, const char* what) {
0029 if (!s.ok()) {
0030 throw std::runtime_error(std::string(what) + ": " + s.ToString());
0031 }
0032 }
0033
0034 }
0035
0036 ArrowSimHitOutputConverter::ArrowSimHitOutputConverter(
0037 const Config& cfg, std::unique_ptr<const Acts::Logger> logger)
0038 : ArrowOutputConverter("ArrowSimHitOutputConverter", std::move(logger)),
0039 m_cfg(cfg) {
0040 if (m_cfg.inputSimHits.empty()) {
0041 throw std::invalid_argument("Missing sim hits input collection");
0042 }
0043 if (m_cfg.outputTable.empty()) {
0044 throw std::invalid_argument("Missing output table name");
0045 }
0046 if (!m_cfg.detectorResolver) {
0047 throw std::invalid_argument("detectorResolver must be set");
0048 }
0049
0050
0051
0052 const bool hasClusters = !m_cfg.inputClusters.empty();
0053 const bool hasMap = !m_cfg.inputSimHitMeasurementsMap.empty();
0054 if (hasClusters != hasMap) {
0055 throw std::invalid_argument(
0056 "ArrowSimHitOutputConverter: inputClusters and "
0057 "inputSimHitMeasurementsMap must both be set or both be unset");
0058 }
0059
0060 m_inputSimHits.initialize(m_cfg.inputSimHits);
0061 m_inputParticles.maybeInitialize(m_cfg.inputParticles);
0062 m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0063 m_inputSimHitMeasurementsMap.maybeInitialize(
0064 m_cfg.inputSimHitMeasurementsMap);
0065 m_outputTable.initialize(m_cfg.outputTable);
0066 }
0067
0068 std::function<std::uint8_t(Acts::GeometryIdentifier)>
0069 ArrowSimHitOutputConverter::makeVolumeIdDetectorResolver(
0070 const std::unordered_map<std::uint32_t, std::uint8_t>& volumeToDetector,
0071 std::uint8_t defaultValue) {
0072 constexpr std::size_t kNumVolumes =
0073 static_cast<std::size_t>(Acts::GeometryIdentifier::getMaxVolume()) + 1;
0074 std::array<std::uint8_t, kNumVolumes> detectorArray{};
0075 detectorArray.fill(defaultValue);
0076 for (const auto& [volume, detector] : volumeToDetector) {
0077 if (volume >= kNumVolumes) {
0078 throw std::invalid_argument(
0079 "makeVolumeIdDetectorResolver: volume id " + std::to_string(volume) +
0080 " exceeds maximum " +
0081 std::to_string(Acts::GeometryIdentifier::getMaxVolume()));
0082 }
0083 detectorArray[volume] = detector;
0084 }
0085 return [detectorArray](Acts::GeometryIdentifier gid) -> std::uint8_t {
0086 const auto volume = static_cast<std::size_t>(gid.volume());
0087 return detectorArray[volume];
0088 };
0089 }
0090
0091 std::vector<std::string> ArrowSimHitOutputConverter::collections() const {
0092 return {m_cfg.outputTable};
0093 }
0094
0095 ProcessCode ArrowSimHitOutputConverter::execute(
0096 const AlgorithmContext& ctx) const {
0097 const SimHitContainer& simHits = m_inputSimHits(ctx);
0098 const SimParticleContainer* particles =
0099 m_inputParticles.isInitialized() ? &m_inputParticles(ctx) : nullptr;
0100 const ClusterContainer* clusters =
0101 m_inputClusters.isInitialized() ? &m_inputClusters(ctx) : nullptr;
0102 const SimHitMeasurementsMap* simHitMeasMap =
0103 m_inputSimHitMeasurementsMap.isInitialized()
0104 ? &m_inputSimHitMeasurementsMap(ctx)
0105 : nullptr;
0106
0107 auto* pool = arrow::default_memory_pool();
0108
0109 arrow::ListBuilder xList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0110 arrow::ListBuilder yList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0111 arrow::ListBuilder zList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0112 arrow::ListBuilder txList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0113 arrow::ListBuilder tyList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0114 arrow::ListBuilder tzList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0115 arrow::ListBuilder timeList(pool,
0116 std::make_shared<arrow::FloatBuilder>(pool));
0117 arrow::ListBuilder pidList(pool,
0118 std::make_shared<arrow::UInt64Builder>(pool));
0119 arrow::ListBuilder detList(pool, std::make_shared<arrow::UInt8Builder>(pool));
0120 arrow::ListBuilder volList(pool, std::make_shared<arrow::UInt8Builder>(pool));
0121 arrow::ListBuilder layList(pool,
0122 std::make_shared<arrow::UInt16Builder>(pool));
0123 arrow::ListBuilder surfList(pool,
0124 std::make_shared<arrow::UInt32Builder>(pool));
0125
0126 check(xList.Append(), "open x list");
0127 check(yList.Append(), "open y list");
0128 check(zList.Append(), "open z list");
0129 check(txList.Append(), "open true_x list");
0130 check(tyList.Append(), "open true_y list");
0131 check(tzList.Append(), "open true_z list");
0132 check(timeList.Append(), "open time list");
0133 check(pidList.Append(), "open particle_id list");
0134 check(detList.Append(), "open detector list");
0135 check(volList.Append(), "open volume_id list");
0136 check(layList.Append(), "open layer_id list");
0137 check(surfList.Append(), "open surface_id list");
0138
0139
0140
0141 auto* xV = static_cast<arrow::FloatBuilder*>(xList.value_builder());
0142 auto* yV = static_cast<arrow::FloatBuilder*>(yList.value_builder());
0143 auto* zV = static_cast<arrow::FloatBuilder*>(zList.value_builder());
0144 auto* txV = static_cast<arrow::FloatBuilder*>(txList.value_builder());
0145 auto* tyV = static_cast<arrow::FloatBuilder*>(tyList.value_builder());
0146 auto* tzV = static_cast<arrow::FloatBuilder*>(tzList.value_builder());
0147 auto* timeV = static_cast<arrow::FloatBuilder*>(timeList.value_builder());
0148 auto* pidV = static_cast<arrow::UInt64Builder*>(pidList.value_builder());
0149 auto* detV = static_cast<arrow::UInt8Builder*>(detList.value_builder());
0150 auto* volV = static_cast<arrow::UInt8Builder*>(volList.value_builder());
0151 auto* layV = static_cast<arrow::UInt16Builder*>(layList.value_builder());
0152 auto* surfV = static_cast<arrow::UInt32Builder*>(surfList.value_builder());
0153
0154 const auto n = simHits.size();
0155 check(xV->Reserve(n), "reserve x");
0156 check(yV->Reserve(n), "reserve y");
0157 check(zV->Reserve(n), "reserve z");
0158 check(txV->Reserve(n), "reserve true_x");
0159 check(tyV->Reserve(n), "reserve true_y");
0160 check(tzV->Reserve(n), "reserve true_z");
0161 check(timeV->Reserve(n), "reserve time");
0162 check(pidV->Reserve(n), "reserve particle_id");
0163 check(detV->Reserve(n), "reserve detector");
0164 check(volV->Reserve(n), "reserve volume_id");
0165 check(layV->Reserve(n), "reserve layer_id");
0166 check(surfV->Reserve(n), "reserve surface_id");
0167
0168
0169
0170
0171 constexpr std::uint64_t kUnmatched =
0172 std::numeric_limits<std::uint64_t>::max();
0173 constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
0174
0175 SimHitIndex hitIdx = 0;
0176 for (const auto& hit : simHits) {
0177 const auto& pos = hit.fourPosition();
0178 const float tx = static_cast<float>(pos.x() / Acts::UnitConstants::mm);
0179 const float ty = static_cast<float>(pos.y() / Acts::UnitConstants::mm);
0180 const float tz = static_cast<float>(pos.z() / Acts::UnitConstants::mm);
0181 const float t = static_cast<float>(pos.w() / Acts::UnitConstants::mm);
0182
0183 txV->UnsafeAppend(tx);
0184 tyV->UnsafeAppend(ty);
0185 tzV->UnsafeAppend(tz);
0186 timeV->UnsafeAppend(t);
0187
0188
0189
0190
0191
0192
0193
0194
0195 float gx = kNaN;
0196 float gy = kNaN;
0197 float gz = kNaN;
0198 if (simHitMeasMap != nullptr && clusters != nullptr) {
0199 auto range = simHitMeasMap->equal_range(hitIdx);
0200 if (range.first != range.second) {
0201 const Index clusterIdx = range.first->second;
0202 const Acts::Vector3& global = (*clusters)[clusterIdx].globalPosition;
0203 gx = static_cast<float>(global.x() / Acts::UnitConstants::mm);
0204 gy = static_cast<float>(global.y() / Acts::UnitConstants::mm);
0205 gz = static_cast<float>(global.z() / Acts::UnitConstants::mm);
0206 }
0207 }
0208 xV->UnsafeAppend(gx);
0209 yV->UnsafeAppend(gy);
0210 zV->UnsafeAppend(gz);
0211
0212 std::uint64_t pid = kUnmatched;
0213 if (particles != nullptr) {
0214 auto pIt = particles->find(hit.particleId());
0215 if (pIt != particles->end()) {
0216 pid =
0217 static_cast<std::uint64_t>(std::distance(particles->begin(), pIt));
0218 }
0219 }
0220 pidV->UnsafeAppend(pid);
0221
0222 const auto gid = hit.geometryId();
0223 volV->UnsafeAppend(static_cast<std::uint8_t>(gid.volume()));
0224 layV->UnsafeAppend(static_cast<std::uint16_t>(gid.layer()));
0225 surfV->UnsafeAppend(static_cast<std::uint32_t>(gid.sensitive()));
0226
0227
0228
0229
0230 detV->UnsafeAppend(m_cfg.detectorResolver(gid));
0231
0232 ++hitIdx;
0233 }
0234
0235 auto finish = [](arrow::ListBuilder& b) {
0236 std::shared_ptr<arrow::Array> out;
0237 check(b.Finish(&out), "finish list");
0238 return out;
0239 };
0240
0241 std::vector<std::shared_ptr<arrow::Array>> arrays = {
0242 finish(xList), finish(yList), finish(zList), finish(txList),
0243 finish(tyList), finish(tzList), finish(timeList), finish(pidList),
0244 finish(detList), finish(volList), finish(layList), finish(surfList),
0245 };
0246
0247 auto table =
0248 arrow::Table::Make(ActsPlugins::ArrowUtil::simHitSchema(), arrays);
0249 m_outputTable(ctx, ActsPlugins::ArrowUtil::ArrowTable{std::move(table)});
0250
0251 return ProcessCode::SUCCESS;
0252 }
0253
0254 }