File indexing completed on 2026-06-25 07:48:58
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0010
0011 #include <algorithm>
0012 #include <exception>
0013 #include <iostream>
0014 #include <mutex>
0015 #include <stdexcept>
0016 #include <vector>
0017
0018 #include <arrow/c/abi.h>
0019 #include <arrow/c/bridge.h>
0020 #include <arrow/compute/api.h>
0021 #include <arrow/dataset/dataset.h>
0022 #include <arrow/dataset/discovery.h>
0023 #include <arrow/dataset/file_parquet.h>
0024 #include <arrow/dataset/scanner.h>
0025 #include <arrow/filesystem/localfs.h>
0026 #include <arrow/io/file.h>
0027 #include <parquet/arrow/reader.h>
0028 #include <parquet/arrow/writer.h>
0029
0030 namespace ActsPlugins::ArrowUtil {
0031
0032 namespace {
0033
0034 [[noreturn]] void throwArrow(const std::string& what,
0035 const arrow::Status& status) {
0036 throw std::runtime_error(what + ": " + status.ToString());
0037 }
0038
0039 template <typename T>
0040 T unwrap(arrow::Result<T> result, const std::string& what) {
0041 if (!result.ok()) {
0042 throwArrow(what, result.status());
0043 }
0044 return std::move(result).ValueOrDie();
0045 }
0046
0047
0048
0049
0050
0051
0052 void ensureComputeInitialized() {
0053 static std::once_flag flag;
0054 std::call_once(flag, [] {
0055 auto status = arrow::compute::Initialize();
0056 if (!status.ok()) {
0057 throwArrow("arrow compute init", status);
0058 }
0059 });
0060 }
0061
0062 }
0063
0064 std::string ArrowSchemaHandle::toString() const {
0065 return m_schema ? m_schema->ToString() : std::string{"<null schema>"};
0066 }
0067
0068 std::vector<std::string> ArrowSchemaHandle::fieldNames() const {
0069 std::vector<std::string> names;
0070 if (!m_schema) {
0071 return names;
0072 }
0073 names.reserve(m_schema->num_fields());
0074 for (int i = 0; i < m_schema->num_fields(); ++i) {
0075 names.push_back(m_schema->field(i)->name());
0076 }
0077 return names;
0078 }
0079
0080 int ArrowSchemaHandle::numFields() const {
0081 return m_schema ? m_schema->num_fields() : 0;
0082 }
0083
0084 void ArrowSchemaHandle::exportToC(::ArrowSchema* out) const {
0085 if (out == nullptr) {
0086 throw std::invalid_argument("ArrowSchemaHandle::exportToC: null out");
0087 }
0088 if (!m_schema) {
0089 throw std::runtime_error(
0090 "ArrowSchemaHandle::exportToC: handle holds a null schema");
0091 }
0092 auto status = arrow::ExportSchema(*m_schema, out);
0093 if (!status.ok()) {
0094 throwArrow("ExportSchema", status);
0095 }
0096 }
0097
0098 std::int64_t ArrowTable::numRows() const {
0099 return m_table ? m_table->num_rows() : 0;
0100 }
0101
0102 int ArrowTable::numColumns() const {
0103 return m_table ? m_table->num_columns() : 0;
0104 }
0105
0106 ArrowSchemaHandle ArrowTable::schema() const {
0107 return ArrowSchemaHandle{m_table ? m_table->schema() : nullptr};
0108 }
0109
0110 std::string ArrowTable::toString() const {
0111 return m_table ? m_table->ToString() : std::string{"<null table>"};
0112 }
0113
0114 void ArrowTable::exportToC(::ArrowSchema* out_schema,
0115 ::ArrowArray* out_array) const {
0116 if (out_schema == nullptr || out_array == nullptr) {
0117 throw std::invalid_argument("ArrowTable::exportToC: null out");
0118 }
0119 if (!m_table) {
0120 throw std::runtime_error(
0121 "ArrowTable::exportToC: handle holds a null table");
0122 }
0123
0124
0125
0126
0127 auto batch =
0128 unwrap(m_table->CombineChunksToBatch(), "table CombineChunksToBatch");
0129 auto status = arrow::ExportRecordBatch(*batch, out_array, out_schema);
0130 if (!status.ok()) {
0131 throwArrow("ExportRecordBatch", status);
0132 }
0133 }
0134
0135 ArrowTable ArrowTable::importFromC(::ArrowSchema* in_schema,
0136 ::ArrowArray* in_array) {
0137 if (in_schema == nullptr || in_array == nullptr) {
0138 throw std::invalid_argument("ArrowTable::importFromC: null input");
0139 }
0140
0141
0142
0143
0144 auto batch = unwrap(arrow::ImportRecordBatch(in_array, in_schema),
0145 "ImportRecordBatch");
0146 auto table = unwrap(arrow::Table::FromRecordBatches({std::move(batch)}),
0147 "Table::FromRecordBatches");
0148 return ArrowTable{std::move(table)};
0149 }
0150
0151 std::shared_ptr<arrow::Field> eventIdField() {
0152 return arrow::field(std::string{kEventIdColumn}, arrow::uint32(),
0153 false);
0154 }
0155
0156 namespace {
0157
0158
0159
0160
0161
0162 std::shared_ptr<arrow::DataType> nullableFloatList() {
0163 return arrow::list(arrow::field("item", arrow::float32(), true));
0164 }
0165
0166 }
0167
0168 std::shared_ptr<arrow::Schema> particleSchema() {
0169 return arrow::schema({
0170 arrow::field("particle_id", arrow::list(arrow::uint64()), false),
0171 arrow::field("pdg_id", arrow::list(arrow::int64()), false),
0172 arrow::field("mass", arrow::list(arrow::float32()), false),
0173 arrow::field("energy", arrow::list(arrow::float32()), false),
0174 arrow::field("charge", arrow::list(arrow::float32()), false),
0175 arrow::field("vx", arrow::list(arrow::float32()), false),
0176 arrow::field("vy", arrow::list(arrow::float32()), false),
0177 arrow::field("vz", arrow::list(arrow::float32()), false),
0178 arrow::field("time", arrow::list(arrow::float32()), false),
0179 arrow::field("px", arrow::list(arrow::float32()), false),
0180 arrow::field("py", arrow::list(arrow::float32()), false),
0181 arrow::field("pz", arrow::list(arrow::float32()), false),
0182 arrow::field("perigee_d0", nullableFloatList(), false),
0183 arrow::field("perigee_z0", nullableFloatList(), false),
0184 arrow::field("vertex_primary", arrow::list(arrow::uint16()), false),
0185 arrow::field("parent_id", arrow::list(arrow::int64()), false),
0186 arrow::field("primary", arrow::list(arrow::boolean()), false),
0187 });
0188 }
0189
0190 std::shared_ptr<arrow::Schema> trackSchema() {
0191 return arrow::schema({
0192 arrow::field("d0", nullableFloatList(), false),
0193 arrow::field("z0", nullableFloatList(), false),
0194 arrow::field("phi", nullableFloatList(), false),
0195 arrow::field("theta", nullableFloatList(), false),
0196 arrow::field("qop", nullableFloatList(), false),
0197 arrow::field("majority_particle_id", arrow::list(arrow::uint64()), false),
0198 arrow::field("hit_ids", arrow::list(arrow::list(arrow::uint32())), false),
0199 arrow::field("track_id", arrow::list(arrow::uint16()), false),
0200 arrow::field("t", nullableFloatList(), true),
0201 });
0202 }
0203
0204 std::shared_ptr<arrow::Schema> simHitSchema() {
0205 return arrow::schema({
0206 arrow::field("x", arrow::list(arrow::float32()), false),
0207 arrow::field("y", arrow::list(arrow::float32()), false),
0208 arrow::field("z", arrow::list(arrow::float32()), false),
0209 arrow::field("true_x", arrow::list(arrow::float32()), false),
0210 arrow::field("true_y", arrow::list(arrow::float32()), false),
0211 arrow::field("true_z", arrow::list(arrow::float32()), false),
0212 arrow::field("time", arrow::list(arrow::float32()), false),
0213 arrow::field("particle_id", arrow::list(arrow::uint64()), false),
0214 arrow::field("detector", arrow::list(arrow::uint8()), false),
0215 arrow::field("volume_id", arrow::list(arrow::uint8()), false),
0216 arrow::field("layer_id", arrow::list(arrow::uint16()), false),
0217 arrow::field("surface_id", arrow::list(arrow::uint32()), false),
0218 });
0219 }
0220
0221 std::shared_ptr<arrow::Table> withEventId(
0222 const std::shared_ptr<arrow::Table>& table, std::uint64_t eventId) {
0223 if (table == nullptr) {
0224 throw std::invalid_argument("withEventId: null table");
0225 }
0226 if (table->num_rows() != 1) {
0227 throw std::invalid_argument(
0228 "withEventId: expected a 1-row (nested-layout) table, got " +
0229 std::to_string(table->num_rows()) + " rows");
0230 }
0231 if (table->schema()->GetFieldIndex(kEventIdColumn) != -1) {
0232 throw std::invalid_argument("withEventId: table already has event_id");
0233 }
0234
0235 arrow::UInt32Builder builder;
0236 if (auto status = builder.Append(static_cast<std::uint32_t>(eventId));
0237 !status.ok()) {
0238 throwArrow("event_id append", status);
0239 }
0240 auto array = unwrap(builder.Finish(), "event_id finish");
0241 auto chunked = std::make_shared<arrow::ChunkedArray>(std::move(array));
0242
0243 return unwrap(table->AddColumn(0, eventIdField(), std::move(chunked)),
0244 "add event_id column");
0245 }
0246
0247 class ParquetFileWriter::Impl {
0248 public:
0249 explicit Impl(std::filesystem::path path) : m_path(std::move(path)) {}
0250
0251 void write(const arrow::Table& table) {
0252 if (!m_writer) {
0253 auto outfile = unwrap(arrow::io::FileOutputStream::Open(m_path.string()),
0254 "open parquet");
0255 auto properties = parquet::WriterProperties::Builder()
0256 .compression(parquet::Compression::ZSTD)
0257 ->enable_write_page_index()
0258 ->build();
0259 auto arrowProperties =
0260 parquet::ArrowWriterProperties::Builder().store_schema()->build();
0261 m_writer = unwrap(parquet::arrow::FileWriter::Open(
0262 *table.schema(), arrow::default_memory_pool(),
0263 outfile, properties, arrowProperties),
0264 "open parquet writer");
0265 }
0266 auto status = m_writer->WriteTable(table, table.num_rows());
0267 if (!status.ok()) {
0268 throwArrow("parquet WriteTable", status);
0269 }
0270 }
0271
0272 void close() {
0273 if (m_writer) {
0274 auto status = m_writer->Close();
0275 m_writer.reset();
0276 if (!status.ok()) {
0277 throwArrow("parquet close", status);
0278 }
0279 }
0280 }
0281
0282 private:
0283 std::filesystem::path m_path;
0284 std::unique_ptr<parquet::arrow::FileWriter> m_writer;
0285 };
0286
0287 ParquetFileWriter::ParquetFileWriter(std::filesystem::path path)
0288 : m_impl(std::make_unique<Impl>(std::move(path))) {}
0289
0290 ParquetFileWriter::~ParquetFileWriter() noexcept {
0291 if (m_impl) {
0292 try {
0293 m_impl->close();
0294 } catch (const std::exception& e) {
0295 std::cerr << "ParquetFileWriter::~ParquetFileWriter failed during close: "
0296 << e.what() << std::endl;
0297 std::terminate();
0298 } catch (...) {
0299 std::cerr << "ParquetFileWriter::~ParquetFileWriter failed during close "
0300 "with an unknown exception"
0301 << std::endl;
0302 std::terminate();
0303 }
0304 }
0305 }
0306
0307 void ParquetFileWriter::write(const arrow::Table& table) {
0308 m_impl->write(table);
0309 }
0310
0311 void ParquetFileWriter::close() {
0312 m_impl->close();
0313 }
0314
0315 class ParquetDatasetReader::Impl {
0316 public:
0317 Impl(std::filesystem::path directory,
0318 const std::shared_ptr<arrow::Schema>& targetSchema)
0319 : m_directory(std::move(directory)) {
0320 ensureComputeInitialized();
0321
0322 if (!std::filesystem::exists(m_directory) ||
0323 !std::filesystem::is_directory(m_directory)) {
0324 throw std::invalid_argument("ParquetDatasetReader: not a directory: " +
0325 m_directory.string());
0326 }
0327
0328 if (targetSchema != nullptr &&
0329 targetSchema->GetFieldIndex(kEventIdColumn) != -1) {
0330 throw std::invalid_argument(
0331 "ParquetDatasetReader: target schema must not contain event_id; "
0332 "the reader prepends it internally");
0333 }
0334
0335 std::vector<std::string> files;
0336 for (const auto& e : std::filesystem::directory_iterator(m_directory)) {
0337 if (e.is_regular_file() && e.path().extension() == ".parquet") {
0338 files.push_back(e.path().string());
0339 }
0340 }
0341 if (files.empty()) {
0342 throw std::invalid_argument(
0343 "ParquetDatasetReader: no parquet files under " +
0344 m_directory.string());
0345 }
0346 std::ranges::sort(files);
0347
0348 m_numEvents = 0;
0349 for (const auto& f : files) {
0350 auto infile =
0351 unwrap(arrow::io::ReadableFile::Open(f, arrow::default_memory_pool()),
0352 "open parquet footer");
0353 auto reader =
0354 unwrap(parquet::arrow::OpenFile(infile, arrow::default_memory_pool()),
0355 "open parquet reader");
0356 m_numEvents += reader->parquet_reader()->metadata()->num_rows();
0357 }
0358
0359 auto fs = std::make_shared<arrow::fs::LocalFileSystem>();
0360 auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
0361 arrow::dataset::FileSystemFactoryOptions options;
0362 auto factory = unwrap(arrow::dataset::FileSystemDatasetFactory::Make(
0363 std::move(fs), files, std::move(format), options),
0364 "make dataset factory");
0365
0366 arrow::dataset::FinishOptions finishOpts;
0367 if (targetSchema != nullptr) {
0368
0369
0370
0371 auto withEventId =
0372 unwrap(targetSchema->AddField(0, eventIdField()), "prepend event_id");
0373 finishOpts.schema = std::move(withEventId);
0374 } else {
0375 arrow::dataset::InspectOptions inspectOpts;
0376 inspectOpts.fragments =
0377 arrow::dataset::InspectOptions::kInspectAllFragments;
0378 finishOpts.inspect_options = inspectOpts;
0379 }
0380 m_dataset = unwrap(factory->Finish(finishOpts), "finish dataset");
0381
0382 auto fullSchema = m_dataset->schema();
0383 const int idx = fullSchema->GetFieldIndex(kEventIdColumn);
0384 if (idx < 0) {
0385 throw std::invalid_argument("ParquetDatasetReader: dataset under '" +
0386 m_directory.string() +
0387 "' lacks event_id column");
0388 }
0389 m_publicSchema =
0390 unwrap(fullSchema->RemoveField(idx), "strip event_id from schema");
0391 }
0392
0393 std::int64_t numEvents() const { return m_numEvents; }
0394 std::shared_ptr<arrow::Schema> schema() const { return m_publicSchema; }
0395
0396 std::shared_ptr<arrow::Table> readEvent(std::uint64_t eventId) const {
0397 auto builder = unwrap(m_dataset->NewScan(), "new scan");
0398 auto status = builder->Filter(arrow::compute::equal(
0399 arrow::compute::field_ref(kEventIdColumn.data()),
0400 arrow::compute::literal(static_cast<std::uint32_t>(eventId))));
0401 if (!status.ok()) {
0402 throwArrow("set scan filter", status);
0403 }
0404 status = builder->UseThreads(false);
0405 if (!status.ok()) {
0406 throwArrow("set use threads", status);
0407 }
0408 auto scanner = unwrap(builder->Finish(), "finish scanner");
0409 auto table = unwrap(scanner->ToTable(), "scan to table");
0410
0411 const int idx = table->schema()->GetFieldIndex(std::string{kEventIdColumn});
0412 if (idx < 0) {
0413 throw std::runtime_error(
0414 "ParquetDatasetReader: scanned table lacks event_id column");
0415 }
0416 return unwrap(table->RemoveColumn(idx), "drop event_id column");
0417 }
0418
0419 private:
0420 std::filesystem::path m_directory;
0421 std::int64_t m_numEvents = 0;
0422 std::shared_ptr<arrow::dataset::Dataset> m_dataset;
0423 std::shared_ptr<arrow::Schema> m_publicSchema;
0424 };
0425
0426 ParquetDatasetReader::ParquetDatasetReader(
0427 std::filesystem::path directory,
0428 const std::shared_ptr<arrow::Schema>& targetSchema)
0429 : m_impl(std::make_unique<Impl>(std::move(directory), targetSchema)) {}
0430
0431 ParquetDatasetReader::~ParquetDatasetReader() = default;
0432
0433 std::int64_t ParquetDatasetReader::numEvents() const {
0434 return m_impl->numEvents();
0435 }
0436
0437 std::shared_ptr<arrow::Schema> ParquetDatasetReader::schema() const {
0438 return m_impl->schema();
0439 }
0440
0441 std::shared_ptr<arrow::Table> ParquetDatasetReader::readEvent(
0442 std::uint64_t eventId) const {
0443 return m_impl->readEvent(eventId);
0444 }
0445
0446 }