Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:17:56

0001 /*
0002     tests/test_operator_overloading.cpp -- operator overloading
0003 
0004     Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
0005 
0006     All rights reserved. Use of this source code is governed by a
0007     BSD-style license that can be found in the LICENSE file.
0008 */
0009 
0010 #include <pybind11/operators.h>
0011 #include <pybind11/stl.h>
0012 
0013 #include "constructor_stats.h"
0014 #include "pybind11_tests.h"
0015 
0016 #include <functional>
0017 
0018 class Vector2 {
0019 public:
0020     Vector2(float x, float y) : x(x), y(y) { print_created(this, toString()); }
0021     Vector2(const Vector2 &v) : x(v.x), y(v.y) { print_copy_created(this); }
0022     Vector2(Vector2 &&v) noexcept : x(v.x), y(v.y) {
0023         print_move_created(this);
0024         v.x = v.y = 0;
0025     }
0026     Vector2 &operator=(const Vector2 &v) {
0027         x = v.x;
0028         y = v.y;
0029         print_copy_assigned(this);
0030         return *this;
0031     }
0032     Vector2 &operator=(Vector2 &&v) noexcept {
0033         x = v.x;
0034         y = v.y;
0035         v.x = v.y = 0;
0036         print_move_assigned(this);
0037         return *this;
0038     }
0039     ~Vector2() { print_destroyed(this); }
0040 
0041     std::string toString() const {
0042         return "[" + std::to_string(x) + ", " + std::to_string(y) + "]";
0043     }
0044 
0045     Vector2 operator-() const { return Vector2(-x, -y); }
0046     Vector2 operator+(const Vector2 &v) const { return Vector2(x + v.x, y + v.y); }
0047     Vector2 operator-(const Vector2 &v) const { return Vector2(x - v.x, y - v.y); }
0048     Vector2 operator-(float value) const { return Vector2(x - value, y - value); }
0049     Vector2 operator+(float value) const { return Vector2(x + value, y + value); }
0050     Vector2 operator*(float value) const { return Vector2(x * value, y * value); }
0051     Vector2 operator/(float value) const { return Vector2(x / value, y / value); }
0052     Vector2 operator*(const Vector2 &v) const { return Vector2(x * v.x, y * v.y); }
0053     Vector2 operator/(const Vector2 &v) const { return Vector2(x / v.x, y / v.y); }
0054     Vector2 &operator+=(const Vector2 &v) {
0055         x += v.x;
0056         y += v.y;
0057         return *this;
0058     }
0059     Vector2 &operator-=(const Vector2 &v) {
0060         x -= v.x;
0061         y -= v.y;
0062         return *this;
0063     }
0064     Vector2 &operator*=(float v) {
0065         x *= v;
0066         y *= v;
0067         return *this;
0068     }
0069     Vector2 &operator/=(float v) {
0070         x /= v;
0071         y /= v;
0072         return *this;
0073     }
0074     Vector2 &operator*=(const Vector2 &v) {
0075         x *= v.x;
0076         y *= v.y;
0077         return *this;
0078     }
0079     Vector2 &operator/=(const Vector2 &v) {
0080         x /= v.x;
0081         y /= v.y;
0082         return *this;
0083     }
0084 
0085     friend Vector2 operator+(float f, const Vector2 &v) { return Vector2(f + v.x, f + v.y); }
0086     friend Vector2 operator-(float f, const Vector2 &v) { return Vector2(f - v.x, f - v.y); }
0087     friend Vector2 operator*(float f, const Vector2 &v) { return Vector2(f * v.x, f * v.y); }
0088     friend Vector2 operator/(float f, const Vector2 &v) { return Vector2(f / v.x, f / v.y); }
0089 
0090     bool operator==(const Vector2 &v) const { return x == v.x && y == v.y; }
0091     bool operator!=(const Vector2 &v) const { return x != v.x || y != v.y; }
0092 
0093 private:
0094     float x, y;
0095 };
0096 
0097 class C1 {};
0098 class C2 {};
0099 
0100 int operator+(const C1 &, const C1 &) { return 11; }
0101 int operator+(const C2 &, const C2 &) { return 22; }
0102 int operator+(const C2 &, const C1 &) { return 21; }
0103 int operator+(const C1 &, const C2 &) { return 12; }
0104 
0105 struct HashMe {
0106     std::string member;
0107 };
0108 
0109 bool operator==(const HashMe &lhs, const HashMe &rhs) { return lhs.member == rhs.member; }
0110 
0111 // Note: Specializing explicit within `namespace std { ... }` is done due to a
0112 // bug in GCC<7. If you are supporting compilers later than this, consider
0113 // specializing `using template<> struct std::hash<...>` in the global
0114 // namespace instead, per this recommendation:
0115 // https://en.cppreference.com/w/cpp/language/extending_std#Adding_template_specializations
0116 namespace std {
0117 template <>
0118 struct hash<Vector2> {
0119     // Not a good hash function, but easy to test
0120     size_t operator()(const Vector2 &) { return 4; }
0121 };
0122 
0123 // HashMe has a hash function in C++ but no `__hash__` for Python.
0124 template <>
0125 struct hash<HashMe> {
0126     std::size_t operator()(const HashMe &selector) const {
0127         return std::hash<std::string>()(selector.member);
0128     }
0129 };
0130 } // namespace std
0131 
0132 // Not a good abs function, but easy to test.
0133 std::string abs(const Vector2 &) { return "abs(Vector2)"; }
0134 
0135 // clang 7.0.0 and Apple LLVM 10.0.1 introduce `-Wself-assign-overloaded` to
0136 // `-Wall`, which is used here for overloading (e.g. `py::self += py::self `).
0137 // Here, we suppress the warning
0138 // Taken from: https://github.com/RobotLocomotion/drake/commit/aaf84b46
0139 // TODO(eric): This could be resolved using a function / functor (e.g. `py::self()`).
0140 #if defined(__APPLE__) && defined(__clang__)
0141 #    if (__clang_major__ >= 10)
0142 PYBIND11_WARNING_DISABLE_CLANG("-Wself-assign-overloaded")
0143 #    endif
0144 #elif defined(__clang__)
0145 #    if (__clang_major__ >= 7)
0146 PYBIND11_WARNING_DISABLE_CLANG("-Wself-assign-overloaded")
0147 #    endif
0148 #endif
0149 
0150 TEST_SUBMODULE(operators, m) {
0151 
0152     // test_operator_overloading
0153     py::class_<Vector2>(m, "Vector2")
0154         .def(py::init<float, float>())
0155         .def(py::self + py::self)
0156         .def(py::self + float())
0157         .def(py::self - py::self)
0158         .def(py::self - float())
0159         .def(py::self * float())
0160         .def(py::self / float())
0161         .def(py::self * py::self)
0162         .def(py::self / py::self)
0163         .def(py::self += py::self)
0164         .def(py::self -= py::self)
0165         .def(py::self *= float())
0166         .def(py::self /= float())
0167         .def(py::self *= py::self)
0168         .def(py::self /= py::self)
0169         .def(float() + py::self)
0170         .def(float() - py::self)
0171         .def(float() * py::self)
0172         .def(float() / py::self)
0173         .def(-py::self)
0174         .def("__str__", &Vector2::toString)
0175         .def("__repr__", &Vector2::toString)
0176         .def(py::self == py::self)
0177         .def(py::self != py::self)
0178         .def(py::hash(py::self))
0179         // N.B. See warning about usage of `py::detail::abs(py::self)` in
0180         // `operators.h`.
0181         .def("__abs__", [](const Vector2 &v) { return abs(v); });
0182 
0183     m.attr("Vector") = m.attr("Vector2");
0184 
0185     // test_operators_notimplemented
0186     // #393: need to return NotSupported to ensure correct arithmetic operator behavior
0187     py::class_<C1>(m, "C1").def(py::init<>()).def(py::self + py::self);
0188 
0189     py::class_<C2>(m, "C2")
0190         .def(py::init<>())
0191         .def(py::self + py::self)
0192         .def("__add__", [](const C2 &c2, const C1 &c1) { return c2 + c1; })
0193         .def("__radd__", [](const C2 &c2, const C1 &c1) { return c1 + c2; });
0194 
0195     // test_nested
0196     // #328: first member in a class can't be used in operators
0197     struct NestABase {
0198         int value = -2;
0199     };
0200     py::class_<NestABase>(m, "NestABase")
0201         .def(py::init<>())
0202         .def_readwrite("value", &NestABase::value);
0203 
0204     struct NestA : NestABase {
0205         int value = 3;
0206         NestA &operator+=(int i) {
0207             value += i;
0208             return *this;
0209         }
0210     };
0211     py::class_<NestA>(m, "NestA")
0212         .def(py::init<>())
0213         .def(py::self += int())
0214         .def(
0215             "as_base",
0216             [](NestA &a) -> NestABase & { return (NestABase &) a; },
0217             py::return_value_policy::reference_internal);
0218     m.def("get_NestA", [](const NestA &a) { return a.value; });
0219 
0220     struct NestB {
0221         NestA a;
0222         int value = 4;
0223         NestB &operator-=(int i) {
0224             value -= i;
0225             return *this;
0226         }
0227     };
0228     py::class_<NestB>(m, "NestB")
0229         .def(py::init<>())
0230         .def(py::self -= int())
0231         .def_readwrite("a", &NestB::a);
0232     m.def("get_NestB", [](const NestB &b) { return b.value; });
0233 
0234     struct NestC {
0235         NestB b;
0236         int value = 5;
0237         NestC &operator*=(int i) {
0238             value *= i;
0239             return *this;
0240         }
0241     };
0242     py::class_<NestC>(m, "NestC")
0243         .def(py::init<>())
0244         .def(py::self *= int())
0245         .def_readwrite("b", &NestC::b);
0246     m.def("get_NestC", [](const NestC &c) { return c.value; });
0247 
0248     // test_overriding_eq_reset_hash
0249     // #2191 Overriding __eq__ should set __hash__ to None
0250     struct Comparable {
0251         int value;
0252         bool operator==(const Comparable &rhs) const { return value == rhs.value; }
0253     };
0254 
0255     struct Hashable : Comparable {
0256         explicit Hashable(int value) : Comparable{value} {};
0257         size_t hash() const { return static_cast<size_t>(value); }
0258     };
0259 
0260     struct Hashable2 : Hashable {
0261         using Hashable::Hashable;
0262     };
0263 
0264     py::class_<Comparable>(m, "Comparable").def(py::init<int>()).def(py::self == py::self);
0265 
0266     py::class_<Hashable>(m, "Hashable")
0267         .def(py::init<int>())
0268         .def(py::self == py::self)
0269         .def("__hash__", &Hashable::hash);
0270 
0271     // define __hash__ before __eq__
0272     py::class_<Hashable2>(m, "Hashable2")
0273         .def("__hash__", &Hashable::hash)
0274         .def(py::init<int>())
0275         .def(py::self == py::self);
0276 
0277     // define __eq__ but not __hash__
0278     py::class_<HashMe>(m, "HashMe").def(py::self == py::self);
0279 
0280     m.def("get_unhashable_HashMe_set", []() { return std::unordered_set<HashMe>{{"one"}}; });
0281 }