diff --git a/oshmem/mca/sshmem/ucx/configure.m4 b/oshmem/mca/sshmem/ucx/configure.m4 index 3b88aaf5d93..cc938075003 100644 --- a/oshmem/mca/sshmem/ucx/configure.m4 +++ b/oshmem/mca/sshmem/ucx/configure.m4 @@ -29,10 +29,27 @@ AC_DEFUN([MCA_oshmem_sshmem_ucx_CONFIG],[ save_CPPFLAGS="$CPPFLAGS" alloc_dm_LDFLAGS=" -L$ompi_check_ucx_libdir/ucx" - alloc_dm_LIBS=" -luct_ib" CPPFLAGS+=" $sshmem_ucx_CPPFLAGS" LDFLAGS+=" $sshmem_ucx_LDFLAGS $alloc_dm_LDFLAGS" - LIBS+=" $sshmem_ucx_LIBS $alloc_dm_LIBS" + LIBS+=" $sshmem_ucx_LIBS" + + AC_LANG_PUSH([C]) + AC_LINK_IFELSE([AC_LANG_PROGRAM( + [[ + #include + ]], + [[ + ucs_memory_type_t mem_type = UCS_MEMORY_TYPE_RDMA; + ]])], + [ + AC_MSG_NOTICE([UCX device memory allocation is supported]) + AC_DEFINE([HAVE_UCX_DEVICE_MEM], [1], [Support for device memory allocation]) + ], + [ + AC_MSG_NOTICE([UCX device memory allocation is not supported]) + AC_DEFINE([HAVE_UCX_DEVICE_MEM], [0], [Support for device memory allocation]) + ]) + AC_LANG_POP([C]) CPPFLAGS="$save_CPPFLAGS" LDFLAGS="$save_LDFLAGS" diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx.h b/oshmem/mca/sshmem/ucx/sshmem_ucx.h index 9b308daa196..90d41ac002c 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx.h +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx.h @@ -16,7 +16,6 @@ #include "oshmem/mca/sshmem/sshmem.h" #include -#include BEGIN_C_DECLS diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c index 1af1fa30137..f44db04a80d 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c @@ -97,7 +97,7 @@ static segment_allocator_t sshmem_ucx_allocator = { static int segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, - unsigned flags, long hint) + unsigned flags, ucs_memory_type_t mem_type) { mca_sshmem_ucx_segment_context_t *ctx; int rc = OSHMEM_SUCCESS; @@ -119,8 +119,7 @@ segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, mem_map_params.address = address; mem_map_params.length = size; mem_map_params.flags = flags; - mem_map_params.memory_type = (hint & SHMEM_HINT_DEVICE_NIC_MEM) ? - UCS_MEMORY_TYPE_RDMA : UCS_MEMORY_TYPE_HOST; + mem_map_params.memory_type = mem_type; status = ucp_mem_map(spml->ucp_context, &mem_map_params, &mem_h); if (UCS_OK != status) { @@ -157,11 +156,7 @@ segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, ds_buf->super.va_end = (void*)((uintptr_t)ds_buf->super.va_base + ds_buf->seg_size); ds_buf->context = ctx; ds_buf->type = MAP_SEGMENT_ALLOC_UCX; - ds_buf->alloc_hints = hint; ctx->ucp_memh = mem_h; - if (hint) { - ds_buf->allocator = &sshmem_ucx_allocator; - } out: OPAL_OUTPUT_VERBOSE( @@ -176,19 +171,39 @@ segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, return rc; } +static int +segment_create_host_mem(map_segment_t *ds_buf, size_t size, unsigned flags) { + mca_spml_ucx_t *spml = (mca_spml_ucx_t*)mca_spml.self; + flags |= UCP_MEM_MAP_FIXED; + if (spml->heap_reg_nb) { + flags |= UCP_MEM_MAP_NONBLOCK; + } + return segment_create_internal(ds_buf, mca_sshmem_base_start_address, size, flags, UCS_MEMORY_TYPE_HOST); +} + static int segment_create(map_segment_t *ds_buf, const char *file_name, size_t size, long hint) { - mca_spml_ucx_t *spml = (mca_spml_ucx_t*)mca_spml.self; - unsigned flags = UCP_MEM_MAP_ALLOCATE | (spml->heap_reg_nb ? UCP_MEM_MAP_NONBLOCK : 0); - if (hint) { - return segment_create_internal(ds_buf, NULL, size, flags, hint); - } else { - return segment_create_internal(ds_buf, mca_sshmem_base_start_address, - size, flags | UCP_MEM_MAP_FIXED, hint); + unsigned flags = UCP_MEM_MAP_ALLOCATE; + int status = OSHMEM_SUCCESS; + +#if HAVE_UCX_DEVICE_MEM + if (hint & SHMEM_HINT_DEVICE_NIC_MEM) { + status = segment_create_internal(ds_buf, NULL, size, flags, UCS_MEMORY_TYPE_RDMA); + if (status == OSHMEM_SUCCESS) { + ds_buf->alloc_hints = hint; + if (hint) { + ds_buf->allocator = &sshmem_ucx_allocator; + } + return OSHMEM_SUCCESS; + } + /* Fallback - Try again using host memory*/ } +#endif + + return segment_create_host_mem(ds_buf, size, flags); } static void *