File indexing completed on 2026-06-20 07:36:36
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Io/Parquet/ParquetWriter.hpp"
0010
0011 #include "Acts/Utilities/Logger.hpp"
0012 #include "ActsExamples/Framework/DataHandle.hpp"
0013 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0014
0015 #include <cstdint>
0016 #include <format>
0017 #include <map>
0018 #include <memory>
0019 #include <mutex>
0020 #include <stdexcept>
0021 #include <unordered_map>
0022 #include <unordered_set>
0023 #include <vector>
0024
0025 #include <arrow/api.h>
0026
0027 namespace ActsExamples {
0028
0029 class ParquetWriter::Impl {
0030 public:
0031 using TableHandle = ConsumeDataHandle<ActsPlugins::ArrowUtil::ArrowTable>;
0032
0033 struct ShardState {
0034 std::filesystem::path path;
0035 std::mutex mutex;
0036 ActsPlugins::ArrowUtil::ParquetFileWriter writer;
0037 std::vector<std::shared_ptr<arrow::Table>> buffer;
0038 std::uint64_t eventsAccepted = 0;
0039
0040 explicit ShardState(std::filesystem::path p)
0041 : path(std::move(p)), writer(path) {}
0042 };
0043
0044 struct CollectionState {
0045 std::string name;
0046 std::filesystem::path directory;
0047 std::string filePrefix;
0048 std::shared_ptr<arrow::Schema> expectedSchema;
0049 std::unique_ptr<TableHandle> handle;
0050 std::mutex shardsMutex;
0051 std::map<std::uint64_t, std::unique_ptr<ShardState>> shards;
0052 };
0053
0054 Impl(const ParquetWriter::Config& config, ParquetWriter& parent)
0055 : m_cfg(config) {
0056 if (m_cfg.collections.empty()) {
0057 throw std::invalid_argument("ParquetWriter: no collections configured");
0058 }
0059 if (m_cfg.eventsPerShard == 0) {
0060 throw std::invalid_argument("ParquetWriter: eventsPerShard must be > 0");
0061 }
0062 if (m_cfg.maxOpenShards == 0) {
0063 throw std::invalid_argument("ParquetWriter: maxOpenShards must be > 0");
0064 }
0065
0066 if (m_cfg.eventsPerRowGroup == 0) {
0067 m_cfg.eventsPerRowGroup = m_cfg.eventsPerShard;
0068 }
0069 if (m_cfg.eventsPerRowGroup > m_cfg.eventsPerShard) {
0070 throw std::invalid_argument(std::format(
0071 "ParquetWriter: eventsPerRowGroup ({}) must be <= eventsPerShard "
0072 "({})",
0073 m_cfg.eventsPerRowGroup, m_cfg.eventsPerShard));
0074 }
0075
0076 std::unordered_set<std::string> seenDirs;
0077 for (const auto& [name, rawPath] : m_cfg.collections) {
0078 if (name.empty()) {
0079 throw std::invalid_argument("ParquetWriter: empty collection name");
0080 }
0081 if (rawPath.empty()) {
0082 throw std::invalid_argument(std::format(
0083 "ParquetWriter: empty output directory for collection '{}'", name));
0084 }
0085
0086 std::filesystem::path resolved =
0087 rawPath.is_absolute() ? rawPath : m_cfg.outputDir / rawPath;
0088 if (!seenDirs.insert(resolved.lexically_normal().string()).second) {
0089 throw std::invalid_argument(
0090 std::format("ParquetWriter: duplicate output directory '{}'",
0091 resolved.string()));
0092 }
0093
0094
0095
0096
0097
0098 auto trimmed = resolved;
0099 while (!trimmed.empty() && trimmed.filename().empty()) {
0100 trimmed = trimmed.parent_path();
0101 }
0102 std::string filePrefix = trimmed.filename().stem().string();
0103 if (filePrefix.empty()) {
0104 throw std::invalid_argument(std::format(
0105 "ParquetWriter: output directory '{}' for collection '{}' has "
0106 "no usable filename stem to derive a shard prefix from",
0107 resolved.string(), name));
0108 }
0109
0110 auto schemaIt = m_cfg.expectedSchemas.find(name);
0111 if (schemaIt == m_cfg.expectedSchemas.end() || !schemaIt->second) {
0112 throw std::invalid_argument(std::format(
0113 "ParquetWriter: collection '{}' has no expected schema. Every "
0114 "configured collection must declare an expected schema; the "
0115 "writer validates each per-event table against it before "
0116 "stamping event_id and serialising.",
0117 name));
0118 }
0119 auto expected = schemaIt->second.schema();
0120 if (expected->GetFieldIndex(
0121 std::string{ActsPlugins::ArrowUtil::kEventIdColumn}) >= 0) {
0122 throw std::invalid_argument(std::format(
0123 "ParquetWriter: expected schema for '{}' must not contain "
0124 "event_id; the writer prepends it.",
0125 name));
0126 }
0127
0128 std::filesystem::create_directories(resolved);
0129
0130 auto state = std::make_unique<CollectionState>();
0131 state->name = name;
0132 state->directory = std::move(resolved);
0133 state->filePrefix = std::move(filePrefix);
0134 state->expectedSchema = std::move(expected);
0135 state->handle = std::make_unique<TableHandle>(&parent, name);
0136 state->handle->initialize(name);
0137 m_collectionStates.push_back(std::move(state));
0138 }
0139 for (const auto& [name, _] : m_cfg.expectedSchemas) {
0140 if (!m_cfg.collections.contains(name)) {
0141 throw std::invalid_argument(std::format(
0142 "ParquetWriter: expectedSchemas has entry for '{}' but no "
0143 "matching collection",
0144 name));
0145 }
0146 }
0147 }
0148
0149
0150
0151 void flushBuffer(ShardState& shard) {
0152 if (shard.buffer.empty()) {
0153 return;
0154 }
0155 auto result = arrow::ConcatenateTables(shard.buffer);
0156 if (!result.ok()) {
0157 throw std::runtime_error(
0158 std::format("ParquetWriter concat for shard '{}': {}",
0159 shard.path.string(), result.status().ToString()));
0160 }
0161 shard.writer.write(*result.ValueOrDie());
0162 shard.buffer.clear();
0163 }
0164
0165 std::filesystem::path shardPath(const CollectionState& state,
0166 std::uint64_t shardId) const {
0167
0168
0169
0170
0171
0172 const std::uint64_t startEvent = shardId * m_cfg.eventsPerShard;
0173 const std::uint64_t endEvent = startEvent + m_cfg.eventsPerShard;
0174 return state.directory / std::format("{}_{:06}-{:06}.parquet",
0175 state.filePrefix, startEvent,
0176 endEvent);
0177 }
0178
0179 ParquetWriter::Config m_cfg;
0180 std::vector<std::unique_ptr<CollectionState>> m_collectionStates;
0181 };
0182
0183 ParquetWriter::ParquetWriter(const Config& config,
0184 std::unique_ptr<const Acts::Logger> logger)
0185 : m_logger(std::move(logger)),
0186 m_impl(std::make_unique<Impl>(config, *this)) {}
0187
0188 ParquetWriter::ParquetWriter(const Config& config, Acts::Logging::Level level)
0189 : ParquetWriter(config, Acts::getDefaultLogger("ParquetWriter", level)) {}
0190
0191 ParquetWriter::~ParquetWriter() = default;
0192
0193 std::string ParquetWriter::name() const {
0194 return "ParquetWriter";
0195 }
0196
0197 const ParquetWriter::Config& ParquetWriter::config() const {
0198 return m_impl->m_cfg;
0199 }
0200
0201 ProcessCode ParquetWriter::write(const AlgorithmContext& ctx) {
0202 using ShardState = Impl::ShardState;
0203
0204 const std::uint64_t shardId = ctx.eventNumber / m_impl->m_cfg.eventsPerShard;
0205
0206 for (const auto& state : m_impl->m_collectionStates) {
0207 auto handle = (*state->handle)(ctx);
0208 if (!handle) {
0209 ACTS_ERROR("ParquetWriter: null table for collection " << state->name);
0210 return ProcessCode::ABORT;
0211 }
0212 const auto& tableSchema = *handle.table()->schema();
0213 if (!tableSchema.Equals(*state->expectedSchema, false)) {
0214 ACTS_ERROR("ParquetWriter: schema mismatch for collection '"
0215 << state->name << "' at event " << ctx.eventNumber
0216 << ".\n expected: " << state->expectedSchema->ToString()
0217 << "\n actual: " << tableSchema.ToString());
0218 return ProcessCode::ABORT;
0219 }
0220 auto stamped = ActsPlugins::ArrowUtil::withEventId(
0221 handle.table(), static_cast<std::uint64_t>(ctx.eventNumber));
0222
0223 ShardState* shard = nullptr;
0224 {
0225 std::lock_guard<std::mutex> guard(state->shardsMutex);
0226 auto it = state->shards.find(shardId);
0227 if (it == state->shards.end()) {
0228 if (state->shards.size() >= m_impl->m_cfg.maxOpenShards) {
0229 std::string openIds;
0230 for (const auto& [id, _] : state->shards) {
0231 if (!openIds.empty()) {
0232 openIds += ", ";
0233 }
0234 openIds += std::to_string(id);
0235 }
0236 throw std::runtime_error(std::format(
0237 "ParquetWriter: collection '{}' would exceed maxOpenShards={} "
0238 "(currently open: [{}], requested shard: {}). This usually "
0239 "means a worker thread is far behind the event-id frontier.",
0240 state->name, m_impl->m_cfg.maxOpenShards, openIds, shardId));
0241 }
0242 auto created =
0243 std::make_unique<ShardState>(m_impl->shardPath(*state, shardId));
0244 shard = created.get();
0245 state->shards.emplace(shardId, std::move(created));
0246 } else {
0247 shard = it->second.get();
0248 }
0249 }
0250
0251 bool full = false;
0252 {
0253 std::lock_guard<std::mutex> guard(shard->mutex);
0254 shard->buffer.push_back(std::move(stamped));
0255 shard->eventsAccepted += 1;
0256 full = (shard->eventsAccepted >= m_impl->m_cfg.eventsPerShard);
0257 if (full) {
0258 m_impl->flushBuffer(*shard);
0259 shard->writer.close();
0260 } else if (shard->buffer.size() >= m_impl->m_cfg.eventsPerRowGroup) {
0261 m_impl->flushBuffer(*shard);
0262 }
0263 }
0264
0265 if (full) {
0266 std::lock_guard<std::mutex> guard(state->shardsMutex);
0267 state->shards.erase(shardId);
0268 }
0269 }
0270
0271 return ProcessCode::SUCCESS;
0272 }
0273
0274 ProcessCode ParquetWriter::finalize() {
0275 for (const auto& state : m_impl->m_collectionStates) {
0276 std::lock_guard<std::mutex> guard(state->shardsMutex);
0277 for (auto& [shardId, shard] : state->shards) {
0278 std::lock_guard<std::mutex> sguard(shard->mutex);
0279 m_impl->flushBuffer(*shard);
0280 shard->writer.close();
0281 }
0282 state->shards.clear();
0283 }
0284 return ProcessCode::SUCCESS;
0285 }
0286
0287 }