Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-20 07:36:32

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/Io/Arrow/ArrowParticleOutputConverter.hpp"
0010 
0011 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0012 
0013 #include <cstdint>
0014 #include <iterator>
0015 #include <memory>
0016 #include <stdexcept>
0017 #include <string>
0018 #include <vector>
0019 
0020 #include <arrow/api.h>
0021 
0022 namespace ActsExamples {
0023 
0024 namespace {
0025 
0026 void check(const arrow::Status& s, const char* what) {
0027   if (!s.ok()) {
0028     throw std::runtime_error(std::string(what) + ": " + s.ToString());
0029   }
0030 }
0031 
0032 }  // namespace
0033 
0034 ArrowParticleOutputConverter::ArrowParticleOutputConverter(
0035     const Config& cfg, std::unique_ptr<const Acts::Logger> logger)
0036     : ArrowOutputConverter("ArrowParticleOutputConverter", std::move(logger)),
0037       m_cfg(cfg) {
0038   if (m_cfg.inputParticles.empty()) {
0039     throw std::invalid_argument("Missing particles input collection");
0040   }
0041   if (m_cfg.outputTable.empty()) {
0042     throw std::invalid_argument("Missing output table name");
0043   }
0044   m_inputParticles.initialize(m_cfg.inputParticles);
0045   m_outputTable.initialize(m_cfg.outputTable);
0046 }
0047 
0048 ArrowParticleOutputConverter::~ArrowParticleOutputConverter() = default;
0049 
0050 std::vector<std::string> ArrowParticleOutputConverter::collections() const {
0051   return {m_cfg.outputTable};
0052 }
0053 
0054 ProcessCode ArrowParticleOutputConverter::execute(
0055     const AlgorithmContext& ctx) const {
0056   const SimParticleContainer& particles = m_inputParticles(ctx);
0057   auto* pool = arrow::default_memory_pool();
0058 
0059   arrow::ListBuilder idList(pool, std::make_shared<arrow::UInt64Builder>(pool));
0060   arrow::ListBuilder pdgList(pool, std::make_shared<arrow::Int64Builder>(pool));
0061   arrow::ListBuilder massList(pool,
0062                               std::make_shared<arrow::FloatBuilder>(pool));
0063   arrow::ListBuilder energyList(pool,
0064                                 std::make_shared<arrow::FloatBuilder>(pool));
0065   arrow::ListBuilder chargeList(pool,
0066                                 std::make_shared<arrow::FloatBuilder>(pool));
0067   arrow::ListBuilder vxList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0068   arrow::ListBuilder vyList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0069   arrow::ListBuilder vzList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0070   arrow::ListBuilder timeList(pool,
0071                               std::make_shared<arrow::FloatBuilder>(pool));
0072   arrow::ListBuilder pxList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0073   arrow::ListBuilder pyList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0074   arrow::ListBuilder pzList(pool, std::make_shared<arrow::FloatBuilder>(pool));
0075   arrow::ListBuilder d0List(pool, std::make_shared<arrow::FloatBuilder>(pool));
0076   arrow::ListBuilder z0List(pool, std::make_shared<arrow::FloatBuilder>(pool));
0077   arrow::ListBuilder vprimList(pool,
0078                                std::make_shared<arrow::UInt16Builder>(pool));
0079   arrow::ListBuilder parentList(pool,
0080                                 std::make_shared<arrow::Int64Builder>(pool));
0081   arrow::ListBuilder primaryList(pool,
0082                                  std::make_shared<arrow::BooleanBuilder>(pool));
0083 
0084   check(idList.Append(), "open particle_id list");
0085   check(pdgList.Append(), "open pdg_id list");
0086   check(massList.Append(), "open mass list");
0087   check(energyList.Append(), "open energy list");
0088   check(chargeList.Append(), "open charge list");
0089   check(vxList.Append(), "open vx list");
0090   check(vyList.Append(), "open vy list");
0091   check(vzList.Append(), "open vz list");
0092   check(timeList.Append(), "open time list");
0093   check(pxList.Append(), "open px list");
0094   check(pyList.Append(), "open py list");
0095   check(pzList.Append(), "open pz list");
0096   check(d0List.Append(), "open perigee_d0 list");
0097   check(z0List.Append(), "open perigee_z0 list");
0098   check(vprimList.Append(), "open vertex_primary list");
0099   check(parentList.Append(), "open parent_id list");
0100   check(primaryList.Append(), "open primary list");
0101 
0102   auto* idV = static_cast<arrow::UInt64Builder*>(idList.value_builder());
0103   auto* pdgV = static_cast<arrow::Int64Builder*>(pdgList.value_builder());
0104   auto* massV = static_cast<arrow::FloatBuilder*>(massList.value_builder());
0105   auto* energyV = static_cast<arrow::FloatBuilder*>(energyList.value_builder());
0106   auto* chargeV = static_cast<arrow::FloatBuilder*>(chargeList.value_builder());
0107   auto* vxV = static_cast<arrow::FloatBuilder*>(vxList.value_builder());
0108   auto* vyV = static_cast<arrow::FloatBuilder*>(vyList.value_builder());
0109   auto* vzV = static_cast<arrow::FloatBuilder*>(vzList.value_builder());
0110   auto* timeV = static_cast<arrow::FloatBuilder*>(timeList.value_builder());
0111   auto* pxV = static_cast<arrow::FloatBuilder*>(pxList.value_builder());
0112   auto* pyV = static_cast<arrow::FloatBuilder*>(pyList.value_builder());
0113   auto* pzV = static_cast<arrow::FloatBuilder*>(pzList.value_builder());
0114   auto* d0V = static_cast<arrow::FloatBuilder*>(d0List.value_builder());
0115   auto* z0V = static_cast<arrow::FloatBuilder*>(z0List.value_builder());
0116   auto* vprimV = static_cast<arrow::UInt16Builder*>(vprimList.value_builder());
0117   auto* parentV = static_cast<arrow::Int64Builder*>(parentList.value_builder());
0118   auto* primaryV =
0119       static_cast<arrow::BooleanBuilder*>(primaryList.value_builder());
0120 
0121   const auto n = particles.size();
0122   check(idV->Reserve(n), "reserve particle_id");
0123   check(pdgV->Reserve(n), "reserve pdg_id");
0124   check(massV->Reserve(n), "reserve mass");
0125   check(energyV->Reserve(n), "reserve energy");
0126   check(chargeV->Reserve(n), "reserve charge");
0127   check(vxV->Reserve(n), "reserve vx");
0128   check(vyV->Reserve(n), "reserve vy");
0129   check(vzV->Reserve(n), "reserve vz");
0130   check(timeV->Reserve(n), "reserve time");
0131   check(pxV->Reserve(n), "reserve px");
0132   check(pyV->Reserve(n), "reserve py");
0133   check(pzV->Reserve(n), "reserve pz");
0134   check(d0V->Reserve(n), "reserve perigee_d0");
0135   check(z0V->Reserve(n), "reserve perigee_z0");
0136   check(vprimV->Reserve(n), "reserve vertex_primary");
0137   check(parentV->Reserve(n), "reserve parent_id");
0138   check(primaryV->Reserve(n), "reserve primary");
0139 
0140   std::int64_t rowIndex = 0;
0141   for (const auto& particle : particles) {
0142     const auto& s = particle.initialState();
0143     const auto mom = s.momentum();
0144     const auto pos = s.position();
0145     const auto bc = s.particleId();
0146 
0147     // Emit the row index as the particle id (matches the colliderml
0148     // convention). Indices are stable within this event/output table even
0149     // when upstream filtering has dropped some EDM4hep particles.
0150     idV->UnsafeAppend(static_cast<std::uint64_t>(rowIndex));
0151     pdgV->UnsafeAppend(static_cast<std::int64_t>(s.pdg()));
0152     massV->UnsafeAppend(static_cast<float>(s.mass()));
0153     energyV->UnsafeAppend(static_cast<float>(s.energy()));
0154     chargeV->UnsafeAppend(static_cast<float>(s.charge()));
0155     vxV->UnsafeAppend(static_cast<float>(pos.x()));
0156     vyV->UnsafeAppend(static_cast<float>(pos.y()));
0157     vzV->UnsafeAppend(static_cast<float>(pos.z()));
0158     timeV->UnsafeAppend(static_cast<float>(s.time()));
0159     pxV->UnsafeAppend(static_cast<float>(mom.x()));
0160     pyV->UnsafeAppend(static_cast<float>(mom.y()));
0161     pzV->UnsafeAppend(static_cast<float>(mom.z()));
0162 
0163     // Perigee parameters are not computed here yet: the truth-to-perigee
0164     // propagation will be added back in a follow-up PR. Until then these
0165     // columns are emitted as null for every particle.
0166     d0V->UnsafeAppendNull();
0167     z0V->UnsafeAppendNull();
0168 
0169     vprimV->UnsafeAppend(static_cast<std::uint16_t>(bc.vertexPrimary()));
0170     // Emit the parent's row index in this same table so consumers can walk
0171     // the chain. -1 means "unknown" (parent was filtered out, or simulation
0172     // engine didn't record it). The container is a flat_set sorted by
0173     // barcode, so find()+distance is O(log N) with O(1) random access.
0174     const auto parentBc = particle.parentParticleId();
0175     std::int64_t parentRow = -1;
0176     if (parentBc.isValid()) {
0177       auto it = particles.find(parentBc);
0178       if (it != particles.end()) {
0179         parentRow = std::distance(particles.begin(), it);
0180       }
0181     }
0182     parentV->UnsafeAppend(parentRow);
0183     check(primaryV->Append(bc.generation() == 0), "append primary");
0184     ++rowIndex;
0185   }
0186 
0187   auto finish = [](arrow::ListBuilder& b) {
0188     std::shared_ptr<arrow::Array> out;
0189     check(b.Finish(&out), "finish list");
0190     return out;
0191   };
0192 
0193   std::vector<std::shared_ptr<arrow::Array>> arrays = {
0194       finish(idList),     finish(pdgList),     finish(massList),
0195       finish(energyList), finish(chargeList),  finish(vxList),
0196       finish(vyList),     finish(vzList),      finish(timeList),
0197       finish(pxList),     finish(pyList),      finish(pzList),
0198       finish(d0List),     finish(z0List),      finish(vprimList),
0199       finish(parentList), finish(primaryList),
0200   };
0201 
0202   auto table =
0203       arrow::Table::Make(ActsPlugins::ArrowUtil::particleSchema(), arrays);
0204   m_outputTable(ctx, ActsPlugins::ArrowUtil::ArrowTable{std::move(table)});
0205 
0206   return ProcessCode::SUCCESS;
0207 }
0208 
0209 }  // namespace ActsExamples