Skip to content

Commit

Permalink
Update benchmark Python bindings for nanobind 2.0, and update to nano…
Browse files Browse the repository at this point in the history
…bind 2.0. (#1817)

Incorporates the nanobind_bazel change from #1795.

nanobind 2.0 reworked the nanobind::enum_ class so it uses a real Python enum or intenum rather than its previous hand-rolled implementation.
https://nanobind.readthedocs.io/en/latest/changelog.html#version-2-0-0-may-23-2024

As a consequence of that change, nanobind now checks when casting an integer to a enum value that the integer corresponds to a valid enum. Counter::Flags is a bitmask, and many combinations are not valid enum members.

This change:
a) sets nb::is_arithmetic(), which means Counter::Flags becomes an IntEnum that can be freely cast to an integer.
b) defines the | operator for flags to return an integer, not an enum, avoiding the error.
c) changes Counter's constructor to accept an int, not a Counter::Flags enum. Since Counter::Flags is an IntEnum now, it can be freely coerced to an int.

If wjakob/nanobind#599 is merged into nanobind, then we can perhaps use a flag enum here instead.
  • Loading branch information
hawkinsp authored Jul 18, 2024
1 parent a6ad7fb commit 64b5d8c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ use_repo(pip, "tools_pip_deps")

# -- bazel_dep definitions -- #

bazel_dep(name = "nanobind_bazel", version = "1.0.0", dev_dependency = True)
bazel_dep(name = "nanobind_bazel", version = "2.0.0", dev_dependency = True)
19 changes: 13 additions & 6 deletions bindings/python/google_benchmark/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ NB_MODULE(_benchmark, m) {
using benchmark::Counter;
nb::class_<Counter> py_counter(m, "Counter");

nb::enum_<Counter::Flags>(py_counter, "Flags")
nb::enum_<Counter::Flags>(py_counter, "Flags", nb::is_arithmetic())
.value("kDefaults", Counter::Flags::kDefaults)
.value("kIsRate", Counter::Flags::kIsRate)
.value("kAvgThreads", Counter::Flags::kAvgThreads)
Expand All @@ -130,18 +130,25 @@ NB_MODULE(_benchmark, m) {
.value("kAvgIterationsRate", Counter::Flags::kAvgIterationsRate)
.value("kInvert", Counter::Flags::kInvert)
.export_values()
.def(nb::self | nb::self);
.def("__or__", [](Counter::Flags a, Counter::Flags b) {
return static_cast<int>(a) | static_cast<int>(b);
});

nb::enum_<Counter::OneK>(py_counter, "OneK")
.value("kIs1000", Counter::OneK::kIs1000)
.value("kIs1024", Counter::OneK::kIs1024)
.export_values();

py_counter
.def(nb::init<double, Counter::Flags, Counter::OneK>(),
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
nb::arg("k") = Counter::kIs1000)
.def("__init__", ([](Counter *c, double value) { new (c) Counter(value); }))
.def(
"__init__",
[](Counter* c, double value, int flags, Counter::OneK oneK) {
new (c) Counter(value, static_cast<Counter::Flags>(flags), oneK);
},
nb::arg("value") = 0., nb::arg("flags") = Counter::kDefaults,
nb::arg("k") = Counter::kIs1000)
.def("__init__",
([](Counter* c, double value) { new (c) Counter(value); }))
.def_rw("value", &Counter::value)
.def_rw("flags", &Counter::flags)
.def_rw("oneK", &Counter::oneK)
Expand Down

0 comments on commit 64b5d8c

Please sign in to comment.