File indexing completed on 2025-08-28 08:12:25
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsExamples/Io/Csv/CsvGnnGraphWriter.hpp"
0010
0011 #include "Acts/Definitions/Algebra.hpp"
0012 #include "Acts/Definitions/Common.hpp"
0013 #include "Acts/Definitions/Units.hpp"
0014 #include "ActsExamples/Framework/AlgorithmContext.hpp"
0015 #include "ActsExamples/Io/Csv/CsvInputOutput.hpp"
0016 #include "ActsExamples/Utilities/Paths.hpp"
0017 #include "ActsFatras/EventData/Barcode.hpp"
0018
0019 #include <stdexcept>
0020 #include <vector>
0021
0022 #include "CsvOutputData.hpp"
0023
0024 ActsExamples::CsvGnnGraphWriter::CsvGnnGraphWriter(
0025 const ActsExamples::CsvGnnGraphWriter::Config& config,
0026 Acts::Logging::Level level)
0027 : WriterT(config.inputGraph, "CsvGnnGraphWriter", level), m_cfg(config) {}
0028
0029 ActsExamples::ProcessCode ActsExamples::CsvGnnGraphWriter::writeT(
0030 const ActsExamples::AlgorithmContext& ctx, const Graph& graph) {
0031 assert(graph.weights.empty() ||
0032 (graph.edges.size() / 2 == graph.weights.size()));
0033 assert(graph.edges.size() % 2 == 0);
0034
0035 if (graph.weights.empty()) {
0036 ACTS_DEBUG("No weights provide, write default value of 1");
0037 }
0038
0039 std::string path = perEventFilepath(
0040 m_cfg.outputDir, m_cfg.outputStem + ".csv", ctx.eventNumber);
0041
0042 ActsExamples::NamedTupleCsvWriter<GraphData> writer(path);
0043
0044 const auto nEdges = graph.edges.size() / 2;
0045 for (auto i = 0ul; i < nEdges; ++i) {
0046 GraphData edge{};
0047 edge.edge0 = graph.edges[2 * i];
0048 edge.edge1 = graph.edges[2 * i + 1];
0049 edge.weight = graph.weights.empty() ? 1.f : graph.weights[i];
0050 writer.append(edge);
0051 }
0052
0053 return ActsExamples::ProcessCode::SUCCESS;
0054 }