Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-06-20 07:36:36

0001 // This file is part of the ACTS project.
0002 //
0003 // Copyright (C) 2016 CERN for the benefit of the ACTS project
0004 //
0005 // This Source Code Form is subject to the terms of the Mozilla Public
0006 // License, v. 2.0. If a copy of the MPL was not distributed with this
0007 // file, You can obtain one at https://mozilla.org/MPL/2.0/.
0008 
0009 #include "ActsExamples/Io/Parquet/ParquetWriter.hpp"
0010 
0011 #include "Acts/Utilities/Logger.hpp"
0012 #include "ActsExamples/Framework/DataHandle.hpp"
0013 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0014 
0015 #include <cstdint>
0016 #include <format>
0017 #include <map>
0018 #include <memory>
0019 #include <mutex>
0020 #include <stdexcept>
0021 #include <unordered_map>
0022 #include <unordered_set>
0023 #include <vector>
0024 
0025 #include <arrow/api.h>
0026 
0027 namespace ActsExamples {
0028 
0029 class ParquetWriter::Impl {
0030  public:
0031   using TableHandle = ConsumeDataHandle<ActsPlugins::ArrowUtil::ArrowTable>;
0032 
0033   struct ShardState {
0034     std::filesystem::path path;
0035     std::mutex mutex;
0036     ActsPlugins::ArrowUtil::ParquetFileWriter writer;
0037     std::vector<std::shared_ptr<arrow::Table>> buffer;
0038     std::uint64_t eventsAccepted = 0;
0039 
0040     explicit ShardState(std::filesystem::path p)
0041         : path(std::move(p)), writer(path) {}
0042   };
0043 
0044   struct CollectionState {
0045     std::string name;
0046     std::filesystem::path directory;
0047     std::string filePrefix;
0048     std::shared_ptr<arrow::Schema> expectedSchema;
0049     std::unique_ptr<TableHandle> handle;
0050     std::mutex shardsMutex;
0051     std::map<std::uint64_t, std::unique_ptr<ShardState>> shards;
0052   };
0053 
0054   Impl(const ParquetWriter::Config& config, ParquetWriter& parent)
0055       : m_cfg(config) {
0056     if (m_cfg.collections.empty()) {
0057       throw std::invalid_argument("ParquetWriter: no collections configured");
0058     }
0059     if (m_cfg.eventsPerShard == 0) {
0060       throw std::invalid_argument("ParquetWriter: eventsPerShard must be > 0");
0061     }
0062     if (m_cfg.maxOpenShards == 0) {
0063       throw std::invalid_argument("ParquetWriter: maxOpenShards must be > 0");
0064     }
0065     // Sentinel: 0 means "one row group per shard".
0066     if (m_cfg.eventsPerRowGroup == 0) {
0067       m_cfg.eventsPerRowGroup = m_cfg.eventsPerShard;
0068     }
0069     if (m_cfg.eventsPerRowGroup > m_cfg.eventsPerShard) {
0070       throw std::invalid_argument(std::format(
0071           "ParquetWriter: eventsPerRowGroup ({}) must be <= eventsPerShard "
0072           "({})",
0073           m_cfg.eventsPerRowGroup, m_cfg.eventsPerShard));
0074     }
0075 
0076     std::unordered_set<std::string> seenDirs;
0077     for (const auto& [name, rawPath] : m_cfg.collections) {
0078       if (name.empty()) {
0079         throw std::invalid_argument("ParquetWriter: empty collection name");
0080       }
0081       if (rawPath.empty()) {
0082         throw std::invalid_argument(std::format(
0083             "ParquetWriter: empty output directory for collection '{}'", name));
0084       }
0085 
0086       std::filesystem::path resolved =
0087           rawPath.is_absolute() ? rawPath : m_cfg.outputDir / rawPath;
0088       if (!seenDirs.insert(resolved.lexically_normal().string()).second) {
0089         throw std::invalid_argument(
0090             std::format("ParquetWriter: duplicate output directory '{}'",
0091                         resolved.string()));
0092       }
0093 
0094       // Derive a file-name prefix from the configured per-collection
0095       // directory: e.g. `<outputDir>/particles.parquet/` → `particles`,
0096       // `<outputDir>/particles/` → `particles`. Walk past empty trailing
0097       // components so paths with trailing separators still decompose.
0098       auto trimmed = resolved;
0099       while (!trimmed.empty() && trimmed.filename().empty()) {
0100         trimmed = trimmed.parent_path();
0101       }
0102       std::string filePrefix = trimmed.filename().stem().string();
0103       if (filePrefix.empty()) {
0104         throw std::invalid_argument(std::format(
0105             "ParquetWriter: output directory '{}' for collection '{}' has "
0106             "no usable filename stem to derive a shard prefix from",
0107             resolved.string(), name));
0108       }
0109 
0110       auto schemaIt = m_cfg.expectedSchemas.find(name);
0111       if (schemaIt == m_cfg.expectedSchemas.end() || !schemaIt->second) {
0112         throw std::invalid_argument(std::format(
0113             "ParquetWriter: collection '{}' has no expected schema. Every "
0114             "configured collection must declare an expected schema; the "
0115             "writer validates each per-event table against it before "
0116             "stamping event_id and serialising.",
0117             name));
0118       }
0119       auto expected = schemaIt->second.schema();
0120       if (expected->GetFieldIndex(
0121               std::string{ActsPlugins::ArrowUtil::kEventIdColumn}) >= 0) {
0122         throw std::invalid_argument(std::format(
0123             "ParquetWriter: expected schema for '{}' must not contain "
0124             "event_id; the writer prepends it.",
0125             name));
0126       }
0127 
0128       std::filesystem::create_directories(resolved);
0129 
0130       auto state = std::make_unique<CollectionState>();
0131       state->name = name;
0132       state->directory = std::move(resolved);
0133       state->filePrefix = std::move(filePrefix);
0134       state->expectedSchema = std::move(expected);
0135       state->handle = std::make_unique<TableHandle>(&parent, name);
0136       state->handle->initialize(name);
0137       m_collectionStates.push_back(std::move(state));
0138     }
0139     for (const auto& [name, _] : m_cfg.expectedSchemas) {
0140       if (!m_cfg.collections.contains(name)) {
0141         throw std::invalid_argument(std::format(
0142             "ParquetWriter: expectedSchemas has entry for '{}' but no "
0143             "matching collection",
0144             name));
0145       }
0146     }
0147   }
0148 
0149   /// Concatenate the shard's buffered tables and write them as a single
0150   /// row group. Caller must hold @c shard->mutex.
0151   void flushBuffer(ShardState& shard) {
0152     if (shard.buffer.empty()) {
0153       return;
0154     }
0155     auto result = arrow::ConcatenateTables(shard.buffer);
0156     if (!result.ok()) {
0157       throw std::runtime_error(
0158           std::format("ParquetWriter concat for shard '{}': {}",
0159                       shard.path.string(), result.status().ToString()));
0160     }
0161     shard.writer.write(*result.ValueOrDie());
0162     shard.buffer.clear();
0163   }
0164 
0165   std::filesystem::path shardPath(const CollectionState& state,
0166                                   std::uint64_t shardId) const {
0167     // Encode the *planned* event window covered by this shard: shardId is
0168     // assigned by `eventNumber / eventsPerShard`, so this shard owns
0169     // events [startEvent, endEvent). The final shard of a job may contain
0170     // fewer events than the name implies; consumers should read the file
0171     // and trust its row count, not the filename.
0172     const std::uint64_t startEvent = shardId * m_cfg.eventsPerShard;
0173     const std::uint64_t endEvent = startEvent + m_cfg.eventsPerShard;
0174     return state.directory / std::format("{}_{:06}-{:06}.parquet",
0175                                          state.filePrefix, startEvent,
0176                                          endEvent);
0177   }
0178 
0179   ParquetWriter::Config m_cfg;
0180   std::vector<std::unique_ptr<CollectionState>> m_collectionStates;
0181 };
0182 
0183 ParquetWriter::ParquetWriter(const Config& config,
0184                              std::unique_ptr<const Acts::Logger> logger)
0185     : m_logger(std::move(logger)),
0186       m_impl(std::make_unique<Impl>(config, *this)) {}
0187 
0188 ParquetWriter::ParquetWriter(const Config& config, Acts::Logging::Level level)
0189     : ParquetWriter(config, Acts::getDefaultLogger("ParquetWriter", level)) {}
0190 
0191 ParquetWriter::~ParquetWriter() = default;
0192 
0193 std::string ParquetWriter::name() const {
0194   return "ParquetWriter";
0195 }
0196 
0197 const ParquetWriter::Config& ParquetWriter::config() const {
0198   return m_impl->m_cfg;
0199 }
0200 
0201 ProcessCode ParquetWriter::write(const AlgorithmContext& ctx) {
0202   using ShardState = Impl::ShardState;
0203 
0204   const std::uint64_t shardId = ctx.eventNumber / m_impl->m_cfg.eventsPerShard;
0205 
0206   for (const auto& state : m_impl->m_collectionStates) {
0207     auto handle = (*state->handle)(ctx);
0208     if (!handle) {
0209       ACTS_ERROR("ParquetWriter: null table for collection " << state->name);
0210       return ProcessCode::ABORT;
0211     }
0212     const auto& tableSchema = *handle.table()->schema();
0213     if (!tableSchema.Equals(*state->expectedSchema, /*check_metadata=*/false)) {
0214       ACTS_ERROR("ParquetWriter: schema mismatch for collection '"
0215                  << state->name << "' at event " << ctx.eventNumber
0216                  << ".\n  expected: " << state->expectedSchema->ToString()
0217                  << "\n  actual:   " << tableSchema.ToString());
0218       return ProcessCode::ABORT;
0219     }
0220     auto stamped = ActsPlugins::ArrowUtil::withEventId(
0221         handle.table(), static_cast<std::uint64_t>(ctx.eventNumber));
0222 
0223     ShardState* shard = nullptr;
0224     {
0225       std::lock_guard<std::mutex> guard(state->shardsMutex);
0226       auto it = state->shards.find(shardId);
0227       if (it == state->shards.end()) {
0228         if (state->shards.size() >= m_impl->m_cfg.maxOpenShards) {
0229           std::string openIds;
0230           for (const auto& [id, _] : state->shards) {
0231             if (!openIds.empty()) {
0232               openIds += ", ";
0233             }
0234             openIds += std::to_string(id);
0235           }
0236           throw std::runtime_error(std::format(
0237               "ParquetWriter: collection '{}' would exceed maxOpenShards={} "
0238               "(currently open: [{}], requested shard: {}). This usually "
0239               "means a worker thread is far behind the event-id frontier.",
0240               state->name, m_impl->m_cfg.maxOpenShards, openIds, shardId));
0241         }
0242         auto created =
0243             std::make_unique<ShardState>(m_impl->shardPath(*state, shardId));
0244         shard = created.get();
0245         state->shards.emplace(shardId, std::move(created));
0246       } else {
0247         shard = it->second.get();
0248       }
0249     }
0250 
0251     bool full = false;
0252     {
0253       std::lock_guard<std::mutex> guard(shard->mutex);
0254       shard->buffer.push_back(std::move(stamped));
0255       shard->eventsAccepted += 1;
0256       full = (shard->eventsAccepted >= m_impl->m_cfg.eventsPerShard);
0257       if (full) {
0258         m_impl->flushBuffer(*shard);
0259         shard->writer.close();
0260       } else if (shard->buffer.size() >= m_impl->m_cfg.eventsPerRowGroup) {
0261         m_impl->flushBuffer(*shard);
0262       }
0263     }
0264 
0265     if (full) {
0266       std::lock_guard<std::mutex> guard(state->shardsMutex);
0267       state->shards.erase(shardId);
0268     }
0269   }
0270 
0271   return ProcessCode::SUCCESS;
0272 }
0273 
0274 ProcessCode ParquetWriter::finalize() {
0275   for (const auto& state : m_impl->m_collectionStates) {
0276     std::lock_guard<std::mutex> guard(state->shardsMutex);
0277     for (auto& [shardId, shard] : state->shards) {
0278       std::lock_guard<std::mutex> sguard(shard->mutex);
0279       m_impl->flushBuffer(*shard);
0280       shard->writer.close();
0281     }
0282     state->shards.clear();
0283   }
0284   return ProcessCode::SUCCESS;
0285 }
0286 
0287 }  // namespace ActsExamples