File indexing completed on 2025-01-18 10:17:55
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011 #include <pybind11/numpy.h>
0012
0013 #include "pybind11_tests.h"
0014
0015 #include <utility>
0016
0017 double my_func(int x, float y, double z) {
0018 py::print("my_func(x:int={}, y:float={:.0f}, z:float={:.0f})"_s.format(x, y, z));
0019 return (float) x * y * z;
0020 }
0021
0022 TEST_SUBMODULE(numpy_vectorize, m) {
0023 try {
0024 py::module_::import("numpy");
0025 } catch (const py::error_already_set &) {
0026 return;
0027 }
0028
0029
0030
0031 m.def("vectorized_func", py::vectorize(my_func));
0032
0033
0034
0035 m.def("vectorized_func2", [](py::array_t<int> x, py::array_t<float> y, float z) {
0036 return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(std::move(x),
0037 std::move(y));
0038 });
0039
0040
0041 m.def("vectorized_func3",
0042 py::vectorize([](std::complex<double> c) { return c * std::complex<double>(2.f); }));
0043
0044
0045
0046
0047
0048 m.def("selective_func",
0049 [](const py::array_t<int, py::array::c_style> &) { return "Int branch taken."; });
0050 m.def("selective_func",
0051 [](const py::array_t<float, py::array::c_style> &) { return "Float branch taken."; });
0052 m.def("selective_func", [](const py::array_t<std::complex<float>, py::array::c_style> &) {
0053 return "Complex float branch taken.";
0054 });
0055
0056
0057
0058
0059 struct NonPODClass {
0060 explicit NonPODClass(int v) : value{v} {}
0061 int value;
0062 };
0063 py::class_<NonPODClass>(m, "NonPODClass")
0064 .def(py::init<int>())
0065 .def_readwrite("value", &NonPODClass::value);
0066 m.def("vec_passthrough",
0067 py::vectorize([](const double *a,
0068 double b,
0069
0070
0071 py::array_t<double> c,
0072 const int &d,
0073 int &e,
0074 NonPODClass f,
0075 const double g) { return *a + b + c.at(0) + d + e + f.value + g; }));
0076
0077
0078 struct VectorizeTestClass {
0079 explicit VectorizeTestClass(int v) : value{v} {};
0080 float method(int x, float y) const { return y + (float) (x + value); }
0081 int value = 0;
0082 };
0083 py::class_<VectorizeTestClass> vtc(m, "VectorizeTestClass");
0084 vtc.def(py::init<int>()).def_readwrite("value", &VectorizeTestClass::value);
0085
0086
0087 vtc.def("method", py::vectorize(&VectorizeTestClass::method));
0088
0089
0090
0091 py::enum_<py::detail::broadcast_trivial>(m, "trivial")
0092 .value("f_trivial", py::detail::broadcast_trivial::f_trivial)
0093 .value("c_trivial", py::detail::broadcast_trivial::c_trivial)
0094 .value("non_trivial", py::detail::broadcast_trivial::non_trivial);
0095 m.def("vectorized_is_trivial",
0096 [](const py::array_t<int, py::array::forcecast> &arg1,
0097 const py::array_t<float, py::array::forcecast> &arg2,
0098 const py::array_t<double, py::array::forcecast> &arg3) {
0099 py::ssize_t ndim = 0;
0100 std::vector<py::ssize_t> shape;
0101 std::array<py::buffer_info, 3> buffers{
0102 {arg1.request(), arg2.request(), arg3.request()}};
0103 return py::detail::broadcast(buffers, ndim, shape);
0104 });
0105
0106 m.def("add_to", py::vectorize([](NonPODClass &x, int a) { x.value += a; }));
0107 }