File indexing completed on 2025-01-18 10:17:51
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #include <pybind11/stl.h>
0011
0012 #include "constructor_stats.h"
0013 #include "pybind11_tests.h"
0014
0015 TEST_SUBMODULE(buffers, m) {
0016
0017 class Matrix {
0018 public:
0019 Matrix(py::ssize_t rows, py::ssize_t cols) : m_rows(rows), m_cols(cols) {
0020 print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
0021
0022 m_data = new float[(size_t) (rows * cols)];
0023 memset(m_data, 0, sizeof(float) * (size_t) (rows * cols));
0024 }
0025
0026 Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) {
0027 print_copy_created(this,
0028 std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
0029
0030 m_data = new float[(size_t) (m_rows * m_cols)];
0031 memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols));
0032 }
0033
0034 Matrix(Matrix &&s) noexcept : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) {
0035 print_move_created(this);
0036 s.m_rows = 0;
0037 s.m_cols = 0;
0038 s.m_data = nullptr;
0039 }
0040
0041 ~Matrix() {
0042 print_destroyed(this,
0043 std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
0044 delete[] m_data;
0045 }
0046
0047 Matrix &operator=(const Matrix &s) {
0048 if (this == &s) {
0049 return *this;
0050 }
0051 print_copy_assigned(this,
0052 std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
0053 delete[] m_data;
0054 m_rows = s.m_rows;
0055 m_cols = s.m_cols;
0056 m_data = new float[(size_t) (m_rows * m_cols)];
0057 memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols));
0058 return *this;
0059 }
0060
0061 Matrix &operator=(Matrix &&s) noexcept {
0062 print_move_assigned(this,
0063 std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix");
0064 if (&s != this) {
0065 delete[] m_data;
0066 m_rows = s.m_rows;
0067 m_cols = s.m_cols;
0068 m_data = s.m_data;
0069 s.m_rows = 0;
0070 s.m_cols = 0;
0071 s.m_data = nullptr;
0072 }
0073 return *this;
0074 }
0075
0076 float operator()(py::ssize_t i, py::ssize_t j) const {
0077 return m_data[(size_t) (i * m_cols + j)];
0078 }
0079
0080 float &operator()(py::ssize_t i, py::ssize_t j) {
0081 return m_data[(size_t) (i * m_cols + j)];
0082 }
0083
0084 float *data() { return m_data; }
0085
0086 py::ssize_t rows() const { return m_rows; }
0087 py::ssize_t cols() const { return m_cols; }
0088
0089 private:
0090 py::ssize_t m_rows;
0091 py::ssize_t m_cols;
0092 float *m_data;
0093 };
0094 py::class_<Matrix>(m, "Matrix", py::buffer_protocol())
0095 .def(py::init<py::ssize_t, py::ssize_t>())
0096
0097 .def(py::init([](const py::buffer &b) {
0098 py::buffer_info info = b.request();
0099 if (info.format != py::format_descriptor<float>::format() || info.ndim != 2) {
0100 throw std::runtime_error("Incompatible buffer format!");
0101 }
0102
0103 auto *v = new Matrix(info.shape[0], info.shape[1]);
0104 memcpy(v->data(), info.ptr, sizeof(float) * (size_t) (v->rows() * v->cols()));
0105 return v;
0106 }))
0107
0108 .def("rows", &Matrix::rows)
0109 .def("cols", &Matrix::cols)
0110
0111
0112 .def("__getitem__",
0113 [](const Matrix &m, std::pair<py::ssize_t, py::ssize_t> i) {
0114 if (i.first >= m.rows() || i.second >= m.cols()) {
0115 throw py::index_error();
0116 }
0117 return m(i.first, i.second);
0118 })
0119 .def("__setitem__",
0120 [](Matrix &m, std::pair<py::ssize_t, py::ssize_t> i, float v) {
0121 if (i.first >= m.rows() || i.second >= m.cols()) {
0122 throw py::index_error();
0123 }
0124 m(i.first, i.second) = v;
0125 })
0126
0127 .def_buffer([](Matrix &m) -> py::buffer_info {
0128 return py::buffer_info(
0129 m.data(),
0130 {m.rows(), m.cols()},
0131 {sizeof(float) * size_t(m.cols()),
0132 sizeof(float)});
0133 });
0134
0135
0136 class SquareMatrix : public Matrix {
0137 public:
0138 explicit SquareMatrix(py::ssize_t n) : Matrix(n, n) {}
0139 };
0140
0141 py::class_<SquareMatrix, Matrix>(m, "SquareMatrix").def(py::init<py::ssize_t>());
0142
0143
0144
0145
0146 struct Buffer {
0147 int32_t value = 0;
0148
0149 py::buffer_info get_buffer_info() {
0150 return py::buffer_info(
0151 &value, sizeof(value), py::format_descriptor<int32_t>::format(), 1);
0152 }
0153 };
0154 py::class_<Buffer>(m, "Buffer", py::buffer_protocol())
0155 .def(py::init<>())
0156 .def_readwrite("value", &Buffer::value)
0157 .def_buffer(&Buffer::get_buffer_info);
0158
0159 class ConstBuffer {
0160 std::unique_ptr<int32_t> value;
0161
0162 public:
0163 int32_t get_value() const { return *value; }
0164 void set_value(int32_t v) { *value = v; }
0165
0166 py::buffer_info get_buffer_info() const {
0167 return py::buffer_info(
0168 value.get(), sizeof(*value), py::format_descriptor<int32_t>::format(), 1);
0169 }
0170
0171 ConstBuffer() : value(new int32_t{0}) {}
0172 };
0173 py::class_<ConstBuffer>(m, "ConstBuffer", py::buffer_protocol())
0174 .def(py::init<>())
0175 .def_property("value", &ConstBuffer::get_value, &ConstBuffer::set_value)
0176 .def_buffer(&ConstBuffer::get_buffer_info);
0177
0178 struct DerivedBuffer : public Buffer {};
0179 py::class_<DerivedBuffer>(m, "DerivedBuffer", py::buffer_protocol())
0180 .def(py::init<>())
0181 .def_readwrite("value", (int32_t DerivedBuffer::*) &DerivedBuffer::value)
0182 .def_buffer(&DerivedBuffer::get_buffer_info);
0183
0184 struct BufferReadOnly {
0185 const uint8_t value = 0;
0186 explicit BufferReadOnly(uint8_t value) : value(value) {}
0187
0188 py::buffer_info get_buffer_info() { return py::buffer_info(&value, 1); }
0189 };
0190 py::class_<BufferReadOnly>(m, "BufferReadOnly", py::buffer_protocol())
0191 .def(py::init<uint8_t>())
0192 .def_buffer(&BufferReadOnly::get_buffer_info);
0193
0194 struct BufferReadOnlySelect {
0195 uint8_t value = 0;
0196 bool readonly = false;
0197
0198 py::buffer_info get_buffer_info() { return py::buffer_info(&value, 1, readonly); }
0199 };
0200 py::class_<BufferReadOnlySelect>(m, "BufferReadOnlySelect", py::buffer_protocol())
0201 .def(py::init<>())
0202 .def_readwrite("value", &BufferReadOnlySelect::value)
0203 .def_readwrite("readonly", &BufferReadOnlySelect::readonly)
0204 .def_buffer(&BufferReadOnlySelect::get_buffer_info);
0205
0206
0207 py::class_<py::buffer_info>(m, "buffer_info")
0208 .def(py::init<>())
0209 .def_readonly("itemsize", &py::buffer_info::itemsize)
0210 .def_readonly("size", &py::buffer_info::size)
0211 .def_readonly("format", &py::buffer_info::format)
0212 .def_readonly("ndim", &py::buffer_info::ndim)
0213 .def_readonly("shape", &py::buffer_info::shape)
0214 .def_readonly("strides", &py::buffer_info::strides)
0215 .def_readonly("readonly", &py::buffer_info::readonly)
0216 .def("__repr__", [](py::handle self) {
0217 return py::str("itemsize={0.itemsize!r}, size={0.size!r}, format={0.format!r}, "
0218 "ndim={0.ndim!r}, shape={0.shape!r}, strides={0.strides!r}, "
0219 "readonly={0.readonly!r}")
0220 .format(self);
0221 });
0222
0223 m.def("get_buffer_info", [](const py::buffer &buffer) { return buffer.request(); });
0224 }