Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-20 07:36:33

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/Io/Arrow/ArrowSimHitOutputConverter.hpp"
0010 
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/Definitions/Units.hpp"
0013 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0014 
0015 #include <array>
0016 #include <cmath>
0017 #include <cstdint>
0018 #include <limits>
0019 #include <memory>
0020 #include <stdexcept>
0021 
0022 #include <arrow/api.h>
0023 
0024 namespace ActsExamples {
0025 
0026 namespace {
0027 
0028 void check(const arrow::Status& s, const char* what) {
0029   if (!s.ok()) {
0030     throw std::runtime_error(std::string(what) + ": " + s.ToString());
0031   }
0032 }
0033 
0034 }  // namespace
0035 
0036 ArrowSimHitOutputConverter::ArrowSimHitOutputConverter(
0037     const Config& cfg, std::unique_ptr<const Acts::Logger> logger)
0038     : ArrowOutputConverter("ArrowSimHitOutputConverter", std::move(logger)),
0039       m_cfg(cfg) {
0040   if (m_cfg.inputSimHits.empty()) {
0041     throw std::invalid_argument("Missing sim hits input collection");
0042   }
0043   if (m_cfg.outputTable.empty()) {
0044     throw std::invalid_argument("Missing output table name");
0045   }
0046   if (!m_cfg.detectorResolver) {
0047     throw std::invalid_argument("detectorResolver must be set");
0048   }
0049 
0050   // The digitized x,y,z columns require both (clusters, map) — partial wiring
0051   // would silently produce stale or NaN positions, so reject it up front.
0052   const bool hasClusters = !m_cfg.inputClusters.empty();
0053   const bool hasMap = !m_cfg.inputSimHitMeasurementsMap.empty();
0054   if (hasClusters != hasMap) {
0055     throw std::invalid_argument(
0056         "ArrowSimHitOutputConverter: inputClusters and "
0057         "inputSimHitMeasurementsMap must both be set or both be unset");
0058   }
0059 
0060   m_inputSimHits.initialize(m_cfg.inputSimHits);
0061   m_inputParticles.maybeInitialize(m_cfg.inputParticles);
0062   m_inputClusters.maybeInitialize(m_cfg.inputClusters);
0063   m_inputSimHitMeasurementsMap.maybeInitialize(
0064       m_cfg.inputSimHitMeasurementsMap);
0065   m_outputTable.initialize(m_cfg.outputTable);
0066 }
0067 
0068 std::function<std::uint8_t(Acts::GeometryIdentifier)>
0069 ArrowSimHitOutputConverter::makeVolumeIdDetectorResolver(
0070     const std::unordered_map<std::uint32_t, std::uint8_t>& volumeToDetector,
0071     std::uint8_t defaultValue) {
0072   constexpr std::size_t kNumVolumes =
0073       static_cast<std::size_t>(Acts::GeometryIdentifier::getMaxVolume()) + 1;
0074   std::array<std::uint8_t, kNumVolumes> detectorArray{};
0075   detectorArray.fill(defaultValue);
0076   for (const auto& [volume, detector] : volumeToDetector) {
0077     if (volume >= kNumVolumes) {
0078       throw std::invalid_argument(
0079           "makeVolumeIdDetectorResolver: volume id " + std::to_string(volume) +
0080           " exceeds maximum " +
0081           std::to_string(Acts::GeometryIdentifier::getMaxVolume()));
0082     }
0083     detectorArray[volume] = detector;
0084   }
0085   return [detectorArray](Acts::GeometryIdentifier gid) -> std::uint8_t {
0086     const auto volume = static_cast<std::size_t>(gid.volume());
0087     return detectorArray[volume];
0088   };
0089 }
0090 
0091 std::vector<std::string> ArrowSimHitOutputConverter::collections() const {
0092   return {m_cfg.outputTable};
0093 }
0094 
0095 ProcessCode ArrowSimHitOutputConverter::execute(
0096     const AlgorithmContext& ctx) const {
0097   const SimHitContainer& simHits = m_inputSimHits(ctx);
0098   const SimParticleContainer* particles =
0099       m_inputParticles.isInitialized() ? &m_inputParticles(ctx) : nullptr;
0100   const ClusterContainer* clusters =
0101       m_inputClusters.isInitialized() ? &m_inputClusters(ctx) : nullptr;
0102   const SimHitMeasurementsMap* simHitMeasMap =
0103       m_inputSimHitMeasurementsMap.isInitialized()
0104           ? &m_inputSimHitMeasurementsMap(ctx)
0105           : nullptr;
0106 
0107   auto* pool = arrow::default_memory_pool();
0108 
0109   arrow::ListBuilder xList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0110   arrow::ListBuilder yList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0111   arrow::ListBuilder zList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0112   arrow::ListBuilder txList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0113   arrow::ListBuilder tyList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0114   arrow::ListBuilder tzList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0115   arrow::ListBuilder timeList(pool,
0116                               std::make_shared<arrow::FloatBuilder>(pool));
0117   arrow::ListBuilder pidList(pool,
0118                              std::make_shared<arrow::UInt64Builder>(pool));
0119   arrow::ListBuilder detList(pool, std::make_shared<arrow::UInt8Builder>(pool));
0120   arrow::ListBuilder volList(pool, std::make_shared<arrow::UInt8Builder>(pool));
0121   arrow::ListBuilder layList(pool,
0122                              std::make_shared<arrow::UInt16Builder>(pool));
0123   arrow::ListBuilder surfList(pool,
0124                               std::make_shared<arrow::UInt32Builder>(pool));
0125 
0126   check(xList.Append(), "open x list");
0127   check(yList.Append(), "open y list");
0128   check(zList.Append(), "open z list");
0129   check(txList.Append(), "open true_x list");
0130   check(tyList.Append(), "open true_y list");
0131   check(tzList.Append(), "open true_z list");
0132   check(timeList.Append(), "open time list");
0133   check(pidList.Append(), "open particle_id list");
0134   check(detList.Append(), "open detector list");
0135   check(volList.Append(), "open volume_id list");
0136   check(layList.Append(), "open layer_id list");
0137   check(surfList.Append(), "open surface_id list");
0138 
0139   // TODO: Keep typed child builder handles when constructing the list builders
0140   // instead of recovering them through value_builder() and static_cast.
0141   auto* xV = static_cast<arrow::FloatBuilder*>(xList.value_builder());
0142   auto* yV = static_cast<arrow::FloatBuilder*>(yList.value_builder());
0143   auto* zV = static_cast<arrow::FloatBuilder*>(zList.value_builder());
0144   auto* txV = static_cast<arrow::FloatBuilder*>(txList.value_builder());
0145   auto* tyV = static_cast<arrow::FloatBuilder*>(tyList.value_builder());
0146   auto* tzV = static_cast<arrow::FloatBuilder*>(tzList.value_builder());
0147   auto* timeV = static_cast<arrow::FloatBuilder*>(timeList.value_builder());
0148   auto* pidV = static_cast<arrow::UInt64Builder*>(pidList.value_builder());
0149   auto* detV = static_cast<arrow::UInt8Builder*>(detList.value_builder());
0150   auto* volV = static_cast<arrow::UInt8Builder*>(volList.value_builder());
0151   auto* layV = static_cast<arrow::UInt16Builder*>(layList.value_builder());
0152   auto* surfV = static_cast<arrow::UInt32Builder*>(surfList.value_builder());
0153 
0154   const auto n = simHits.size();
0155   check(xV->Reserve(n), "reserve x");
0156   check(yV->Reserve(n), "reserve y");
0157   check(zV->Reserve(n), "reserve z");
0158   check(txV->Reserve(n), "reserve true_x");
0159   check(tyV->Reserve(n), "reserve true_y");
0160   check(tzV->Reserve(n), "reserve true_z");
0161   check(timeV->Reserve(n), "reserve time");
0162   check(pidV->Reserve(n), "reserve particle_id");
0163   check(detV->Reserve(n), "reserve detector");
0164   check(volV->Reserve(n), "reserve volume_id");
0165   check(layV->Reserve(n), "reserve layer_id");
0166   check(surfV->Reserve(n), "reserve surface_id");
0167 
0168   // Sentinel matches the convention used by ArrowTrackOutputConverter for
0169   // unmatched rows.
0170   // @TODO: Turn into explicit optionals?
0171   constexpr std::uint64_t kUnmatched =
0172       std::numeric_limits<std::uint64_t>::max();
0173   constexpr float kNaN = std::numeric_limits<float>::quiet_NaN();
0174 
0175   SimHitIndex hitIdx = 0;
0176   for (const auto& hit : simHits) {
0177     const auto& pos = hit.fourPosition();
0178     const float tx = static_cast<float>(pos.x() / Acts::UnitConstants::mm);
0179     const float ty = static_cast<float>(pos.y() / Acts::UnitConstants::mm);
0180     const float tz = static_cast<float>(pos.z() / Acts::UnitConstants::mm);
0181     const float t = static_cast<float>(pos.w() / Acts::UnitConstants::mm);
0182 
0183     txV->UnsafeAppend(tx);
0184     tyV->UnsafeAppend(ty);
0185     tzV->UnsafeAppend(tz);
0186     timeV->UnsafeAppend(t);
0187 
0188     // Digitized global position: reuse the precomputed global position of the
0189     // first matched cluster. Clusters are indexed one-to-one with
0190     // measurements, so the sim-hit → measurement map doubles as a sim-hit →
0191     // cluster map. Multiple measurements per hit would only happen if a hit
0192     // migrated across modules during clustering; we take the first
0193     // deterministically and leave the rest for a future "merged hits"
0194     // extension.
0195     float gx = kNaN;
0196     float gy = kNaN;
0197     float gz = kNaN;
0198     if (simHitMeasMap != nullptr && clusters != nullptr) {
0199       auto range = simHitMeasMap->equal_range(hitIdx);
0200       if (range.first != range.second) {
0201         const Index clusterIdx = range.first->second;
0202         const Acts::Vector3& global = (*clusters)[clusterIdx].globalPosition;
0203         gx = static_cast<float>(global.x() / Acts::UnitConstants::mm);
0204         gy = static_cast<float>(global.y() / Acts::UnitConstants::mm);
0205         gz = static_cast<float>(global.z() / Acts::UnitConstants::mm);
0206       }
0207     }
0208     xV->UnsafeAppend(gx);
0209     yV->UnsafeAppend(gy);
0210     zV->UnsafeAppend(gz);
0211 
0212     std::uint64_t pid = kUnmatched;
0213     if (particles != nullptr) {
0214       auto pIt = particles->find(hit.particleId());
0215       if (pIt != particles->end()) {
0216         pid =
0217             static_cast<std::uint64_t>(std::distance(particles->begin(), pIt));
0218       }
0219     }
0220     pidV->UnsafeAppend(pid);
0221 
0222     const auto gid = hit.geometryId();
0223     volV->UnsafeAppend(static_cast<std::uint8_t>(gid.volume()));
0224     layV->UnsafeAppend(static_cast<std::uint16_t>(gid.layer()));
0225     surfV->UnsafeAppend(static_cast<std::uint32_t>(gid.sensitive()));
0226     // The default `detectorResolver` reads the geometry id's
0227     // `extra` byte, which we rely on geometry construction to stamp with a
0228     // per-surface subsystem id. By default, every hit gets `extra() == 0`
0229     // unless the user supplies a custom resolver.
0230     detV->UnsafeAppend(m_cfg.detectorResolver(gid));
0231 
0232     ++hitIdx;
0233   }
0234 
0235   auto finish = [](arrow::ListBuilder& b) {
0236     std::shared_ptr<arrow::Array> out;
0237     check(b.Finish(&out), "finish list");
0238     return out;
0239   };
0240 
0241   std::vector<std::shared_ptr<arrow::Array>> arrays = {
0242       finish(xList),   finish(yList),   finish(zList),    finish(txList),
0243       finish(tyList),  finish(tzList),  finish(timeList), finish(pidList),
0244       finish(detList), finish(volList), finish(layList),  finish(surfList),
0245   };
0246 
0247   auto table =
0248       arrow::Table::Make(ActsPlugins::ArrowUtil::simHitSchema(), arrays);
0249   m_outputTable(ctx, ActsPlugins::ArrowUtil::ArrowTable{std::move(table)});
0250 
0251   return ProcessCode::SUCCESS;
0252 }
0253 
0254 }  // namespace ActsExamples