Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-25 07:48:58

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0010 
0011 #include <algorithm>
0012 #include <exception>
0013 #include <iostream>
0014 #include <mutex>
0015 #include <stdexcept>
0016 #include <vector>
0017 
0018 #include <arrow/c/abi.h>
0019 #include <arrow/c/bridge.h>
0020 #include <arrow/compute/api.h>
0021 #include <arrow/dataset/dataset.h>
0022 #include <arrow/dataset/discovery.h>
0023 #include <arrow/dataset/file_parquet.h>
0024 #include <arrow/dataset/scanner.h>
0025 #include <arrow/filesystem/localfs.h>
0026 #include <arrow/io/file.h>
0027 #include <parquet/arrow/reader.h>
0028 #include <parquet/arrow/writer.h>
0029 
0030 namespace ActsPlugins::ArrowUtil {
0031 
0032 namespace {
0033 
0034 [[noreturn]] void throwArrow(const std::string& what,
0035                              const arrow::Status& status) {
0036   throw std::runtime_error(what + ": " + status.ToString());
0037 }
0038 
0039 template <typename T>
0040 T unwrap(arrow::Result<T> result, const std::string& what) {
0041   if (!result.ok()) {
0042     throwArrow(what, result.status());
0043   }
0044   return std::move(result).ValueOrDie();
0045 }
0046 
0047 // Arrow's compute kernels (e.g. `equal` used for the event_id filter
0048 // pushdown below) are registered lazily. With the linker-isolated arrow
0049 // island the registry is empty until we ask for it explicitly — without
0050 // this call, scan-time filtering fails with "No function registered with
0051 // name: equal".
0052 void ensureComputeInitialized() {
0053   static std::once_flag flag;
0054   std::call_once(flag, [] {
0055     auto status = arrow::compute::Initialize();
0056     if (!status.ok()) {
0057       throwArrow("arrow compute init", status);
0058     }
0059   });
0060 }
0061 
0062 }  // namespace
0063 
0064 std::string ArrowSchemaHandle::toString() const {
0065   return m_schema ? m_schema->ToString() : std::string{"<null schema>"};
0066 }
0067 
0068 std::vector<std::string> ArrowSchemaHandle::fieldNames() const {
0069   std::vector<std::string> names;
0070   if (!m_schema) {
0071     return names;
0072   }
0073   names.reserve(m_schema->num_fields());
0074   for (int i = 0; i < m_schema->num_fields(); ++i) {
0075     names.push_back(m_schema->field(i)->name());
0076   }
0077   return names;
0078 }
0079 
0080 int ArrowSchemaHandle::numFields() const {
0081   return m_schema ? m_schema->num_fields() : 0;
0082 }
0083 
0084 void ArrowSchemaHandle::exportToC(::ArrowSchema* out) const {
0085   if (out == nullptr) {
0086     throw std::invalid_argument("ArrowSchemaHandle::exportToC: null out");
0087   }
0088   if (!m_schema) {
0089     throw std::runtime_error(
0090         "ArrowSchemaHandle::exportToC: handle holds a null schema");
0091   }
0092   auto status = arrow::ExportSchema(*m_schema, out);
0093   if (!status.ok()) {
0094     throwArrow("ExportSchema", status);
0095   }
0096 }
0097 
0098 std::int64_t ArrowTable::numRows() const {
0099   return m_table ? m_table->num_rows() : 0;
0100 }
0101 
0102 int ArrowTable::numColumns() const {
0103   return m_table ? m_table->num_columns() : 0;
0104 }
0105 
0106 ArrowSchemaHandle ArrowTable::schema() const {
0107   return ArrowSchemaHandle{m_table ? m_table->schema() : nullptr};
0108 }
0109 
0110 std::string ArrowTable::toString() const {
0111   return m_table ? m_table->ToString() : std::string{"<null table>"};
0112 }
0113 
0114 void ArrowTable::exportToC(::ArrowSchema* out_schema,
0115                            ::ArrowArray* out_array) const {
0116   if (out_schema == nullptr || out_array == nullptr) {
0117     throw std::invalid_argument("ArrowTable::exportToC: null out");
0118   }
0119   if (!m_table) {
0120     throw std::runtime_error(
0121         "ArrowTable::exportToC: handle holds a null table");
0122   }
0123   // Combine to a single record batch — required by ExportRecordBatch and
0124   // matches what consumers receive through the `__arrow_c_array__`
0125   // protocol. For the canonical 1-row 1-chunk case this is essentially a
0126   // pointer copy.
0127   auto batch =
0128       unwrap(m_table->CombineChunksToBatch(), "table CombineChunksToBatch");
0129   auto status = arrow::ExportRecordBatch(*batch, out_array, out_schema);
0130   if (!status.ok()) {
0131     throwArrow("ExportRecordBatch", status);
0132   }
0133 }
0134 
0135 ArrowTable ArrowTable::importFromC(::ArrowSchema* in_schema,
0136                                    ::ArrowArray* in_array) {
0137   if (in_schema == nullptr || in_array == nullptr) {
0138     throw std::invalid_argument("ArrowTable::importFromC: null input");
0139   }
0140   // ImportRecordBatch consumes the C-Data structs (their release
0141   // callbacks are nulled out on success). Buffers are referenced via the
0142   // batch's internal release wiring, so the producer's memory stays
0143   // alive until our arrow::RecordBatch is destroyed.
0144   auto batch = unwrap(arrow::ImportRecordBatch(in_array, in_schema),
0145                       "ImportRecordBatch");
0146   auto table = unwrap(arrow::Table::FromRecordBatches({std::move(batch)}),
0147                       "Table::FromRecordBatches");
0148   return ArrowTable{std::move(table)};
0149 }
0150 
0151 std::shared_ptr<arrow::Field> eventIdField() {
0152   return arrow::field(std::string{kEventIdColumn}, arrow::uint32(),
0153                       /*nullable=*/false);
0154 }
0155 
0156 namespace {
0157 
0158 // `list<T>` whose inner elements are nullable. Used for perigee parameters
0159 // and track parameters so a per-element failure (no reference surface,
0160 // failed local transform, failed propagation) emits a real null instead
0161 // of a NaN sentinel.
0162 std::shared_ptr<arrow::DataType> nullableFloatList() {
0163   return arrow::list(arrow::field("item", arrow::float32(), true));
0164 }
0165 
0166 }  // namespace
0167 
0168 std::shared_ptr<arrow::Schema> particleSchema() {
0169   return arrow::schema({
0170       arrow::field("particle_id", arrow::list(arrow::uint64()), false),
0171       arrow::field("pdg_id", arrow::list(arrow::int64()), false),
0172       arrow::field("mass", arrow::list(arrow::float32()), false),
0173       arrow::field("energy", arrow::list(arrow::float32()), false),
0174       arrow::field("charge", arrow::list(arrow::float32()), false),
0175       arrow::field("vx", arrow::list(arrow::float32()), false),
0176       arrow::field("vy", arrow::list(arrow::float32()), false),
0177       arrow::field("vz", arrow::list(arrow::float32()), false),
0178       arrow::field("time", arrow::list(arrow::float32()), false),
0179       arrow::field("px", arrow::list(arrow::float32()), false),
0180       arrow::field("py", arrow::list(arrow::float32()), false),
0181       arrow::field("pz", arrow::list(arrow::float32()), false),
0182       arrow::field("perigee_d0", nullableFloatList(), false),
0183       arrow::field("perigee_z0", nullableFloatList(), false),
0184       arrow::field("vertex_primary", arrow::list(arrow::uint16()), false),
0185       arrow::field("parent_id", arrow::list(arrow::int64()), false),
0186       arrow::field("primary", arrow::list(arrow::boolean()), false),
0187   });
0188 }
0189 
0190 std::shared_ptr<arrow::Schema> trackSchema() {
0191   return arrow::schema({
0192       arrow::field("d0", nullableFloatList(), false),
0193       arrow::field("z0", nullableFloatList(), false),
0194       arrow::field("phi", nullableFloatList(), false),
0195       arrow::field("theta", nullableFloatList(), false),
0196       arrow::field("qop", nullableFloatList(), false),
0197       arrow::field("majority_particle_id", arrow::list(arrow::uint64()), false),
0198       arrow::field("hit_ids", arrow::list(arrow::list(arrow::uint32())), false),
0199       arrow::field("track_id", arrow::list(arrow::uint16()), false),
0200       arrow::field("t", nullableFloatList(), true),
0201   });
0202 }
0203 
0204 std::shared_ptr<arrow::Schema> simHitSchema() {
0205   return arrow::schema({
0206       arrow::field("x", arrow::list(arrow::float32()), false),
0207       arrow::field("y", arrow::list(arrow::float32()), false),
0208       arrow::field("z", arrow::list(arrow::float32()), false),
0209       arrow::field("true_x", arrow::list(arrow::float32()), false),
0210       arrow::field("true_y", arrow::list(arrow::float32()), false),
0211       arrow::field("true_z", arrow::list(arrow::float32()), false),
0212       arrow::field("time", arrow::list(arrow::float32()), false),
0213       arrow::field("particle_id", arrow::list(arrow::uint64()), false),
0214       arrow::field("detector", arrow::list(arrow::uint8()), false),
0215       arrow::field("volume_id", arrow::list(arrow::uint8()), false),
0216       arrow::field("layer_id", arrow::list(arrow::uint16()), false),
0217       arrow::field("surface_id", arrow::list(arrow::uint32()), false),
0218   });
0219 }
0220 
0221 std::shared_ptr<arrow::Table> withEventId(
0222     const std::shared_ptr<arrow::Table>& table, std::uint64_t eventId) {
0223   if (table == nullptr) {
0224     throw std::invalid_argument("withEventId: null table");
0225   }
0226   if (table->num_rows() != 1) {
0227     throw std::invalid_argument(
0228         "withEventId: expected a 1-row (nested-layout) table, got " +
0229         std::to_string(table->num_rows()) + " rows");
0230   }
0231   if (table->schema()->GetFieldIndex(kEventIdColumn) != -1) {
0232     throw std::invalid_argument("withEventId: table already has event_id");
0233   }
0234 
0235   arrow::UInt32Builder builder;
0236   if (auto status = builder.Append(static_cast<std::uint32_t>(eventId));
0237       !status.ok()) {
0238     throwArrow("event_id append", status);
0239   }
0240   auto array = unwrap(builder.Finish(), "event_id finish");
0241   auto chunked = std::make_shared<arrow::ChunkedArray>(std::move(array));
0242 
0243   return unwrap(table->AddColumn(0, eventIdField(), std::move(chunked)),
0244                 "add event_id column");
0245 }
0246 
0247 class ParquetFileWriter::Impl {
0248  public:
0249   explicit Impl(std::filesystem::path path) : m_path(std::move(path)) {}
0250 
0251   void write(const arrow::Table& table) {
0252     if (!m_writer) {
0253       auto outfile = unwrap(arrow::io::FileOutputStream::Open(m_path.string()),
0254                             "open parquet");
0255       auto properties = parquet::WriterProperties::Builder()
0256                             .compression(parquet::Compression::ZSTD)
0257                             ->enable_write_page_index()
0258                             ->build();
0259       auto arrowProperties =
0260           parquet::ArrowWriterProperties::Builder().store_schema()->build();
0261       m_writer = unwrap(parquet::arrow::FileWriter::Open(
0262                             *table.schema(), arrow::default_memory_pool(),
0263                             outfile, properties, arrowProperties),
0264                         "open parquet writer");
0265     }
0266     auto status = m_writer->WriteTable(table, table.num_rows());
0267     if (!status.ok()) {
0268       throwArrow("parquet WriteTable", status);
0269     }
0270   }
0271 
0272   void close() {
0273     if (m_writer) {
0274       auto status = m_writer->Close();
0275       m_writer.reset();
0276       if (!status.ok()) {
0277         throwArrow("parquet close", status);
0278       }
0279     }
0280   }
0281 
0282  private:
0283   std::filesystem::path m_path;
0284   std::unique_ptr<parquet::arrow::FileWriter> m_writer;
0285 };
0286 
0287 ParquetFileWriter::ParquetFileWriter(std::filesystem::path path)
0288     : m_impl(std::make_unique<Impl>(std::move(path))) {}
0289 
0290 ParquetFileWriter::~ParquetFileWriter() noexcept {
0291   if (m_impl) {
0292     try {
0293       m_impl->close();
0294     } catch (const std::exception& e) {
0295       std::cerr << "ParquetFileWriter::~ParquetFileWriter failed during close: "
0296                 << e.what() << std::endl;
0297       std::terminate();
0298     } catch (...) {
0299       std::cerr << "ParquetFileWriter::~ParquetFileWriter failed during close "
0300                    "with an unknown exception"
0301                 << std::endl;
0302       std::terminate();
0303     }
0304   }
0305 }
0306 
0307 void ParquetFileWriter::write(const arrow::Table& table) {
0308   m_impl->write(table);
0309 }
0310 
0311 void ParquetFileWriter::close() {
0312   m_impl->close();
0313 }
0314 
0315 class ParquetDatasetReader::Impl {
0316  public:
0317   Impl(std::filesystem::path directory,
0318        const std::shared_ptr<arrow::Schema>& targetSchema)
0319       : m_directory(std::move(directory)) {
0320     ensureComputeInitialized();
0321 
0322     if (!std::filesystem::exists(m_directory) ||
0323         !std::filesystem::is_directory(m_directory)) {
0324       throw std::invalid_argument("ParquetDatasetReader: not a directory: " +
0325                                   m_directory.string());
0326     }
0327 
0328     if (targetSchema != nullptr &&
0329         targetSchema->GetFieldIndex(kEventIdColumn) != -1) {
0330       throw std::invalid_argument(
0331           "ParquetDatasetReader: target schema must not contain event_id; "
0332           "the reader prepends it internally");
0333     }
0334 
0335     std::vector<std::string> files;
0336     for (const auto& e : std::filesystem::directory_iterator(m_directory)) {
0337       if (e.is_regular_file() && e.path().extension() == ".parquet") {
0338         files.push_back(e.path().string());
0339       }
0340     }
0341     if (files.empty()) {
0342       throw std::invalid_argument(
0343           "ParquetDatasetReader: no parquet files under " +
0344           m_directory.string());
0345     }
0346     std::ranges::sort(files);
0347 
0348     m_numEvents = 0;
0349     for (const auto& f : files) {
0350       auto infile =
0351           unwrap(arrow::io::ReadableFile::Open(f, arrow::default_memory_pool()),
0352                  "open parquet footer");
0353       auto reader =
0354           unwrap(parquet::arrow::OpenFile(infile, arrow::default_memory_pool()),
0355                  "open parquet reader");
0356       m_numEvents += reader->parquet_reader()->metadata()->num_rows();
0357     }
0358 
0359     auto fs = std::make_shared<arrow::fs::LocalFileSystem>();
0360     auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
0361     arrow::dataset::FileSystemFactoryOptions options;
0362     auto factory = unwrap(arrow::dataset::FileSystemDatasetFactory::Make(
0363                               std::move(fs), files, std::move(format), options),
0364                           "make dataset factory");
0365 
0366     arrow::dataset::FinishOptions finishOpts;
0367     if (targetSchema != nullptr) {
0368       // Prepend event_id; the dataset must carry it for the per-event
0369       // filter, but callers see it as an internal column and the
0370       // public schema below strips it again.
0371       auto withEventId =
0372           unwrap(targetSchema->AddField(0, eventIdField()), "prepend event_id");
0373       finishOpts.schema = std::move(withEventId);
0374     } else {
0375       arrow::dataset::InspectOptions inspectOpts;
0376       inspectOpts.fragments =
0377           arrow::dataset::InspectOptions::kInspectAllFragments;
0378       finishOpts.inspect_options = inspectOpts;
0379     }
0380     m_dataset = unwrap(factory->Finish(finishOpts), "finish dataset");
0381 
0382     auto fullSchema = m_dataset->schema();
0383     const int idx = fullSchema->GetFieldIndex(kEventIdColumn);
0384     if (idx < 0) {
0385       throw std::invalid_argument("ParquetDatasetReader: dataset under '" +
0386                                   m_directory.string() +
0387                                   "' lacks event_id column");
0388     }
0389     m_publicSchema =
0390         unwrap(fullSchema->RemoveField(idx), "strip event_id from schema");
0391   }
0392 
0393   std::int64_t numEvents() const { return m_numEvents; }
0394   std::shared_ptr<arrow::Schema> schema() const { return m_publicSchema; }
0395 
0396   std::shared_ptr<arrow::Table> readEvent(std::uint64_t eventId) const {
0397     auto builder = unwrap(m_dataset->NewScan(), "new scan");
0398     auto status = builder->Filter(arrow::compute::equal(
0399         arrow::compute::field_ref(kEventIdColumn.data()),
0400         arrow::compute::literal(static_cast<std::uint32_t>(eventId))));
0401     if (!status.ok()) {
0402       throwArrow("set scan filter", status);
0403     }
0404     status = builder->UseThreads(false);
0405     if (!status.ok()) {
0406       throwArrow("set use threads", status);
0407     }
0408     auto scanner = unwrap(builder->Finish(), "finish scanner");
0409     auto table = unwrap(scanner->ToTable(), "scan to table");
0410 
0411     const int idx = table->schema()->GetFieldIndex(std::string{kEventIdColumn});
0412     if (idx < 0) {
0413       throw std::runtime_error(
0414           "ParquetDatasetReader: scanned table lacks event_id column");
0415     }
0416     return unwrap(table->RemoveColumn(idx), "drop event_id column");
0417   }
0418 
0419  private:
0420   std::filesystem::path m_directory;
0421   std::int64_t m_numEvents = 0;
0422   std::shared_ptr<arrow::dataset::Dataset> m_dataset;
0423   std::shared_ptr<arrow::Schema> m_publicSchema;
0424 };
0425 
0426 ParquetDatasetReader::ParquetDatasetReader(
0427     std::filesystem::path directory,
0428     const std::shared_ptr<arrow::Schema>& targetSchema)
0429     : m_impl(std::make_unique<Impl>(std::move(directory), targetSchema)) {}
0430 
0431 ParquetDatasetReader::~ParquetDatasetReader() = default;
0432 
0433 std::int64_t ParquetDatasetReader::numEvents() const {
0434   return m_impl->numEvents();
0435 }
0436 
0437 std::shared_ptr<arrow::Schema> ParquetDatasetReader::schema() const {
0438   return m_impl->schema();
0439 }
0440 
0441 std::shared_ptr<arrow::Table> ParquetDatasetReader::readEvent(
0442     std::uint64_t eventId) const {
0443   return m_impl->readEvent(eventId);
0444 }
0445 
0446 }  // namespace ActsPlugins::ArrowUtil