File indexing completed on 2025-01-18 10:17:55
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #include <pybind11/numpy.h>
0011 #include <pybind11/stl.h>
0012
0013 #include "pybind11_tests.h"
0014
0015 #include <cstdint>
0016 #include <utility>
0017
0018
0019 struct DtypeCheck {
0020 py::dtype numpy{};
0021 py::dtype pybind11{};
0022 };
0023
0024 template <typename T>
0025 DtypeCheck get_dtype_check(const char *name) {
0026 py::module_ np = py::module_::import("numpy");
0027 DtypeCheck check{};
0028 check.numpy = np.attr("dtype")(np.attr(name));
0029 check.pybind11 = py::dtype::of<T>();
0030 return check;
0031 }
0032
0033 std::vector<DtypeCheck> get_concrete_dtype_checks() {
0034 return {
0035 get_dtype_check<std::int8_t>("int8"),
0036 get_dtype_check<std::uint8_t>("uint8"),
0037 get_dtype_check<std::int16_t>("int16"),
0038 get_dtype_check<std::uint16_t>("uint16"),
0039 get_dtype_check<std::int32_t>("int32"),
0040 get_dtype_check<std::uint32_t>("uint32"),
0041 get_dtype_check<std::int64_t>("int64"),
0042 get_dtype_check<std::uint64_t>("uint64")};
0043 }
0044
0045 struct DtypeSizeCheck {
0046 std::string name{};
0047 int size_cpp{};
0048 int size_numpy{};
0049
0050 py::dtype dtype{};
0051 };
0052
0053 template <typename T>
0054 DtypeSizeCheck get_dtype_size_check() {
0055 DtypeSizeCheck check{};
0056 check.name = py::type_id<T>();
0057 check.size_cpp = sizeof(T);
0058 check.dtype = py::dtype::of<T>();
0059 check.size_numpy = check.dtype.attr("itemsize").template cast<int>();
0060 return check;
0061 }
0062
0063 std::vector<DtypeSizeCheck> get_platform_dtype_size_checks() {
0064 return {
0065 get_dtype_size_check<short>(),
0066 get_dtype_size_check<unsigned short>(),
0067 get_dtype_size_check<int>(),
0068 get_dtype_size_check<unsigned int>(),
0069 get_dtype_size_check<long>(),
0070 get_dtype_size_check<unsigned long>(),
0071 get_dtype_size_check<long long>(),
0072 get_dtype_size_check<unsigned long long>(),
0073 };
0074 }
0075
0076
0077 using arr = py::array;
0078 using arr_t = py::array_t<uint16_t, 0>;
0079 static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
0080
0081 template <typename... Ix>
0082 arr data(const arr &a, Ix... index) {
0083 return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
0084 }
0085
0086 template <typename... Ix>
0087 arr data_t(const arr_t &a, Ix... index) {
0088 return arr(a.size() - a.index_at(index...), a.data(index...));
0089 }
0090
0091 template <typename... Ix>
0092 arr &mutate_data(arr &a, Ix... index) {
0093 auto *ptr = (uint8_t *) a.mutable_data(index...);
0094 for (py::ssize_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) {
0095 ptr[i] = (uint8_t) (ptr[i] * 2);
0096 }
0097 return a;
0098 }
0099
0100 template <typename... Ix>
0101 arr_t &mutate_data_t(arr_t &a, Ix... index) {
0102 auto ptr = a.mutable_data(index...);
0103 for (py::ssize_t i = 0; i < a.size() - a.index_at(index...); i++) {
0104 ptr[i]++;
0105 }
0106 return a;
0107 }
0108
0109 template <typename... Ix>
0110 py::ssize_t index_at(const arr &a, Ix... idx) {
0111 return a.index_at(idx...);
0112 }
0113 template <typename... Ix>
0114 py::ssize_t index_at_t(const arr_t &a, Ix... idx) {
0115 return a.index_at(idx...);
0116 }
0117 template <typename... Ix>
0118 py::ssize_t offset_at(const arr &a, Ix... idx) {
0119 return a.offset_at(idx...);
0120 }
0121 template <typename... Ix>
0122 py::ssize_t offset_at_t(const arr_t &a, Ix... idx) {
0123 return a.offset_at(idx...);
0124 }
0125 template <typename... Ix>
0126 py::ssize_t at_t(const arr_t &a, Ix... idx) {
0127 return a.at(idx...);
0128 }
0129 template <typename... Ix>
0130 arr_t &mutate_at_t(arr_t &a, Ix... idx) {
0131 a.mutable_at(idx...)++;
0132 return a;
0133 }
0134
0135 #define def_index_fn(name, type) \
0136 sm.def(#name, [](type a) { return name(a); }); \
0137 sm.def(#name, [](type a, int i) { return name(a, i); }); \
0138 sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \
0139 sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
0140
0141 template <typename T, typename T2>
0142 py::handle auxiliaries(T &&r, T2 &&r2) {
0143 if (r.ndim() != 2) {
0144 throw std::domain_error("error: ndim != 2");
0145 }
0146 py::list l;
0147 l.append(*r.data(0, 0));
0148 l.append(*r2.mutable_data(0, 0));
0149 l.append(r.data(0, 1) == r2.mutable_data(0, 1));
0150 l.append(r.ndim());
0151 l.append(r.itemsize());
0152 l.append(r.shape(0));
0153 l.append(r.shape(1));
0154 l.append(r.size());
0155 l.append(r.nbytes());
0156 return l.release();
0157 }
0158
0159
0160 static int data_i = 42;
0161
0162 TEST_SUBMODULE(numpy_array, sm) {
0163 try {
0164 py::module_::import("numpy");
0165 } catch (const py::error_already_set &) {
0166 return;
0167 }
0168
0169
0170 py::class_<DtypeCheck>(sm, "DtypeCheck")
0171 .def_readonly("numpy", &DtypeCheck::numpy)
0172 .def_readonly("pybind11", &DtypeCheck::pybind11)
0173 .def("__repr__", [](const DtypeCheck &self) {
0174 return py::str("<DtypeCheck numpy={} pybind11={}>").format(self.numpy, self.pybind11);
0175 });
0176 sm.def("get_concrete_dtype_checks", &get_concrete_dtype_checks);
0177
0178 py::class_<DtypeSizeCheck>(sm, "DtypeSizeCheck")
0179 .def_readonly("name", &DtypeSizeCheck::name)
0180 .def_readonly("size_cpp", &DtypeSizeCheck::size_cpp)
0181 .def_readonly("size_numpy", &DtypeSizeCheck::size_numpy)
0182 .def("__repr__", [](const DtypeSizeCheck &self) {
0183 return py::str("<DtypeSizeCheck name='{}' size_cpp={} size_numpy={} dtype={}>")
0184 .format(self.name, self.size_cpp, self.size_numpy, self.dtype);
0185 });
0186 sm.def("get_platform_dtype_size_checks", &get_platform_dtype_size_checks);
0187
0188
0189 sm.def("ndim", [](const arr &a) { return a.ndim(); });
0190 sm.def("shape", [](const arr &a) { return arr(a.ndim(), a.shape()); });
0191 sm.def("shape", [](const arr &a, py::ssize_t dim) { return a.shape(dim); });
0192 sm.def("strides", [](const arr &a) { return arr(a.ndim(), a.strides()); });
0193 sm.def("strides", [](const arr &a, py::ssize_t dim) { return a.strides(dim); });
0194 sm.def("writeable", [](const arr &a) { return a.writeable(); });
0195 sm.def("size", [](const arr &a) { return a.size(); });
0196 sm.def("itemsize", [](const arr &a) { return a.itemsize(); });
0197 sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
0198 sm.def("owndata", [](const arr &a) { return a.owndata(); });
0199
0200
0201 def_index_fn(index_at, const arr &);
0202 def_index_fn(index_at_t, const arr_t &);
0203 def_index_fn(offset_at, const arr &);
0204 def_index_fn(offset_at_t, const arr_t &);
0205
0206 def_index_fn(data, const arr &);
0207 def_index_fn(data_t, const arr_t &);
0208
0209 def_index_fn(mutate_data, arr &);
0210 def_index_fn(mutate_data_t, arr_t &);
0211 def_index_fn(at_t, const arr_t &);
0212 def_index_fn(mutate_at_t, arr_t &);
0213
0214
0215 sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
0216 sm.def("make_c_array", [] { return py::array_t<float>({2, 2}, {8, 4}); });
0217
0218
0219 sm.def("make_empty_shaped_array", [] { return py::array(py::dtype("f"), {}, {}); });
0220
0221 sm.def("scalar_int", []() { return py::array(py::dtype("i"), {}, {}, &data_i); });
0222
0223
0224 sm.def("wrap", [](const py::array &a) {
0225 return py::array(a.dtype(),
0226 {a.shape(), a.shape() + a.ndim()},
0227 {a.strides(), a.strides() + a.ndim()},
0228 a.data(),
0229 a);
0230 });
0231
0232
0233 struct ArrayClass {
0234 int data[2] = {1, 2};
0235 ArrayClass() { py::print("ArrayClass()"); }
0236 ~ArrayClass() { py::print("~ArrayClass()"); }
0237 };
0238 py::class_<ArrayClass>(sm, "ArrayClass")
0239 .def(py::init<>())
0240 .def("numpy_view", [](py::object &obj) {
0241 py::print("ArrayClass::numpy_view()");
0242 auto &a = obj.cast<ArrayClass &>();
0243 return py::array_t<int>({2}, {4}, a.data, obj);
0244 });
0245
0246
0247 sm.def("function_taking_uint64", [](uint64_t) {});
0248
0249
0250 sm.def("isinstance_untyped", [](py::object yes, py::object no) {
0251 return py::isinstance<py::array>(std::move(yes))
0252 && !py::isinstance<py::array>(std::move(no));
0253 });
0254 sm.def("isinstance_typed", [](const py::object &o) {
0255 return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o);
0256 });
0257
0258
0259 sm.def("default_constructors", []() {
0260 return py::dict("array"_a = py::array(),
0261 "array_t<int32>"_a = py::array_t<std::int32_t>(),
0262 "array_t<double>"_a = py::array_t<double>());
0263 });
0264 sm.def("converting_constructors", [](const py::object &o) {
0265 return py::dict("array"_a = py::array(o),
0266 "array_t<int32>"_a = py::array_t<std::int32_t>(o),
0267 "array_t<double>"_a = py::array_t<double>(o));
0268 });
0269
0270
0271 sm.def("overloaded", [](const py::array_t<double> &) { return "double"; });
0272 sm.def("overloaded", [](const py::array_t<float> &) { return "float"; });
0273 sm.def("overloaded", [](const py::array_t<int> &) { return "int"; });
0274 sm.def("overloaded", [](const py::array_t<unsigned short> &) { return "unsigned short"; });
0275 sm.def("overloaded", [](const py::array_t<long long> &) { return "long long"; });
0276 sm.def("overloaded",
0277 [](const py::array_t<std::complex<double>> &) { return "double complex"; });
0278 sm.def("overloaded", [](const py::array_t<std::complex<float>> &) { return "float complex"; });
0279
0280 sm.def("overloaded2",
0281 [](const py::array_t<std::complex<double>> &) { return "double complex"; });
0282 sm.def("overloaded2", [](const py::array_t<double> &) { return "double"; });
0283 sm.def("overloaded2",
0284 [](const py::array_t<std::complex<float>> &) { return "float complex"; });
0285 sm.def("overloaded2", [](const py::array_t<float> &) { return "float"; });
0286
0287
0288
0289
0290 sm.def(
0291 "overloaded3", [](const py::array_t<int> &) { return "int"; }, py::arg{}.noconvert());
0292 sm.def(
0293 "overloaded3",
0294 [](const py::array_t<double> &) { return "double"; },
0295 py::arg{}.noconvert());
0296
0297
0298
0299 sm.def("overloaded4", [](const py::array_t<long long, 0> &) { return "long long"; });
0300 sm.def("overloaded4", [](const py::array_t<double, 0> &) { return "double"; });
0301
0302
0303
0304 sm.def("overloaded5", [](const py::array_t<unsigned int> &) { return "unsigned int"; });
0305 sm.def("overloaded5", [](const py::array_t<double> &) { return "double"; });
0306
0307
0308
0309 sm.def("issue685", [](const std::string &) { return "string"; });
0310 sm.def("issue685", [](const py::array &) { return "array"; });
0311 sm.def("issue685", [](const py::object &) { return "other"; });
0312
0313
0314 sm.def(
0315 "proxy_add2",
0316 [](py::array_t<double> a, double v) {
0317 auto r = a.mutable_unchecked<2>();
0318 for (py::ssize_t i = 0; i < r.shape(0); i++) {
0319 for (py::ssize_t j = 0; j < r.shape(1); j++) {
0320 r(i, j) += v;
0321 }
0322 }
0323 },
0324 py::arg{}.noconvert(),
0325 py::arg());
0326
0327 sm.def("proxy_init3", [](double start) {
0328 py::array_t<double, py::array::c_style> a({3, 3, 3});
0329 auto r = a.mutable_unchecked<3>();
0330 for (py::ssize_t i = 0; i < r.shape(0); i++) {
0331 for (py::ssize_t j = 0; j < r.shape(1); j++) {
0332 for (py::ssize_t k = 0; k < r.shape(2); k++) {
0333 r(i, j, k) = start++;
0334 }
0335 }
0336 }
0337 return a;
0338 });
0339 sm.def("proxy_init3F", [](double start) {
0340 py::array_t<double, py::array::f_style> a({3, 3, 3});
0341 auto r = a.mutable_unchecked<3>();
0342 for (py::ssize_t k = 0; k < r.shape(2); k++) {
0343 for (py::ssize_t j = 0; j < r.shape(1); j++) {
0344 for (py::ssize_t i = 0; i < r.shape(0); i++) {
0345 r(i, j, k) = start++;
0346 }
0347 }
0348 }
0349 return a;
0350 });
0351 sm.def("proxy_squared_L2_norm", [](const py::array_t<double> &a) {
0352 auto r = a.unchecked<1>();
0353 double sumsq = 0;
0354 for (py::ssize_t i = 0; i < r.shape(0); i++) {
0355 sumsq += r[i] * r(i);
0356 }
0357 return sumsq;
0358 });
0359
0360 sm.def("proxy_auxiliaries2", [](py::array_t<double> a) {
0361 auto r = a.unchecked<2>();
0362 auto r2 = a.mutable_unchecked<2>();
0363 return auxiliaries(r, r2);
0364 });
0365
0366 sm.def("proxy_auxiliaries1_const_ref", [](py::array_t<double> a) {
0367 const auto &r = a.unchecked<1>();
0368 const auto &r2 = a.mutable_unchecked<1>();
0369 return r(0) == r2(0) && r[0] == r2[0];
0370 });
0371
0372 sm.def("proxy_auxiliaries2_const_ref", [](py::array_t<double> a) {
0373 const auto &r = a.unchecked<2>();
0374 const auto &r2 = a.mutable_unchecked<2>();
0375 return r(0, 0) == r2(0, 0);
0376 });
0377
0378
0379
0380 sm.def(
0381 "proxy_add2_dyn",
0382 [](py::array_t<double> a, double v) {
0383 auto r = a.mutable_unchecked();
0384 if (r.ndim() != 2) {
0385 throw std::domain_error("error: ndim != 2");
0386 }
0387 for (py::ssize_t i = 0; i < r.shape(0); i++) {
0388 for (py::ssize_t j = 0; j < r.shape(1); j++) {
0389 r(i, j) += v;
0390 }
0391 }
0392 },
0393 py::arg{}.noconvert(),
0394 py::arg());
0395 sm.def("proxy_init3_dyn", [](double start) {
0396 py::array_t<double, py::array::c_style> a({3, 3, 3});
0397 auto r = a.mutable_unchecked();
0398 if (r.ndim() != 3) {
0399 throw std::domain_error("error: ndim != 3");
0400 }
0401 for (py::ssize_t i = 0; i < r.shape(0); i++) {
0402 for (py::ssize_t j = 0; j < r.shape(1); j++) {
0403 for (py::ssize_t k = 0; k < r.shape(2); k++) {
0404 r(i, j, k) = start++;
0405 }
0406 }
0407 }
0408 return a;
0409 });
0410 sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) {
0411 return auxiliaries(a.unchecked(), a.mutable_unchecked());
0412 });
0413
0414 sm.def("array_auxiliaries2", [](py::array_t<double> a) { return auxiliaries(a, a); });
0415
0416
0417
0418
0419 sm.def("array_fail_test", []() { return py::array(py::object()); });
0420 sm.def("array_t_fail_test", []() { return py::array_t<double>(py::object()); });
0421
0422 sm.def("array_fail_test_negative_size", []() {
0423 int c = 0;
0424 return py::array(-1, &c);
0425 });
0426
0427
0428
0429 sm.def("array_initializer_list1", []() { return py::array_t<float>(1); });
0430
0431 sm.def("array_initializer_list2", []() { return py::array_t<float>({1, 2}); });
0432 sm.def("array_initializer_list3", []() { return py::array_t<float>({1, 2, 3}); });
0433 sm.def("array_initializer_list4", []() { return py::array_t<float>({1, 2, 3, 4}); });
0434
0435
0436
0437 sm.def("array_reshape2", [](py::array_t<double> a) {
0438 const auto dim_sz = (py::ssize_t) std::sqrt(a.size());
0439 if (dim_sz * dim_sz != a.size()) {
0440 throw std::domain_error(
0441 "array_reshape2: input array total size is not a squared integer");
0442 }
0443 a.resize({dim_sz, dim_sz});
0444 });
0445
0446
0447 sm.def("array_resize3", [](py::array_t<double> a, size_t N, bool refcheck) {
0448 a.resize({N, N, N}, refcheck);
0449 });
0450
0451
0452
0453 sm.def("create_and_resize", [](size_t N) {
0454 py::array_t<double> a;
0455 a.resize({N, N});
0456 std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.);
0457 return a;
0458 });
0459
0460 sm.def("array_view",
0461 [](py::array_t<uint8_t> a, const std::string &dtype) { return a.view(dtype); });
0462
0463 sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
0464 return a.reshape({N, M, O});
0465 });
0466 sm.def("reshape_tuple", [](py::array_t<int> a, const std::vector<int> &new_shape) {
0467 return a.reshape(new_shape);
0468 });
0469
0470 sm.def("index_using_ellipsis",
0471 [](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });
0472
0473
0474 sm.def(
0475 "accept_double", [](const py::array_t<double, 0> &) {}, py::arg("a"));
0476 sm.def(
0477 "accept_double_forcecast",
0478 [](const py::array_t<double, py::array::forcecast> &) {},
0479 py::arg("a"));
0480 sm.def(
0481 "accept_double_c_style",
0482 [](const py::array_t<double, py::array::c_style> &) {},
0483 py::arg("a"));
0484 sm.def(
0485 "accept_double_c_style_forcecast",
0486 [](const py::array_t<double, py::array::forcecast | py::array::c_style> &) {},
0487 py::arg("a"));
0488 sm.def(
0489 "accept_double_f_style",
0490 [](const py::array_t<double, py::array::f_style> &) {},
0491 py::arg("a"));
0492 sm.def(
0493 "accept_double_f_style_forcecast",
0494 [](const py::array_t<double, py::array::forcecast | py::array::f_style> &) {},
0495 py::arg("a"));
0496 sm.def(
0497 "accept_double_noconvert", [](const py::array_t<double, 0> &) {}, "a"_a.noconvert());
0498 sm.def(
0499 "accept_double_forcecast_noconvert",
0500 [](const py::array_t<double, py::array::forcecast> &) {},
0501 "a"_a.noconvert());
0502 sm.def(
0503 "accept_double_c_style_noconvert",
0504 [](const py::array_t<double, py::array::c_style> &) {},
0505 "a"_a.noconvert());
0506 sm.def(
0507 "accept_double_c_style_forcecast_noconvert",
0508 [](const py::array_t<double, py::array::forcecast | py::array::c_style> &) {},
0509 "a"_a.noconvert());
0510 sm.def(
0511 "accept_double_f_style_noconvert",
0512 [](const py::array_t<double, py::array::f_style> &) {},
0513 "a"_a.noconvert());
0514 sm.def(
0515 "accept_double_f_style_forcecast_noconvert",
0516 [](const py::array_t<double, py::array::forcecast | py::array::f_style> &) {},
0517 "a"_a.noconvert());
0518
0519
0520 sm.def("test_fmt_desc_float", [](const py::array_t<float> &) {});
0521 sm.def("test_fmt_desc_double", [](const py::array_t<double> &) {});
0522 sm.def("test_fmt_desc_const_float", [](const py::array_t<const float> &) {});
0523 sm.def("test_fmt_desc_const_double", [](const py::array_t<const double> &) {});
0524
0525 sm.def("round_trip_float", [](double d) { return d; });
0526 }