From a4ac6693996ec6049b61e9e61a0b34139838647c Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Fri, 22 Mar 2024 17:13:15 +0100 Subject: [PATCH] Fix shared pointer support with xbuffer_adaptor, and allow it to move shared memory between xtensor_adaptor. --- include/xtensor/xbuffer_adaptor.hpp | 2 +- test/test_xbuffer_adaptor.cpp | 23 ++++++++++++ test/test_xtensor_adaptor.cpp | 58 +++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/include/xtensor/xbuffer_adaptor.hpp b/include/xtensor/xbuffer_adaptor.hpp index 368da5867..e131fc8ce 100644 --- a/include/xtensor/xbuffer_adaptor.hpp +++ b/include/xtensor/xbuffer_adaptor.hpp @@ -111,7 +111,7 @@ namespace xt using size_type = typename allocator_traits::size_type; using difference_type = typename allocator_traits::difference_type; - xbuffer_smart_pointer(); + xbuffer_smart_pointer() = default; template xbuffer_smart_pointer(P&& data_ptr, size_type size, DT&& destruct); diff --git a/test/test_xbuffer_adaptor.cpp b/test/test_xbuffer_adaptor.cpp index f58e54620..8989fb82b 100644 --- a/test/test_xbuffer_adaptor.cpp +++ b/test/test_xbuffer_adaptor.cpp @@ -7,6 +7,8 @@ * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ +#include + #include "xtensor/xbuffer_adaptor.hpp" #include "test_common_macros.hpp" @@ -216,4 +218,25 @@ namespace xt delete[] data; } + + TEST(xbuffer_adaptor, shared_owner) + { + using T = double; + using xbuf = xbuffer_adaptor>; + size_t size = 100; + auto data = std::shared_ptr(new T[size]); + + xbuf adapt(data.get(), size, data); + xbuf adapt2(adapt); + EXPECT_EQ(adapt.size(), adapt2.size()); + EXPECT_EQ(adapt.data(), adapt2.data()); + + xbuf adapt3(std::move(adapt2)); + EXPECT_EQ(adapt.size(), adapt3.size()); + EXPECT_EQ(adapt.data(), adapt3.data()); + + size_t size2 = 50; + XT_EXPECT_THROW(adapt.resize(size2), std::runtime_error); + XT_EXPECT_NO_THROW(adapt.resize(size)); + } } diff --git a/test/test_xtensor_adaptor.cpp b/test/test_xtensor_adaptor.cpp index 7cc404552..5655d6942 100644 --- a/test/test_xtensor_adaptor.cpp +++ b/test/test_xtensor_adaptor.cpp @@ -201,4 +201,62 @@ namespace xt test_iterator_types(); test_iterator_types(); } + + namespace xt_shared + { + template + using xtensor_buffer = xtensor_adaptor>, N, L>; + } + + TEST(xtensor_shared_buffer, share_shared_pointer) + { + using T = double; + using xtensor_type = xt_shared::xtensor_buffer; + using storage_type = xtensor_type::storage_type; + using inner_shape_type = typename xtensor_type::inner_shape_type; + using inner_strides_type = typename xtensor_type::inner_strides_type; + auto shape = xtensor_type::shape_type{100}; + auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + auto data = std::shared_ptr(new T[size]); + auto data2 = std::shared_ptr(new T[size]); + inner_shape_type inner_shape = shape; + inner_shape_type inner_shape2 = shape; + inner_strides_type inner_strides; + xt::compute_strides(inner_shape, XTENSOR_DEFAULT_LAYOUT, inner_strides); + + // Create storage with shared ownership + storage_type s(data.get(), size, data); + storage_type s2(data.get(), size, data); + // s3 has no shared ownership + storage_type s3(data2.get(), size, data); + // Now create the respective tensors + xtensor_type x(std::move(s), inner_shape_type(shape), inner_strides_type(inner_strides)); + xtensor_type x2(std::move(s2), inner_shape_type(shape), inner_strides_type(inner_strides)); + xtensor_type x3(std::move(s3), inner_shape_type(shape), inner_strides_type(inner_strides)); + + // Initialize both tensors (x & x3) to zero + x = xt::broadcast(double(0), {size}); + x3 = xt::broadcast(double(0), {size}); + + // Assign another shared memory tensor to x shared memory + xtensor_type y = x; + + // Modify the value in x + x(0) = 1.0; + + // Use x2 now + y = x2; + + // We do get 1.0 in y, because x and y share the same memory + EXPECT_EQ(y(0), 1.0); + + // Now assign x3 memory to y (i.e. unshare it with x) + y = x3; + // Change value in x2 + x2(0) = 2.0; + // We do not get 2.0 in y, because x2 and y do not share the same memory + EXPECT_EQ(y(0), 0.0); + // We do get 2.0 in x2 + EXPECT_EQ(x2(0), 2.0); + } }