File indexing completed on 2026-06-16 07:48:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include "ActsPlugins/Arrow/ArrowUtil.hpp"
0010 #include "ActsPython/Utilities/WhiteBoardRegistry.hpp"
0011
0012 #include <memory>
0013 #include <stdexcept>
0014
0015 #include <arrow/c/abi.h>
0016 #include <pybind11/pybind11.h>
0017 #include <pybind11/stl.h>
0018
0019 namespace py = pybind11;
0020 using ActsPlugins::ArrowUtil::ArrowSchemaHandle;
0021 using ActsPlugins::ArrowUtil::ArrowTable;
0022 using ActsPython::WhiteBoardRegistry;
0023 namespace ArrowUtil = ActsPlugins::ArrowUtil;
0024
0025 namespace {
0026
0027
0028
0029
0030
0031 void releaseArrowSchemaCapsule(PyObject* capsule) {
0032 auto* c =
0033 static_cast<ArrowSchema*>(PyCapsule_GetPointer(capsule, "arrow_schema"));
0034 if (c == nullptr) {
0035 PyErr_Clear();
0036 return;
0037 }
0038 if (c->release != nullptr) {
0039 c->release(c);
0040 }
0041 delete c;
0042 }
0043
0044 void releaseArrowArrayCapsule(PyObject* capsule) {
0045 auto* c =
0046 static_cast<ArrowArray*>(PyCapsule_GetPointer(capsule, "arrow_array"));
0047 if (c == nullptr) {
0048 PyErr_Clear();
0049 return;
0050 }
0051 if (c->release != nullptr) {
0052 c->release(c);
0053 }
0054 delete c;
0055 }
0056
0057 }
0058
0059 PYBIND11_MODULE(ActsPluginsPythonBindingsArrow, m) {
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071 py::class_<ArrowSchemaHandle>(m, "ArrowSchema")
0072 .def("__repr__",
0073 [](const ArrowSchemaHandle& s) {
0074 return "<ArrowSchema " + s.toString() + ">";
0075 })
0076 .def("__str__", &ArrowSchemaHandle::toString)
0077 .def("field_names", &ArrowSchemaHandle::fieldNames)
0078 .def("__len__", &ArrowSchemaHandle::numFields)
0079
0080
0081 .def("__arrow_c_schema__", [](const ArrowSchemaHandle& self) {
0082 auto* c_schema = new ArrowSchema{};
0083 try {
0084 self.exportToC(c_schema);
0085 } catch (...) {
0086 delete c_schema;
0087 throw;
0088 }
0089 return py::reinterpret_steal<py::object>(
0090 PyCapsule_New(c_schema, "arrow_schema", releaseArrowSchemaCapsule));
0091 });
0092
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103 auto arrowTableClass =
0104 py::class_<ArrowTable, py::smart_holder>(m, "ArrowTable")
0105 .def(py::init<>())
0106 .def("__repr__",
0107 [](const ArrowTable& t) {
0108 return "<ArrowTable " + std::to_string(t.numRows()) +
0109 " rows x " + std::to_string(t.numColumns()) + " cols>";
0110 })
0111 .def("__str__", &ArrowTable::toString)
0112 .def_property_readonly("num_rows", &ArrowTable::numRows)
0113 .def_property_readonly("num_columns", &ArrowTable::numColumns)
0114 .def_property_readonly("schema", &ArrowTable::schema)
0115
0116
0117
0118
0119
0120 .def(
0121 "__arrow_c_array__",
0122 [](const ArrowTable& self,
0123 [[maybe_unused]] const py::object& requested_schema) {
0124 auto* c_schema = new ArrowSchema{};
0125 auto* c_array = new ArrowArray{};
0126 try {
0127 self.exportToC(c_schema, c_array);
0128 } catch (...) {
0129 delete c_schema;
0130 delete c_array;
0131 throw;
0132 }
0133 py::object schemaCap =
0134 py::reinterpret_steal<py::object>(PyCapsule_New(
0135 c_schema, "arrow_schema", releaseArrowSchemaCapsule));
0136 py::object arrayCap =
0137 py::reinterpret_steal<py::object>(PyCapsule_New(
0138 c_array, "arrow_array", releaseArrowArrayCapsule));
0139 return py::make_tuple(std::move(schemaCap),
0140 std::move(arrayCap));
0141 },
0142 py::arg("requested_schema") = py::none())
0143
0144
0145
0146 .def("as_table",
0147 [](py::object self) {
0148 auto pa = py::module_::import("pyarrow");
0149 auto batch = pa.attr("record_batch")(self);
0150 return pa.attr("Table").attr("from_batches")(
0151 py::make_tuple(batch));
0152 })
0153
0154
0155
0156
0157
0158
0159
0160 .def_static(
0161 "from_arrow",
0162 [](const py::object& obj) {
0163 if (py::hasattr(obj, "__arrow_c_array__")) {
0164 py::tuple capsules = obj.attr("__arrow_c_array__")();
0165 if (capsules.size() != 2) {
0166 throw py::type_error(
0167 "__arrow_c_array__ returned a tuple of size " +
0168 std::to_string(capsules.size()) + ", expected 2");
0169 }
0170 auto* sc = static_cast<ArrowSchema*>(
0171 PyCapsule_GetPointer(capsules[0].ptr(), "arrow_schema"));
0172 auto* ar = static_cast<ArrowArray*>(
0173 PyCapsule_GetPointer(capsules[1].ptr(), "arrow_array"));
0174 if (sc == nullptr || ar == nullptr) {
0175 throw py::value_error(
0176 "__arrow_c_array__ returned invalid PyCapsules");
0177 }
0178
0179
0180
0181 return ArrowTable::importFromC(sc, ar);
0182 }
0183 if (py::hasattr(obj, "__arrow_c_stream__")) {
0184
0185
0186
0187
0188 auto pa = py::module_::import("pyarrow");
0189 auto pa_table = pa.attr("table")(obj);
0190 auto combined = pa_table.attr("combine_chunks")();
0191
0192
0193
0194
0195 py::list batches = combined.attr("to_batches")();
0196 if (batches.size() != 1) {
0197 throw py::value_error(
0198 "expected 1 batch after combine_chunks, got " +
0199 std::to_string(batches.size()));
0200 }
0201 py::tuple capsules = batches[0].attr("__arrow_c_array__")();
0202 auto* sc = static_cast<ArrowSchema*>(
0203 PyCapsule_GetPointer(capsules[0].ptr(), "arrow_schema"));
0204 auto* ar = static_cast<ArrowArray*>(
0205 PyCapsule_GetPointer(capsules[1].ptr(), "arrow_array"));
0206 if (sc == nullptr || ar == nullptr) {
0207 throw py::value_error(
0208 "__arrow_c_array__ returned invalid PyCapsules");
0209 }
0210 return ArrowTable::importFromC(sc, ar);
0211 }
0212 throw py::type_error(
0213 "ArrowTable.from_arrow: object does not implement the "
0214 "Arrow C Data Interface (__arrow_c_array__ or "
0215 "__arrow_c_stream__)");
0216 },
0217 py::arg("obj"),
0218 "Build an ArrowTable from any object implementing the Arrow "
0219 "C Data Interface (pyarrow Table/RecordBatch, polars "
0220 "DataFrame, duckdb relation, etc.). Zero-copy via the "
0221 "release-callback wiring; the producer's buffers stay alive "
0222 "until the resulting ArrowTable is destroyed.");
0223
0224 WhiteBoardRegistry::registerClass(arrowTableClass);
0225
0226 auto wrap = [](std::shared_ptr<arrow::Schema> s) {
0227 return ArrowSchemaHandle{std::move(s)};
0228 };
0229 m.def(
0230 "particleSchema", [wrap]() { return wrap(ArrowUtil::particleSchema()); },
0231 "Schema produced by ArrowParticleOutputConverter.");
0232 m.def(
0233 "trackSchema", [wrap]() { return wrap(ArrowUtil::trackSchema()); },
0234 "Schema produced by ArrowTrackOutputConverter.");
0235 m.def(
0236 "simHitSchema", [wrap]() { return wrap(ArrowUtil::simHitSchema()); },
0237 "Schema produced by ArrowSimHitOutputConverter.");
0238 }