Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:17:55

0001 /*
0002     tests/test_numpy_array.cpp -- test core array functionality
0003 
0004     Copyright (c) 2016 Ivan Smirnov <i.s.smirnov@gmail.com>
0005 
0006     All rights reserved. Use of this source code is governed by a
0007     BSD-style license that can be found in the LICENSE file.
0008 */
0009 
0010 #include <pybind11/numpy.h>
0011 #include <pybind11/stl.h>
0012 
0013 #include "pybind11_tests.h"
0014 
0015 #include <cstdint>
0016 #include <utility>
0017 
0018 // Size / dtype checks.
0019 struct DtypeCheck {
0020     py::dtype numpy{};
0021     py::dtype pybind11{};
0022 };
0023 
0024 template <typename T>
0025 DtypeCheck get_dtype_check(const char *name) {
0026     py::module_ np = py::module_::import("numpy");
0027     DtypeCheck check{};
0028     check.numpy = np.attr("dtype")(np.attr(name));
0029     check.pybind11 = py::dtype::of<T>();
0030     return check;
0031 }
0032 
0033 std::vector<DtypeCheck> get_concrete_dtype_checks() {
0034     return {// Normalization
0035             get_dtype_check<std::int8_t>("int8"),
0036             get_dtype_check<std::uint8_t>("uint8"),
0037             get_dtype_check<std::int16_t>("int16"),
0038             get_dtype_check<std::uint16_t>("uint16"),
0039             get_dtype_check<std::int32_t>("int32"),
0040             get_dtype_check<std::uint32_t>("uint32"),
0041             get_dtype_check<std::int64_t>("int64"),
0042             get_dtype_check<std::uint64_t>("uint64")};
0043 }
0044 
0045 struct DtypeSizeCheck {
0046     std::string name{};
0047     int size_cpp{};
0048     int size_numpy{};
0049     // For debugging.
0050     py::dtype dtype{};
0051 };
0052 
0053 template <typename T>
0054 DtypeSizeCheck get_dtype_size_check() {
0055     DtypeSizeCheck check{};
0056     check.name = py::type_id<T>();
0057     check.size_cpp = sizeof(T);
0058     check.dtype = py::dtype::of<T>();
0059     check.size_numpy = check.dtype.attr("itemsize").template cast<int>();
0060     return check;
0061 }
0062 
0063 std::vector<DtypeSizeCheck> get_platform_dtype_size_checks() {
0064     return {
0065         get_dtype_size_check<short>(),
0066         get_dtype_size_check<unsigned short>(),
0067         get_dtype_size_check<int>(),
0068         get_dtype_size_check<unsigned int>(),
0069         get_dtype_size_check<long>(),
0070         get_dtype_size_check<unsigned long>(),
0071         get_dtype_size_check<long long>(),
0072         get_dtype_size_check<unsigned long long>(),
0073     };
0074 }
0075 
0076 // Arrays.
0077 using arr = py::array;
0078 using arr_t = py::array_t<uint16_t, 0>;
0079 static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
0080 
0081 template <typename... Ix>
0082 arr data(const arr &a, Ix... index) {
0083     return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...));
0084 }
0085 
0086 template <typename... Ix>
0087 arr data_t(const arr_t &a, Ix... index) {
0088     return arr(a.size() - a.index_at(index...), a.data(index...));
0089 }
0090 
0091 template <typename... Ix>
0092 arr &mutate_data(arr &a, Ix... index) {
0093     auto *ptr = (uint8_t *) a.mutable_data(index...);
0094     for (py::ssize_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) {
0095         ptr[i] = (uint8_t) (ptr[i] * 2);
0096     }
0097     return a;
0098 }
0099 
0100 template <typename... Ix>
0101 arr_t &mutate_data_t(arr_t &a, Ix... index) {
0102     auto ptr = a.mutable_data(index...);
0103     for (py::ssize_t i = 0; i < a.size() - a.index_at(index...); i++) {
0104         ptr[i]++;
0105     }
0106     return a;
0107 }
0108 
0109 template <typename... Ix>
0110 py::ssize_t index_at(const arr &a, Ix... idx) {
0111     return a.index_at(idx...);
0112 }
0113 template <typename... Ix>
0114 py::ssize_t index_at_t(const arr_t &a, Ix... idx) {
0115     return a.index_at(idx...);
0116 }
0117 template <typename... Ix>
0118 py::ssize_t offset_at(const arr &a, Ix... idx) {
0119     return a.offset_at(idx...);
0120 }
0121 template <typename... Ix>
0122 py::ssize_t offset_at_t(const arr_t &a, Ix... idx) {
0123     return a.offset_at(idx...);
0124 }
0125 template <typename... Ix>
0126 py::ssize_t at_t(const arr_t &a, Ix... idx) {
0127     return a.at(idx...);
0128 }
0129 template <typename... Ix>
0130 arr_t &mutate_at_t(arr_t &a, Ix... idx) {
0131     a.mutable_at(idx...)++;
0132     return a;
0133 }
0134 
0135 #define def_index_fn(name, type)                                                                  \
0136     sm.def(#name, [](type a) { return name(a); });                                                \
0137     sm.def(#name, [](type a, int i) { return name(a, i); });                                      \
0138     sm.def(#name, [](type a, int i, int j) { return name(a, i, j); });                            \
0139     sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); });
0140 
0141 template <typename T, typename T2>
0142 py::handle auxiliaries(T &&r, T2 &&r2) {
0143     if (r.ndim() != 2) {
0144         throw std::domain_error("error: ndim != 2");
0145     }
0146     py::list l;
0147     l.append(*r.data(0, 0));
0148     l.append(*r2.mutable_data(0, 0));
0149     l.append(r.data(0, 1) == r2.mutable_data(0, 1));
0150     l.append(r.ndim());
0151     l.append(r.itemsize());
0152     l.append(r.shape(0));
0153     l.append(r.shape(1));
0154     l.append(r.size());
0155     l.append(r.nbytes());
0156     return l.release();
0157 }
0158 
0159 // note: declaration at local scope would create a dangling reference!
0160 static int data_i = 42;
0161 
0162 TEST_SUBMODULE(numpy_array, sm) {
0163     try {
0164         py::module_::import("numpy");
0165     } catch (const py::error_already_set &) {
0166         return;
0167     }
0168 
0169     // test_dtypes
0170     py::class_<DtypeCheck>(sm, "DtypeCheck")
0171         .def_readonly("numpy", &DtypeCheck::numpy)
0172         .def_readonly("pybind11", &DtypeCheck::pybind11)
0173         .def("__repr__", [](const DtypeCheck &self) {
0174             return py::str("<DtypeCheck numpy={} pybind11={}>").format(self.numpy, self.pybind11);
0175         });
0176     sm.def("get_concrete_dtype_checks", &get_concrete_dtype_checks);
0177 
0178     py::class_<DtypeSizeCheck>(sm, "DtypeSizeCheck")
0179         .def_readonly("name", &DtypeSizeCheck::name)
0180         .def_readonly("size_cpp", &DtypeSizeCheck::size_cpp)
0181         .def_readonly("size_numpy", &DtypeSizeCheck::size_numpy)
0182         .def("__repr__", [](const DtypeSizeCheck &self) {
0183             return py::str("<DtypeSizeCheck name='{}' size_cpp={} size_numpy={} dtype={}>")
0184                 .format(self.name, self.size_cpp, self.size_numpy, self.dtype);
0185         });
0186     sm.def("get_platform_dtype_size_checks", &get_platform_dtype_size_checks);
0187 
0188     // test_array_attributes
0189     sm.def("ndim", [](const arr &a) { return a.ndim(); });
0190     sm.def("shape", [](const arr &a) { return arr(a.ndim(), a.shape()); });
0191     sm.def("shape", [](const arr &a, py::ssize_t dim) { return a.shape(dim); });
0192     sm.def("strides", [](const arr &a) { return arr(a.ndim(), a.strides()); });
0193     sm.def("strides", [](const arr &a, py::ssize_t dim) { return a.strides(dim); });
0194     sm.def("writeable", [](const arr &a) { return a.writeable(); });
0195     sm.def("size", [](const arr &a) { return a.size(); });
0196     sm.def("itemsize", [](const arr &a) { return a.itemsize(); });
0197     sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
0198     sm.def("owndata", [](const arr &a) { return a.owndata(); });
0199 
0200     // test_index_offset
0201     def_index_fn(index_at, const arr &);
0202     def_index_fn(index_at_t, const arr_t &);
0203     def_index_fn(offset_at, const arr &);
0204     def_index_fn(offset_at_t, const arr_t &);
0205     // test_data
0206     def_index_fn(data, const arr &);
0207     def_index_fn(data_t, const arr_t &);
0208     // test_mutate_data, test_mutate_readonly
0209     def_index_fn(mutate_data, arr &);
0210     def_index_fn(mutate_data_t, arr_t &);
0211     def_index_fn(at_t, const arr_t &);
0212     def_index_fn(mutate_at_t, arr_t &);
0213 
0214     // test_make_c_f_array
0215     sm.def("make_f_array", [] { return py::array_t<float>({2, 2}, {4, 8}); });
0216     sm.def("make_c_array", [] { return py::array_t<float>({2, 2}, {8, 4}); });
0217 
0218     // test_empty_shaped_array
0219     sm.def("make_empty_shaped_array", [] { return py::array(py::dtype("f"), {}, {}); });
0220     // test numpy scalars (empty shape, ndim==0)
0221     sm.def("scalar_int", []() { return py::array(py::dtype("i"), {}, {}, &data_i); });
0222 
0223     // test_wrap
0224     sm.def("wrap", [](const py::array &a) {
0225         return py::array(a.dtype(),
0226                          {a.shape(), a.shape() + a.ndim()},
0227                          {a.strides(), a.strides() + a.ndim()},
0228                          a.data(),
0229                          a);
0230     });
0231 
0232     // test_numpy_view
0233     struct ArrayClass {
0234         int data[2] = {1, 2};
0235         ArrayClass() { py::print("ArrayClass()"); }
0236         ~ArrayClass() { py::print("~ArrayClass()"); }
0237     };
0238     py::class_<ArrayClass>(sm, "ArrayClass")
0239         .def(py::init<>())
0240         .def("numpy_view", [](py::object &obj) {
0241             py::print("ArrayClass::numpy_view()");
0242             auto &a = obj.cast<ArrayClass &>();
0243             return py::array_t<int>({2}, {4}, a.data, obj);
0244         });
0245 
0246     // test_cast_numpy_int64_to_uint64
0247     sm.def("function_taking_uint64", [](uint64_t) {});
0248 
0249     // test_isinstance
0250     sm.def("isinstance_untyped", [](py::object yes, py::object no) {
0251         return py::isinstance<py::array>(std::move(yes))
0252                && !py::isinstance<py::array>(std::move(no));
0253     });
0254     sm.def("isinstance_typed", [](const py::object &o) {
0255         return py::isinstance<py::array_t<double>>(o) && !py::isinstance<py::array_t<int>>(o);
0256     });
0257 
0258     // test_constructors
0259     sm.def("default_constructors", []() {
0260         return py::dict("array"_a = py::array(),
0261                         "array_t<int32>"_a = py::array_t<std::int32_t>(),
0262                         "array_t<double>"_a = py::array_t<double>());
0263     });
0264     sm.def("converting_constructors", [](const py::object &o) {
0265         return py::dict("array"_a = py::array(o),
0266                         "array_t<int32>"_a = py::array_t<std::int32_t>(o),
0267                         "array_t<double>"_a = py::array_t<double>(o));
0268     });
0269 
0270     // test_overload_resolution
0271     sm.def("overloaded", [](const py::array_t<double> &) { return "double"; });
0272     sm.def("overloaded", [](const py::array_t<float> &) { return "float"; });
0273     sm.def("overloaded", [](const py::array_t<int> &) { return "int"; });
0274     sm.def("overloaded", [](const py::array_t<unsigned short> &) { return "unsigned short"; });
0275     sm.def("overloaded", [](const py::array_t<long long> &) { return "long long"; });
0276     sm.def("overloaded",
0277            [](const py::array_t<std::complex<double>> &) { return "double complex"; });
0278     sm.def("overloaded", [](const py::array_t<std::complex<float>> &) { return "float complex"; });
0279 
0280     sm.def("overloaded2",
0281            [](const py::array_t<std::complex<double>> &) { return "double complex"; });
0282     sm.def("overloaded2", [](const py::array_t<double> &) { return "double"; });
0283     sm.def("overloaded2",
0284            [](const py::array_t<std::complex<float>> &) { return "float complex"; });
0285     sm.def("overloaded2", [](const py::array_t<float> &) { return "float"; });
0286 
0287     // [workaround(intel)] ICC 20/21 breaks with py::arg().stuff, using py::arg{}.stuff works.
0288 
0289     // Only accept the exact types:
0290     sm.def(
0291         "overloaded3", [](const py::array_t<int> &) { return "int"; }, py::arg{}.noconvert());
0292     sm.def(
0293         "overloaded3",
0294         [](const py::array_t<double> &) { return "double"; },
0295         py::arg{}.noconvert());
0296 
0297     // Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but
0298     // rather that float gets converted via the safe (conversion to double) overload:
0299     sm.def("overloaded4", [](const py::array_t<long long, 0> &) { return "long long"; });
0300     sm.def("overloaded4", [](const py::array_t<double, 0> &) { return "double"; });
0301 
0302     // But we do allow conversion to int if forcecast is enabled (but only if no overload matches
0303     // without conversion)
0304     sm.def("overloaded5", [](const py::array_t<unsigned int> &) { return "unsigned int"; });
0305     sm.def("overloaded5", [](const py::array_t<double> &) { return "double"; });
0306 
0307     // test_greedy_string_overload
0308     // Issue 685: ndarray shouldn't go to std::string overload
0309     sm.def("issue685", [](const std::string &) { return "string"; });
0310     sm.def("issue685", [](const py::array &) { return "array"; });
0311     sm.def("issue685", [](const py::object &) { return "other"; });
0312 
0313     // test_array_unchecked_fixed_dims
0314     sm.def(
0315         "proxy_add2",
0316         [](py::array_t<double> a, double v) {
0317             auto r = a.mutable_unchecked<2>();
0318             for (py::ssize_t i = 0; i < r.shape(0); i++) {
0319                 for (py::ssize_t j = 0; j < r.shape(1); j++) {
0320                     r(i, j) += v;
0321                 }
0322             }
0323         },
0324         py::arg{}.noconvert(),
0325         py::arg());
0326 
0327     sm.def("proxy_init3", [](double start) {
0328         py::array_t<double, py::array::c_style> a({3, 3, 3});
0329         auto r = a.mutable_unchecked<3>();
0330         for (py::ssize_t i = 0; i < r.shape(0); i++) {
0331             for (py::ssize_t j = 0; j < r.shape(1); j++) {
0332                 for (py::ssize_t k = 0; k < r.shape(2); k++) {
0333                     r(i, j, k) = start++;
0334                 }
0335             }
0336         }
0337         return a;
0338     });
0339     sm.def("proxy_init3F", [](double start) {
0340         py::array_t<double, py::array::f_style> a({3, 3, 3});
0341         auto r = a.mutable_unchecked<3>();
0342         for (py::ssize_t k = 0; k < r.shape(2); k++) {
0343             for (py::ssize_t j = 0; j < r.shape(1); j++) {
0344                 for (py::ssize_t i = 0; i < r.shape(0); i++) {
0345                     r(i, j, k) = start++;
0346                 }
0347             }
0348         }
0349         return a;
0350     });
0351     sm.def("proxy_squared_L2_norm", [](const py::array_t<double> &a) {
0352         auto r = a.unchecked<1>();
0353         double sumsq = 0;
0354         for (py::ssize_t i = 0; i < r.shape(0); i++) {
0355             sumsq += r[i] * r(i); // Either notation works for a 1D array
0356         }
0357         return sumsq;
0358     });
0359 
0360     sm.def("proxy_auxiliaries2", [](py::array_t<double> a) {
0361         auto r = a.unchecked<2>();
0362         auto r2 = a.mutable_unchecked<2>();
0363         return auxiliaries(r, r2);
0364     });
0365 
0366     sm.def("proxy_auxiliaries1_const_ref", [](py::array_t<double> a) {
0367         const auto &r = a.unchecked<1>();
0368         const auto &r2 = a.mutable_unchecked<1>();
0369         return r(0) == r2(0) && r[0] == r2[0];
0370     });
0371 
0372     sm.def("proxy_auxiliaries2_const_ref", [](py::array_t<double> a) {
0373         const auto &r = a.unchecked<2>();
0374         const auto &r2 = a.mutable_unchecked<2>();
0375         return r(0, 0) == r2(0, 0);
0376     });
0377 
0378     // test_array_unchecked_dyn_dims
0379     // Same as the above, but without a compile-time dimensions specification:
0380     sm.def(
0381         "proxy_add2_dyn",
0382         [](py::array_t<double> a, double v) {
0383             auto r = a.mutable_unchecked();
0384             if (r.ndim() != 2) {
0385                 throw std::domain_error("error: ndim != 2");
0386             }
0387             for (py::ssize_t i = 0; i < r.shape(0); i++) {
0388                 for (py::ssize_t j = 0; j < r.shape(1); j++) {
0389                     r(i, j) += v;
0390                 }
0391             }
0392         },
0393         py::arg{}.noconvert(),
0394         py::arg());
0395     sm.def("proxy_init3_dyn", [](double start) {
0396         py::array_t<double, py::array::c_style> a({3, 3, 3});
0397         auto r = a.mutable_unchecked();
0398         if (r.ndim() != 3) {
0399             throw std::domain_error("error: ndim != 3");
0400         }
0401         for (py::ssize_t i = 0; i < r.shape(0); i++) {
0402             for (py::ssize_t j = 0; j < r.shape(1); j++) {
0403                 for (py::ssize_t k = 0; k < r.shape(2); k++) {
0404                     r(i, j, k) = start++;
0405                 }
0406             }
0407         }
0408         return a;
0409     });
0410     sm.def("proxy_auxiliaries2_dyn", [](py::array_t<double> a) {
0411         return auxiliaries(a.unchecked(), a.mutable_unchecked());
0412     });
0413 
0414     sm.def("array_auxiliaries2", [](py::array_t<double> a) { return auxiliaries(a, a); });
0415 
0416     // test_array_failures
0417     // Issue #785: Uninformative "Unknown internal error" exception when constructing array from
0418     // empty object:
0419     sm.def("array_fail_test", []() { return py::array(py::object()); });
0420     sm.def("array_t_fail_test", []() { return py::array_t<double>(py::object()); });
0421     // Make sure the error from numpy is being passed through:
0422     sm.def("array_fail_test_negative_size", []() {
0423         int c = 0;
0424         return py::array(-1, &c);
0425     });
0426 
0427     // test_initializer_list
0428     // Issue (unnumbered; reported in #788): regression: initializer lists can be ambiguous
0429     sm.def("array_initializer_list1", []() { return py::array_t<float>(1); });
0430     // { 1 } also works for the above, but clang warns about it
0431     sm.def("array_initializer_list2", []() { return py::array_t<float>({1, 2}); });
0432     sm.def("array_initializer_list3", []() { return py::array_t<float>({1, 2, 3}); });
0433     sm.def("array_initializer_list4", []() { return py::array_t<float>({1, 2, 3, 4}); });
0434 
0435     // test_array_resize
0436     // reshape array to 2D without changing size
0437     sm.def("array_reshape2", [](py::array_t<double> a) {
0438         const auto dim_sz = (py::ssize_t) std::sqrt(a.size());
0439         if (dim_sz * dim_sz != a.size()) {
0440             throw std::domain_error(
0441                 "array_reshape2: input array total size is not a squared integer");
0442         }
0443         a.resize({dim_sz, dim_sz});
0444     });
0445 
0446     // resize to 3D array with each dimension = N
0447     sm.def("array_resize3", [](py::array_t<double> a, size_t N, bool refcheck) {
0448         a.resize({N, N, N}, refcheck);
0449     });
0450 
0451     // test_array_create_and_resize
0452     // return 2D array with Nrows = Ncols = N
0453     sm.def("create_and_resize", [](size_t N) {
0454         py::array_t<double> a;
0455         a.resize({N, N});
0456         std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.);
0457         return a;
0458     });
0459 
0460     sm.def("array_view",
0461            [](py::array_t<uint8_t> a, const std::string &dtype) { return a.view(dtype); });
0462 
0463     sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
0464         return a.reshape({N, M, O});
0465     });
0466     sm.def("reshape_tuple", [](py::array_t<int> a, const std::vector<int> &new_shape) {
0467         return a.reshape(new_shape);
0468     });
0469 
0470     sm.def("index_using_ellipsis",
0471            [](const py::array &a) { return a[py::make_tuple(0, py::ellipsis(), 0)]; });
0472 
0473     // test_argument_conversions
0474     sm.def(
0475         "accept_double", [](const py::array_t<double, 0> &) {}, py::arg("a"));
0476     sm.def(
0477         "accept_double_forcecast",
0478         [](const py::array_t<double, py::array::forcecast> &) {},
0479         py::arg("a"));
0480     sm.def(
0481         "accept_double_c_style",
0482         [](const py::array_t<double, py::array::c_style> &) {},
0483         py::arg("a"));
0484     sm.def(
0485         "accept_double_c_style_forcecast",
0486         [](const py::array_t<double, py::array::forcecast | py::array::c_style> &) {},
0487         py::arg("a"));
0488     sm.def(
0489         "accept_double_f_style",
0490         [](const py::array_t<double, py::array::f_style> &) {},
0491         py::arg("a"));
0492     sm.def(
0493         "accept_double_f_style_forcecast",
0494         [](const py::array_t<double, py::array::forcecast | py::array::f_style> &) {},
0495         py::arg("a"));
0496     sm.def(
0497         "accept_double_noconvert", [](const py::array_t<double, 0> &) {}, "a"_a.noconvert());
0498     sm.def(
0499         "accept_double_forcecast_noconvert",
0500         [](const py::array_t<double, py::array::forcecast> &) {},
0501         "a"_a.noconvert());
0502     sm.def(
0503         "accept_double_c_style_noconvert",
0504         [](const py::array_t<double, py::array::c_style> &) {},
0505         "a"_a.noconvert());
0506     sm.def(
0507         "accept_double_c_style_forcecast_noconvert",
0508         [](const py::array_t<double, py::array::forcecast | py::array::c_style> &) {},
0509         "a"_a.noconvert());
0510     sm.def(
0511         "accept_double_f_style_noconvert",
0512         [](const py::array_t<double, py::array::f_style> &) {},
0513         "a"_a.noconvert());
0514     sm.def(
0515         "accept_double_f_style_forcecast_noconvert",
0516         [](const py::array_t<double, py::array::forcecast | py::array::f_style> &) {},
0517         "a"_a.noconvert());
0518 
0519     // Check that types returns correct npy format descriptor
0520     sm.def("test_fmt_desc_float", [](const py::array_t<float> &) {});
0521     sm.def("test_fmt_desc_double", [](const py::array_t<double> &) {});
0522     sm.def("test_fmt_desc_const_float", [](const py::array_t<const float> &) {});
0523     sm.def("test_fmt_desc_const_double", [](const py::array_t<const double> &) {});
0524 
0525     sm.def("round_trip_float", [](double d) { return d; });
0526 }