Skip to content

Commit

Permalink
Search visitor (#31)
Browse files Browse the repository at this point in the history
* Improved search visitor handling.
* Added approximate versions of SearchNn, SearchRadius, and search_radius.
* Added support for Eigen::Map<const Eigen::Matrix<>>.
* Added RKdTree to pico_understory.
* Added the mnist example.
* Version bump.
  • Loading branch information
Jaybro committed Aug 10, 2023
1 parent 9d83fa7 commit 9ac48f3
Show file tree
Hide file tree
Showing 25 changed files with 929 additions and 230 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/utils.cmake)

project(pico_tree
LANGUAGES CXX
VERSION 0.8.0
VERSION 0.8.1
DESCRIPTION "PicoTree is a C++ header only library for fast nearest neighbor searches and range searches using a KdTree."
HOMEPAGE_URL "https://github.com/Jaybro/pico_tree")

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ PicoTree is a C++ header only library with [Python bindings](https://github.com/
| [Scikit-learn KDTree][skkd] 1.2.2 | ... | 6.2s | ... | 42.2s |
| [pykdtree][pykd] 1.3.7 | ... | 1.0s | ... | 6.6s |
| [OpenCV FLANN][cvfn] 4.6.0 | 1.9s | ... | 4.7s | ... |
| PicoTree KdTree v0.8.0 | 0.9s | 1.0s | 2.8s | 3.1s |
| PicoTree KdTree v0.8.1 | 0.9s | 1.0s | 2.8s | 3.1s |

Two [LiDAR](./docs/benchmark.md) based point clouds of sizes 7733372 and 7200863 were used to generate these numbers. The first point cloud was the input to the build algorithm and the second to the query algorithm. All benchmarks were run on a single thread with the following parameters: `max_leaf_size=10` and `knn=1`. A more detailed [C++ comparison](./docs/benchmark.md) of PicoTree is available with respect to [nanoflann][nano].

Expand Down Expand Up @@ -61,6 +61,7 @@ PicoTree can interface with different types of points and point sets through tra
* Creating a [custom search visitor](./examples/kd_tree/kd_tree_custom_search_visitor.cpp).
* [Saving and loading](./examples/kd_tree/kd_tree_save_and_load.cpp) a KdTree to and from a file.
* Support for [Eigen](./examples/eigen/eigen.cpp) and [OpenCV](./examples/opencv/opencv.cpp) data types.
* Running the KdTree on the [MNIST](./examples/mnist/mnist.cpp) [database](http://yann.lecun.com/exdb/mnist/).
* How to use the [KdTree with Python](./examples/python/kd_tree.py).

# Requirements
Expand Down
4 changes: 4 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ else()
message(STATUS "benchmark not found. PicoTree benchmarks skipped.")
endif()

if(Eigen3_FOUND)
add_subdirectory(mnist)
endif()

# The Python examples only get copied when the bindings module will be build.
if(TARGET _pyco_tree)
add_subdirectory(python)
Expand Down
5 changes: 3 additions & 2 deletions examples/benchmark/bm_opencv_flann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ BENCHMARK_DEFINE_F(BmOpenCvFlann, KnnCt)(benchmark::State& state) {
// There is also the option to query them all at once, but this doesn't really
// change performance and this version looks more like the other benchmarks.
for (auto _ : state) {
std::vector<Index> indices(knn_count);
// The only supported index type is int.
std::vector<int> indices(knn_count);
std::vector<Scalar> distances(knn_count);
fl::Matrix<Index> mat_indices(indices.data(), 1, knn_count);
fl::Matrix<int> mat_indices(indices.data(), 1, knn_count);
fl::Matrix<Scalar> mat_distances(distances.data(), 1, knn_count);

for (auto& p : points_test_) {
Expand Down
33 changes: 16 additions & 17 deletions examples/kd_tree/kd_tree_custom_search_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,39 @@
#include <pico_tree/kd_tree.hpp>
#include <pico_tree/vector_traits.hpp>

//! \brief Search visitor that counts how many points were considered as a
//! nearest neighbor.
// Search visitor that counts how many points were considered as a possible
// nearest neighbor.
template <typename Neighbor>
class SearchNnCounter {
public:
using NeighborType = Neighbor;
using IndexType = typename Neighbor::IndexType;
using ScalarType = typename Neighbor::ScalarType;

//! \brief Creates a visitor for approximate nearest neighbor searching.
//! \param nn Search result.
// Create a visitor for approximate nearest neighbor searching. The argument
// is the search result.
inline SearchNnCounter(Neighbor& nn) : count_(0), nn_(nn) {
// Initial search distance.
nn_.distance = std::numeric_limits<ScalarType>::max();
}

//! \brief Visit current point.
//! \details This method is required. The KdTree calls this function when it
//! finds a point that is closer to the query than the result of this
//! visitors' max() function. I.e., it found a new nearest neighbor.
//! \param idx Point index.
//! \param d Point distance (that depends on the metric).
// Visit current point. This method is required. The search algorithm calls
// this function for every point it encounters in the KdTree. The arguments of
// the method are respectively the index and distance of the visited point.
inline void operator()(IndexType const idx, ScalarType const dst) {
// Only update the nearest neighbor when the point we visit is actually
// closer to the query point.
if (max() > dst) {
nn_ = {idx, dst};
}
count_++;
nn_ = {idx, dst};
}

//! \brief Maximum search distance with respect to the query point.
//! \details This method is required.
// Maximum search distance with respect to the query point. This method is
// required. The nodes of the KdTree are filtered using this method.
inline ScalarType const& max() const { return nn_.distance; }

//! \brief Returns the number of points that were considered the nearest
//! neighbor.
//! \details This method is not required.
// The amount of points visited during a query.
inline IndexType const& count() const { return count_; }

private:
Expand All @@ -62,7 +61,7 @@ int main() {
SearchNnCounter<Neighbor> v(nn);
tree.SearchNearest(q, v);

std::cout << "Custom visitor # nns considered: " << v.count() << std::endl;
std::cout << "Number of points visited: " << v.count() << std::endl;

return 0;
}
3 changes: 3 additions & 0 deletions examples/mnist/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
add_executable(mnist mnist.cpp)
set_default_target_properties(mnist)
target_link_libraries(mnist PUBLIC pico_toolshed pico_understory Eigen3::Eigen)
122 changes: 122 additions & 0 deletions examples/mnist/mnist.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <algorithm>
#include <filesystem>
#include <iostream>
#include <pico_toolshed/format/format_bin.hpp>
#include <pico_toolshed/format/format_mnist.hpp>
#include <pico_toolshed/scoped_timer.hpp>
#include <pico_tree/array_traits.hpp>
#include <pico_tree/kd_tree.hpp>
#include <pico_tree/vector_traits.hpp>
#include <pico_understory/rkd_tree.hpp>

template <typename U, typename T, std::size_t N>
std::array<U, N> Cast(std::array<T, N> const& i) {
std::array<U, N> c;
std::transform(i.begin(), i.end(), c.begin(), [](T a) -> U {
return static_cast<U>(a);
});
return c;
}

template <std::size_t N>
std::vector<std::array<float, N>> Cast(
std::vector<std::array<std::byte, N>> const& i) {
std::vector<std::array<float, N>> c;
std::transform(
i.begin(),
i.end(),
std::back_inserter(c),
[](std::array<std::byte, N> const& a) -> std::array<float, N> {
return Cast<float>(a);
});
return c;
}

int main(int argc, char** argv) {
using ImageByte = std::array<std::byte, 28 * 28>;
using ImageFloat = std::array<float, 28 * 28>;

std::string fn_images_train = "train-images.idx3-ubyte";
std::string fn_images_test = "t10k-images.idx3-ubyte";
std::string fn_mnist_nns_gt = "mnist_nns_gt.bin";

if (!std::filesystem::exists(fn_images_train)) {
std::cout << fn_images_train << " doesn't exist." << std::endl;
return 0;
}

if (!std::filesystem::exists(fn_images_test)) {
std::cout << fn_images_test << " doesn't exist." << std::endl;
return 0;
}

std::vector<ImageFloat> images_train;
{
std::vector<ImageByte> images_train_u8;
pico_tree::ReadMnistImages(fn_images_train, images_train_u8);
images_train = Cast(images_train_u8);
}

std::vector<ImageFloat> images_test;
{
std::vector<ImageByte> images_test_u8;
pico_tree::ReadMnistImages(fn_images_test, images_test_u8);
images_test = Cast(images_test_u8);
}

std::size_t max_leaf_size_ex = 16;
std::size_t max_leaf_size_rp = 128;
// With 16 trees we can get a precision of around 85-90%.
// With 32 trees we can get a precision of around 95-97%.
std::size_t forest_size = 2;
std::size_t count = images_test.size();
std::vector<pico_tree::Neighbor<int, float>> nns(count);

if (!std::filesystem::exists(fn_images_train)) {
auto kd_tree = [&images_train, &max_leaf_size_ex]() {
ScopedTimer t0("kd_tree build");
return pico_tree::KdTree<std::reference_wrapper<std::vector<ImageFloat>>>(
images_train, max_leaf_size_ex);
}();

{
ScopedTimer t1("kd_tree query");
for (std::size_t i = 0; i < nns.size(); ++i) {
kd_tree.SearchNn(images_test[i], nns[i]);
}
}

std::cout << "Writing " << fn_mnist_nns_gt << "." << std::endl;
pico_tree::WriteBin(fn_mnist_nns_gt, nns);
} else {
std::cout << "Reading " << fn_mnist_nns_gt << "." << std::endl;
pico_tree::ReadBin(fn_mnist_nns_gt, nns);
}

std::size_t equal = 0;

{
auto rkd_tree = [&images_train, &max_leaf_size_rp, &forest_size]() {
ScopedTimer t0("rkd_tree build");
return pico_tree::RKdTree<
std::reference_wrapper<std::vector<ImageFloat>>>(
images_train, max_leaf_size_rp, forest_size);
}();

ScopedTimer t1("rkd_tree query");
pico_tree::Neighbor<int, float> nn;
for (std::size_t i = 0; i < nns.size(); ++i) {
rkd_tree.SearchNn(images_test[i], nn);

if (nns[i].index == nn.index) {
++equal;
}
}
}

std::cout << "Precision: "
<< (static_cast<float>(equal) / static_cast<float>(count))
<< std::endl;

return 0;
}
14 changes: 12 additions & 2 deletions examples/pico_understory/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ target_include_directories(pico_understory INTERFACE ${CMAKE_CURRENT_LIST_DIR})
target_link_libraries(pico_understory INTERFACE PicoTree::PicoTree)
target_sources(pico_understory
INTERFACE
${CMAKE_CURRENT_LIST_DIR}/pico_understory/cover_tree.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/metric.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_base.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_builder.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_data.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_node.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/cover_tree_search.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/rkd_tree_builder.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/rkd_tree_rr_data.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/rkd_tree_search.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/internal/static_buffer.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/cover_tree.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/metric.hpp
${CMAKE_CURRENT_LIST_DIR}/pico_understory/rkd_tree.hpp
)
Loading

0 comments on commit 9ac48f3

Please sign in to comment.