Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:17:51

0001 #include <pybind11/embed.h>
0002 
0003 // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
0004 // catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
0005 PYBIND11_WARNING_DISABLE_MSVC(4996)
0006 
0007 #include <catch.hpp>
0008 #include <cstdlib>
0009 #include <fstream>
0010 #include <functional>
0011 #include <thread>
0012 #include <utility>
0013 
0014 namespace py = pybind11;
0015 using namespace py::literals;
0016 
0017 size_t get_sys_path_size() {
0018     auto sys_path = py::module::import("sys").attr("path");
0019     return py::len(sys_path);
0020 }
0021 
0022 class Widget {
0023 public:
0024     explicit Widget(std::string message) : message(std::move(message)) {}
0025     virtual ~Widget() = default;
0026 
0027     std::string the_message() const { return message; }
0028     virtual int the_answer() const = 0;
0029     virtual std::string argv0() const = 0;
0030 
0031 private:
0032     std::string message;
0033 };
0034 
0035 class PyWidget final : public Widget {
0036     using Widget::Widget;
0037 
0038     int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); }
0039     std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); }
0040 };
0041 
0042 class test_override_cache_helper {
0043 
0044 public:
0045     virtual int func() { return 0; }
0046 
0047     test_override_cache_helper() = default;
0048     virtual ~test_override_cache_helper() = default;
0049     // Non-copyable
0050     test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete;
0051     test_override_cache_helper(test_override_cache_helper const &Copy) = delete;
0052 };
0053 
0054 class test_override_cache_helper_trampoline : public test_override_cache_helper {
0055     int func() override { PYBIND11_OVERRIDE(int, test_override_cache_helper, func); }
0056 };
0057 
0058 PYBIND11_EMBEDDED_MODULE(widget_module, m) {
0059     py::class_<Widget, PyWidget>(m, "Widget")
0060         .def(py::init<std::string>())
0061         .def_property_readonly("the_message", &Widget::the_message);
0062 
0063     m.def("add", [](int i, int j) { return i + j; });
0064 }
0065 
0066 PYBIND11_EMBEDDED_MODULE(trampoline_module, m) {
0067     py::class_<test_override_cache_helper,
0068                test_override_cache_helper_trampoline,
0069                std::shared_ptr<test_override_cache_helper>>(m, "test_override_cache_helper")
0070         .def(py::init_alias<>())
0071         .def("func", &test_override_cache_helper::func);
0072 }
0073 
0074 PYBIND11_EMBEDDED_MODULE(throw_exception, ) { throw std::runtime_error("C++ Error"); }
0075 
0076 PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
0077     auto d = py::dict();
0078     d["missing"].cast<py::object>();
0079 }
0080 
0081 TEST_CASE("PYTHONPATH is used to update sys.path") {
0082     // The setup for this TEST_CASE is in catch.cpp!
0083     auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
0084     REQUIRE_THAT(sys_path,
0085                  Catch::Matchers::Contains("pybind11_test_embed_PYTHONPATH_2099743835476552"));
0086 }
0087 
0088 TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
0089     auto module_ = py::module_::import("test_interpreter");
0090     REQUIRE(py::hasattr(module_, "DerivedWidget"));
0091 
0092     auto locals = py::dict("hello"_a = "Hello, World!", "x"_a = 5, **module_.attr("__dict__"));
0093     py::exec(R"(
0094         widget = DerivedWidget("{} - {}".format(hello, x))
0095         message = widget.the_message
0096     )",
0097              py::globals(),
0098              locals);
0099     REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
0100 
0101     auto py_widget = module_.attr("DerivedWidget")("The question");
0102     auto message = py_widget.attr("the_message");
0103     REQUIRE(message.cast<std::string>() == "The question");
0104 
0105     const auto &cpp_widget = py_widget.cast<const Widget &>();
0106     REQUIRE(cpp_widget.the_answer() == 42);
0107 }
0108 
0109 TEST_CASE("Override cache") {
0110     auto module_ = py::module_::import("test_trampoline");
0111     REQUIRE(py::hasattr(module_, "func"));
0112     REQUIRE(py::hasattr(module_, "func2"));
0113 
0114     auto locals = py::dict(**module_.attr("__dict__"));
0115 
0116     int i = 0;
0117     for (; i < 1500; ++i) {
0118         std::shared_ptr<test_override_cache_helper> p_obj;
0119         std::shared_ptr<test_override_cache_helper> p_obj2;
0120 
0121         py::object loc_inst = locals["func"]();
0122         p_obj = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);
0123 
0124         int ret = p_obj->func();
0125 
0126         REQUIRE(ret == 42);
0127 
0128         loc_inst = locals["func2"]();
0129 
0130         p_obj2 = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);
0131 
0132         p_obj2->func();
0133     }
0134 }
0135 
0136 TEST_CASE("Import error handling") {
0137     REQUIRE_NOTHROW(py::module_::import("widget_module"));
0138     REQUIRE_THROWS_WITH(py::module_::import("throw_exception"), "ImportError: C++ Error");
0139     REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"),
0140                         Catch::Contains("ImportError: initialization failed"));
0141 
0142     auto locals = py::dict("is_keyerror"_a = false, "message"_a = "not set");
0143     py::exec(R"(
0144         try:
0145             import throw_error_already_set
0146         except ImportError as e:
0147             is_keyerror = type(e.__cause__) == KeyError
0148             message = str(e.__cause__)
0149     )",
0150              py::globals(),
0151              locals);
0152     REQUIRE(locals["is_keyerror"].cast<bool>() == true);
0153     REQUIRE(locals["message"].cast<std::string>() == "'missing'");
0154 }
0155 
0156 TEST_CASE("There can be only one interpreter") {
0157     static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
0158     static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
0159     static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
0160     static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
0161 
0162     REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
0163     REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
0164 
0165     py::finalize_interpreter();
0166     REQUIRE_NOTHROW(py::scoped_interpreter());
0167     {
0168         auto pyi1 = py::scoped_interpreter();
0169         auto pyi2 = std::move(pyi1);
0170     }
0171     py::initialize_interpreter();
0172 }
0173 
0174 #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
0175 TEST_CASE("Custom PyConfig") {
0176     py::finalize_interpreter();
0177     PyConfig config;
0178     PyConfig_InitPythonConfig(&config);
0179     REQUIRE_NOTHROW(py::scoped_interpreter{&config});
0180     {
0181         py::scoped_interpreter p{&config};
0182         REQUIRE(py::module_::import("widget_module").attr("add")(1, 41).cast<int>() == 42);
0183     }
0184     py::initialize_interpreter();
0185 }
0186 
0187 TEST_CASE("Custom PyConfig with argv") {
0188     py::finalize_interpreter();
0189     {
0190         PyConfig config;
0191         PyConfig_InitIsolatedConfig(&config);
0192         char *argv[] = {strdup("a.out")};
0193         py::scoped_interpreter argv_scope{&config, 1, argv};
0194         std::free(argv[0]);
0195         auto module = py::module::import("test_interpreter");
0196         auto py_widget = module.attr("DerivedWidget")("The question");
0197         const auto &cpp_widget = py_widget.cast<const Widget &>();
0198         REQUIRE(cpp_widget.argv0() == "a.out");
0199     }
0200     py::initialize_interpreter();
0201 }
0202 #endif
0203 
0204 TEST_CASE("Add program dir to path pre-PyConfig") {
0205     py::finalize_interpreter();
0206     size_t path_size_add_program_dir_to_path_false = 0;
0207     {
0208         py::scoped_interpreter scoped_interp{true, 0, nullptr, false};
0209         path_size_add_program_dir_to_path_false = get_sys_path_size();
0210     }
0211     {
0212         py::scoped_interpreter scoped_interp{};
0213         REQUIRE(get_sys_path_size() == path_size_add_program_dir_to_path_false + 1);
0214     }
0215     py::initialize_interpreter();
0216 }
0217 
0218 #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
0219 TEST_CASE("Add program dir to path using PyConfig") {
0220     py::finalize_interpreter();
0221     size_t path_size_add_program_dir_to_path_false = 0;
0222     {
0223         PyConfig config;
0224         PyConfig_InitPythonConfig(&config);
0225         py::scoped_interpreter scoped_interp{&config, 0, nullptr, false};
0226         path_size_add_program_dir_to_path_false = get_sys_path_size();
0227     }
0228     {
0229         PyConfig config;
0230         PyConfig_InitPythonConfig(&config);
0231         py::scoped_interpreter scoped_interp{&config};
0232         REQUIRE(get_sys_path_size() == path_size_add_program_dir_to_path_false + 1);
0233     }
0234     py::initialize_interpreter();
0235 }
0236 #endif
0237 
0238 bool has_pybind11_internals_builtin() {
0239     auto builtins = py::handle(PyEval_GetBuiltins());
0240     return builtins.contains(PYBIND11_INTERNALS_ID);
0241 };
0242 
0243 bool has_pybind11_internals_static() {
0244     auto **&ipp = py::detail::get_internals_pp();
0245     return (ipp != nullptr) && (*ipp != nullptr);
0246 }
0247 
0248 TEST_CASE("Restart the interpreter") {
0249     // Verify pre-restart state.
0250     REQUIRE(py::module_::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
0251     REQUIRE(has_pybind11_internals_builtin());
0252     REQUIRE(has_pybind11_internals_static());
0253     REQUIRE(py::module_::import("external_module").attr("A")(123).attr("value").cast<int>()
0254             == 123);
0255 
0256     // local and foreign module internals should point to the same internals:
0257     REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
0258             == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
0259 
0260     // Restart the interpreter.
0261     py::finalize_interpreter();
0262     REQUIRE(Py_IsInitialized() == 0);
0263 
0264     py::initialize_interpreter();
0265     REQUIRE(Py_IsInitialized() == 1);
0266 
0267     // Internals are deleted after a restart.
0268     REQUIRE_FALSE(has_pybind11_internals_builtin());
0269     REQUIRE_FALSE(has_pybind11_internals_static());
0270     pybind11::detail::get_internals();
0271     REQUIRE(has_pybind11_internals_builtin());
0272     REQUIRE(has_pybind11_internals_static());
0273     REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
0274             == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
0275 
0276     // Make sure that an interpreter with no get_internals() created until finalize still gets the
0277     // internals destroyed
0278     py::finalize_interpreter();
0279     py::initialize_interpreter();
0280     bool ran = false;
0281     py::module_::import("__main__").attr("internals_destroy_test")
0282         = py::capsule(&ran, [](void *ran) {
0283               py::detail::get_internals();
0284               *static_cast<bool *>(ran) = true;
0285           });
0286     REQUIRE_FALSE(has_pybind11_internals_builtin());
0287     REQUIRE_FALSE(has_pybind11_internals_static());
0288     REQUIRE_FALSE(ran);
0289     py::finalize_interpreter();
0290     REQUIRE(ran);
0291     py::initialize_interpreter();
0292     REQUIRE_FALSE(has_pybind11_internals_builtin());
0293     REQUIRE_FALSE(has_pybind11_internals_static());
0294 
0295     // C++ modules can be reloaded.
0296     auto cpp_module = py::module_::import("widget_module");
0297     REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
0298 
0299     // C++ type information is reloaded and can be used in python modules.
0300     auto py_module = py::module_::import("test_interpreter");
0301     auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
0302     REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
0303 }
0304 
0305 TEST_CASE("Subinterpreter") {
0306     // Add tags to the modules in the main interpreter and test the basics.
0307     py::module_::import("__main__").attr("main_tag") = "main interpreter";
0308     {
0309         auto m = py::module_::import("widget_module");
0310         m.attr("extension_module_tag") = "added to module in main interpreter";
0311 
0312         REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
0313     }
0314     REQUIRE(has_pybind11_internals_builtin());
0315     REQUIRE(has_pybind11_internals_static());
0316 
0317     /// Create and switch to a subinterpreter.
0318     auto *main_tstate = PyThreadState_Get();
0319     auto *sub_tstate = Py_NewInterpreter();
0320 
0321     // Subinterpreters get their own copy of builtins. detail::get_internals() still
0322     // works by returning from the static variable, i.e. all interpreters share a single
0323     // global pybind11::internals;
0324     REQUIRE_FALSE(has_pybind11_internals_builtin());
0325     REQUIRE(has_pybind11_internals_static());
0326 
0327     // Modules tags should be gone.
0328     REQUIRE_FALSE(py::hasattr(py::module_::import("__main__"), "tag"));
0329     {
0330         auto m = py::module_::import("widget_module");
0331         REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
0332 
0333         // Function bindings should still work.
0334         REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
0335     }
0336 
0337     // Restore main interpreter.
0338     Py_EndInterpreter(sub_tstate);
0339     PyThreadState_Swap(main_tstate);
0340 
0341     REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag"));
0342     REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag"));
0343 }
0344 
0345 TEST_CASE("Execution frame") {
0346     // When the interpreter is embedded, there is no execution frame, but `py::exec`
0347     // should still function by using reasonable globals: `__main__.__dict__`.
0348     py::exec("var = dict(number=42)");
0349     REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
0350 }
0351 
0352 TEST_CASE("Threads") {
0353     // Restart interpreter to ensure threads are not initialized
0354     py::finalize_interpreter();
0355     py::initialize_interpreter();
0356     REQUIRE_FALSE(has_pybind11_internals_static());
0357 
0358     constexpr auto num_threads = 10;
0359     auto locals = py::dict("count"_a = 0);
0360 
0361     {
0362         py::gil_scoped_release gil_release{};
0363 
0364         auto threads = std::vector<std::thread>();
0365         for (auto i = 0; i < num_threads; ++i) {
0366             threads.emplace_back([&]() {
0367                 py::gil_scoped_acquire gil{};
0368                 locals["count"] = locals["count"].cast<int>() + 1;
0369             });
0370         }
0371 
0372         for (auto &thread : threads) {
0373             thread.join();
0374         }
0375     }
0376 
0377     REQUIRE(locals["count"].cast<int>() == num_threads);
0378 }
0379 
0380 // Scope exit utility https://stackoverflow.com/a/36644501/7255855
0381 struct scope_exit {
0382     std::function<void()> f_;
0383     explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
0384     ~scope_exit() {
0385         if (f_) {
0386             f_();
0387         }
0388     }
0389 };
0390 
0391 TEST_CASE("Reload module from file") {
0392     // Disable generation of cached bytecode (.pyc files) for this test, otherwise
0393     // Python might pick up an old version from the cache instead of the new versions
0394     // of the .py files generated below
0395     auto sys = py::module_::import("sys");
0396     bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
0397     sys.attr("dont_write_bytecode") = true;
0398     // Reset the value at scope exit
0399     scope_exit reset_dont_write_bytecode(
0400         [&]() { sys.attr("dont_write_bytecode") = dont_write_bytecode; });
0401 
0402     std::string module_name = "test_module_reload";
0403     std::string module_file = module_name + ".py";
0404 
0405     // Create the module .py file
0406     std::ofstream test_module(module_file);
0407     test_module << "def test():\n";
0408     test_module << "    return 1\n";
0409     test_module.close();
0410     // Delete the file at scope exit
0411     scope_exit delete_module_file([&]() { std::remove(module_file.c_str()); });
0412 
0413     // Import the module from file
0414     auto module_ = py::module_::import(module_name.c_str());
0415     int result = module_.attr("test")().cast<int>();
0416     REQUIRE(result == 1);
0417 
0418     // Update the module .py file with a small change
0419     test_module.open(module_file);
0420     test_module << "def test():\n";
0421     test_module << "    return 2\n";
0422     test_module.close();
0423 
0424     // Reload the module
0425     module_.reload();
0426     result = module_.attr("test")().cast<int>();
0427     REQUIRE(result == 2);
0428 }
0429 
0430 TEST_CASE("sys.argv gets initialized properly") {
0431     py::finalize_interpreter();
0432     {
0433         py::scoped_interpreter default_scope;
0434         auto module = py::module::import("test_interpreter");
0435         auto py_widget = module.attr("DerivedWidget")("The question");
0436         const auto &cpp_widget = py_widget.cast<const Widget &>();
0437         REQUIRE(cpp_widget.argv0().empty());
0438     }
0439 
0440     {
0441         char *argv[] = {strdup("a.out")};
0442         py::scoped_interpreter argv_scope(true, 1, argv);
0443         std::free(argv[0]);
0444         auto module = py::module::import("test_interpreter");
0445         auto py_widget = module.attr("DerivedWidget")("The question");
0446         const auto &cpp_widget = py_widget.cast<const Widget &>();
0447         REQUIRE(cpp_widget.argv0() == "a.out");
0448     }
0449     py::initialize_interpreter();
0450 }
0451 
0452 TEST_CASE("make_iterator can be called before then after finalizing an interpreter") {
0453     // Reproduction of issue #2101 (https://github.com/pybind/pybind11/issues/2101)
0454     py::finalize_interpreter();
0455 
0456     std::vector<int> container;
0457     {
0458         pybind11::scoped_interpreter g;
0459         auto iter = pybind11::make_iterator(container.begin(), container.end());
0460     }
0461 
0462     REQUIRE_NOTHROW([&]() {
0463         pybind11::scoped_interpreter g;
0464         auto iter = pybind11::make_iterator(container.begin(), container.end());
0465     }());
0466 
0467     py::initialize_interpreter();
0468 }