From 4e5e81efe10bac88db404a41a1154f49e1068df3 Mon Sep 17 00:00:00 2001 From: Devendar Bureddy Date: Mon, 2 Oct 2023 13:45:03 -0700 Subject: [PATCH] Reduce-scatter and Allgather APIs --- include/net.h | 1 + include/net_v7.h | 75 +---------------------- include/net_v8.h | 134 +++++++++++++++++++++++++++++++++++++++++ src/ib_plugin.c | 22 +++++++ src/p2p_plugin.c | 19 ++++++ src/sharp_plugin.c | 138 ++++++++++++++++++++++++++++++++++++++++++- src/ucx_plugin.c | 22 +++++++ src/ucx_rma_plugin.c | 23 +++++++- 8 files changed, 358 insertions(+), 76 deletions(-) create mode 100644 include/net_v8.h diff --git a/include/net.h b/include/net.h index 8b656bba..1e60bad8 100644 --- a/include/net.h +++ b/include/net.h @@ -22,6 +22,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); +#include "net_v8.h" #include "net_v7.h" #include "net_v6.h" #include "net_v5.h" diff --git a/include/net_v7.h b/include/net_v7.h index ac5e6654..fc91e8fa 100644 --- a/include/net_v7.h +++ b/include/net_v7.h @@ -5,79 +5,8 @@ #ifndef NCCL_NET_V7_H_ #define NCCL_NET_V7_H_ -#include "net_device.h" - -typedef struct { - char* name; // Used mostly for logging. - char* pciPath; // Path to the PCI device in /sys. - uint64_t guid; // Unique identifier for the NIC chip. Important for - // cards with multiple PCI functions (Physical or virtual). - int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF] - int speed; // Port speed in Mbps. - int port; // Port number. - float latency; // Network latency - int maxComms; // Maximum number of comms we can create - int maxRecvs; // Maximum number of grouped receives. - ncclNetDeviceType netDeviceType; // Network offload type - int netDeviceVersion; // Version number for network offload -} ncclNetProperties_v7_t; - -typedef ncclNetProperties_v7_t ncclNetProperties_t; - -typedef struct { - // Name of the network (mainly for logs) - const char* name; - // Initialize the network. - ncclResult_t (*init)(ncclDebugLogger_t logFunction); - // Return the number of adapters. - ncclResult_t (*devices)(int* ndev); - // Get various device properties. - ncclResult_t (*getProperties)(int dev, ncclNetProperties_v7_t* props); - // Create a receiving object and provide a handle to connect to it. The - // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged - // between ranks to create a connection. - ncclResult_t (*listen)(int dev, void* handle, void** listenComm); - // Connect to a handle and return a sending comm object for that peer. - // This call must not block for the connection to be established, and instead - // should return successfully with sendComm == NULL with the expectation that - // it will be called again until sendComm != NULL. - // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection - ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v7_t** sendDevComm); - // Finalize connection establishment after remote peer has called connect. - // This call must not block for the connection to be established, and instead - // should return successfully with recvComm == NULL with the expectation that - // it will be called again until recvComm != NULL. - // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection - ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v7_t** recvDevComm); - // Register/Deregister memory. Comm can be either a sendComm or a recvComm. - // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. - ncclResult_t (*regMr)(void* comm, void* data, int size, int type, void** mhandle); - /* DMA-BUF support */ - ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); - ncclResult_t (*deregMr)(void* comm, void* mhandle); - // Asynchronous send to a peer. - // May return request == NULL if the call cannot be performed (or would block) - ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, void* mhandle, void** request); - // Asynchronous recv from a peer. - // May return request == NULL if the call cannot be performed (or would block) - ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request); - // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is - // visible to the GPU - ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); - // Test whether a request is complete. If size is not NULL, it returns the - // number of bytes sent/received. - ncclResult_t (*test)(void* request, int* done, int* sizes); - // Close and free send/recv comm objects - ncclResult_t (*closeSend)(void* sendComm); - ncclResult_t (*closeRecv)(void* recvComm); - ncclResult_t (*closeListen)(void* listenComm); - - // Copy the given mhandle to a dptr in a format usable by this plugin's device code - ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); - - // Notify the plugin that a recv has completed by the device - ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); -} ncclNet_v7_t; +typedef ncclNetProperties_v8_t ncclNetProperties_v7_t; +typedef ncclNet_v8_t ncclNet_v7_t; // v7 struct for backwards compatibility typedef struct { diff --git a/include/net_v8.h b/include/net_v8.h new file mode 100644 index 00000000..59d1dcb0 --- /dev/null +++ b/include/net_v8.h @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2017-2023, NVIDIA CORPORATION. All rights reserved. + */ + +#ifndef NCCL_NET_V8_H_ +#define NCCL_NET_V8_H_ +#include "net_device.h" + +typedef struct { + char* name; // Used mostly for logging. + char* pciPath; // Path to the PCI device in /sys. + uint64_t guid; // Unique identifier for the NIC chip. Important for + // cards with multiple PCI functions (Physical or virtual). + int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF] + int speed; // Port speed in Mbps. + int port; // Port number. + float latency; // Network latency + int maxComms; // Maximum number of comms we can create + int maxRecvs; // Maximum number of grouped receives. + ncclNetDeviceType netDeviceType; // Network offload type + int netDeviceVersion; // Version number for network offload +} ncclNetProperties_v8_t; + +typedef ncclNetProperties_v8_t ncclNetProperties_t; + +typedef struct { + // Name of the network (mainly for logs) + const char* name; + // Initialize the network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create a connection. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Connect to a handle and return a sending comm object for that peer. + // This call must not block for the connection to be established, and instead + // should return successfully with sendComm == NULL with the expectation that + // it will be called again until sendComm != NULL. + // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v7_t** sendDevComm); + // Finalize connection establishment after remote peer has called connect. + // This call must not block for the connection to be established, and instead + // should return successfully with recvComm == NULL with the expectation that + // it will be called again until recvComm != NULL. + // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v7_t** recvDevComm); + // Register/Deregister memory. Comm can be either a sendComm or a recvComm. + // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* comm, void* data, int size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* comm, void* mhandle); + // Asynchronous send to a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, void* mhandle, void** request); + // Asynchronous recv from a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* sizes); + // Close and free send/recv comm objects + ncclResult_t (*closeSend)(void* sendComm); + ncclResult_t (*closeRecv)(void* recvComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Copy the given mhandle to a dptr in a format usable by this plugin's device code + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify the plugin that a recv has completed by the device + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); +} ncclNet_v8_t; + +typedef struct { + void* mhandle; + void* address; + uint32_t size; +} ncclNetSGE_v8_t; + +typedef struct { + // Name of the collective network (mainly for logs) + const char* name; + // Initialize the collective network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters capable of doing collective operations. + // If ndev returns 0, all other functions might be set to NULL. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create connections. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Create a group for collective operations. handles have been created + // using listen() above. rank indicates caller's rank in the collective network. + ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm); + // Returns whether a reduction operation on a data type is supported. + // 1 for supported, 0 otherwise. + ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported); + // Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* collComm, void* data, int size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* collComm, void* mhandle); + // Performs an asynchronous allreduce operation on the collective group. + // May return request == NULL if the call cannot be performed (or would block). + ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, int count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request); + ncclResult_t (*iallgather)(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v8_t* recvParts, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + void* sendMhandle, void** request); + ncclResult_t (*ireducescatter)(void* collComm, int nSendParts, ncclNetSGE_v8_t* sendParts, void* recvData, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + ncclDataType_t dataType, ncclRedOp_t redOp, + void* recvMhandle, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* size); + // Close and free collective comm objects + ncclResult_t (*closeColl)(void* collComm); + ncclResult_t (*closeListen)(void* listenComm); +} ncclCollNet_v8_t; + +#endif // end include guard diff --git a/src/ib_plugin.c b/src/ib_plugin.c index 35e514e0..0a3344b4 100644 --- a/src/ib_plugin.c +++ b/src/ib_plugin.c @@ -1144,6 +1144,28 @@ ncclResult_t ncclIbCloseListen(void* listenComm) { return ncclSuccess; } +const ncclNet_v8_t ibPlugin_v8 = { + .name = "IBext_v8", + .init = ncclIbInit, + .devices = ncclIbDevices, + .getProperties = ncclIbGetProperties, + .listen = ncclIbListen, + .connect = ncclIbConnect, + .accept = ncclIbAccept, + .regMr = ncclIbRegMr, + .regMrDmaBuf = ncclIbRegMrDmaBuf, + .deregMr = ncclIbDeregMr, + .isend = ncclIbIsend, + .irecv = ncclIbIrecv, + .iflush = ncclIbIflush, + .test = ncclIbTest, + .closeSend = ncclIbCloseSend, + .closeRecv = ncclIbCloseRecv, + .closeListen = ncclIbCloseListen, + NULL /* getDeviceMr */, + NULL /* irecvConsumed */ +}; + const ncclNet_v7_t ibPlugin_v7 = { .name = "IBext_v7", .init = ncclIbInit, diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 956ffc76..220d628c 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -15,14 +15,17 @@ #include "p2p_plugin.h" #ifdef HAVE_UCX_PLUGIN +extern ncclNet_v8_t ucxPlugin_v8; extern ncclNet_v7_t ucxPlugin_v7; extern ncclNet_v6_t ucxPlugin_v6; extern ncclNet_v5_t ucxPlugin_v5; +extern ncclNet_v8_t ucxRmaPlugin_v8; extern ncclNet_v7_t ucxRmaPlugin_v7; extern ncclNet_v6_t ucxRmaPlugin_v6; extern ncclNet_v5_t ucxRmaPlugin_v5; #endif +extern ncclNet_v8_t ibPlugin_v8; extern ncclNet_v7_t ibPlugin_v7; extern ncclNet_v6_t ibPlugin_v6; extern ncclNet_v5_t ibPlugin_v5; @@ -40,10 +43,16 @@ extern int ncclIbRelaxedOrderingEnabled; NCCL_PARAM(SharpMaxComms, "SHARP_MAX_COMMS", 1); NCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2); +ncclResult_t pluginInit_v8(ncclDebugLogger_t logFunction); ncclResult_t pluginInit_v7(ncclDebugLogger_t logFunction); ncclResult_t pluginInit_v6(ncclDebugLogger_t logFunction); ncclResult_t pluginInit_v5(ncclDebugLogger_t logFunction); +ncclNet_v8_t ncclNetPlugin_v8 = { + "NCCL RDMA Plugin v8", + pluginInit_v8, +}; + ncclNet_v7_t ncclNetPlugin_v7 = { "NCCL RDMA Plugin v7", pluginInit_v7, @@ -85,17 +94,20 @@ static void pluginSetup() } switch (p2p_plugin) { case NCCL_P2P_IB: + ncclNetPlugin_v8 = ibPlugin_v8; ncclNetPlugin_v7 = ibPlugin_v7; ncclNetPlugin_v6 = ibPlugin_v6; ncclNetPlugin_v5 = ibPlugin_v5; break; #ifdef HAVE_UCX_PLUGIN case NCCL_P2P_UCX: + ncclNetPlugin_v8 = ucxPlugin_v8; ncclNetPlugin_v7 = ucxPlugin_v7; ncclNetPlugin_v6 = ucxPlugin_v6; ncclNetPlugin_v5 = ucxPlugin_v5; break; case NCCL_P2P_UCX_RMA: + ncclNetPlugin_v8 = ucxRmaPlugin_v8; ncclNetPlugin_v7 = ucxRmaPlugin_v7; ncclNetPlugin_v6 = ucxRmaPlugin_v6; ncclNetPlugin_v5 = ucxRmaPlugin_v5; @@ -105,6 +117,13 @@ static void pluginSetup() } +ncclResult_t pluginInit_v8(ncclDebugLogger_t logFunction) { + pluginLogFunction = logFunction; + pluginSetup(); + INFO(NCCL_INIT|NCCL_NET, "P2P plugin %s", ncclNetPlugin_v8.name); + return ncclNetPlugin_v8.init(logFunction); +} + ncclResult_t pluginInit_v7(ncclDebugLogger_t logFunction) { pluginLogFunction = logFunction; pluginSetup(); diff --git a/src/sharp_plugin.c b/src/sharp_plugin.c index 07380c5d..a9e473b2 100644 --- a/src/sharp_plugin.c +++ b/src/sharp_plugin.c @@ -20,6 +20,7 @@ #include "sharp/api/sharp_coll.h" #include "utils.h" +extern ncclNet_v8_t ncclNetPlugin_v8; extern ncclNet_v7_t ncclNetPlugin_v7; extern ncclNet_v6_t ncclNetPlugin_v6; extern ncclNet_v5_t ncclNetPlugin_v5; @@ -211,16 +212,18 @@ ncclResult_t ncclSharpDevices(int* ndev) { return ncclSuccess; } +ncclResult_t ncclSharpGetProperties_v8(int dev, ncclNetProperties_v8_t* props) { + return ncclNetPlugin_v8.getProperties(dev, props); +} + ncclResult_t ncclSharpGetProperties_v7(int dev, ncclNetProperties_v7_t* props) { return ncclNetPlugin_v7.getProperties(dev, props); } - ncclResult_t ncclSharpGetProperties_v6(int dev, ncclNetProperties_v6_t* props) { return ncclNetPlugin_v6.getProperties(dev, props); } - ncclResult_t ncclSharpGetProperties_v5(int dev, ncclNetProperties_v5_t* props) { return ncclNetPlugin_v5.getProperties(dev, props); } @@ -497,6 +500,117 @@ ncclResult_t ncclSharpIallreduce(void* collComm, void* sendData, void* recvData, return ncclSuccess; } +ncclResult_t ncclSharpIallgather(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v8_t* recvParts, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + void* sendMhandle, void** request) +{ + struct ncclSharpCollComm* cComm = (struct ncclSharpCollComm*)collComm; + struct ncclSharpMemHandle *send_mh = (struct ncclSharpMemHandle*)sendMhandle; + struct ncclSharpMemHandle *recv_mh = (struct ncclSharpMemHandle*)recvParts[0].mhandle; + struct ncclSharpRequest* req; + NCCLCHECK(ncclSharpGetRequest(cComm->reqs, &req)); + + + assert(nRecvParts == 1); + + struct sharp_coll_gather_spec gather_spec; + + gather_spec.sbuf_desc.type = SHARP_DATA_BUFFER; + gather_spec.sbuf_desc.buffer.ptr = sendData; + gather_spec.sbuf_desc.buffer.length = bytesPerRank; + gather_spec.sbuf_desc.buffer.mem_handle = send_mh->mr; + + gather_spec.rbuf_desc.type = SHARP_DATA_BUFFER; + gather_spec.rbuf_desc.buffer.ptr = recvParts[0].address; + gather_spec.rbuf_desc.buffer.length = recvParts[0].size; + gather_spec.rbuf_desc.buffer.mem_handle = recv_mh->mr; + + gather_spec.dtype = SHARP_DTYPE_INT8; + gather_spec.size = recvParts[0].size; + gather_spec.offset = windowOffset; + +#if BLOCKING==0 + if (SHARP_COLL_SUCCESS != sharp_coll_do_allgather_nb(cComm->sharpCollComm, &gather_spec, &req->sharpRequest)) { + WARN("SHARP Allgather failed\n"); + } + req->size = recvParts[0].size; +#else + if (SHARP_COLL_SUCCESS != sharp_coll_do_allgather(cComm->sharpCollComm, &gather_spec)) { + WARN("SHARP Allgather failed\n"); + } + req->sharpRequest = (void *) 0xabababab; + req->size = recvSize; +#endif + req->requestType = NCCL_SHARP_REQ_SHARP_COLL; + *request = req; + return ncclSuccess; +} + +ncclResult_t ncclSharpIreducescatter(void* collComm, int nSendParts, ncclNetSGE_v8_t* sendParts, void* recvData, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + ncclDataType_t dataType, ncclRedOp_t redOp, + void* recvMhandle, void** request) +{ + struct ncclSharpCollComm* cComm = (struct ncclSharpCollComm*)collComm; + + enum sharp_datatype sharp_type = typeConvert(dataType); + if (sharp_type == SHARP_DTYPE_NULL) { + WARN("SHARP: unsupported data type\n"); + return ncclInternalError; + } + + enum sharp_reduce_op op_type = opConvert(redOp); + if (op_type == SHARP_OP_NULL) { + WARN("SHARP: unsupported reduce operation\n"); + return ncclInternalError; + } + + assert(nSendParts == 1); + + int dt_size = typeSize(dataType); + struct ncclSharpMemHandle *mr_sbuf = (struct ncclSharpMemHandle*)sendParts[0].mhandle; + struct ncclSharpMemHandle *mr_rbuf = (struct ncclSharpMemHandle*)recvMhandle; + + struct ncclSharpRequest* req; + NCCLCHECK(ncclSharpGetRequest(cComm->reqs, &req)); + + struct sharp_coll_reduce_spec reduce_spec; + + reduce_spec.sbuf_desc.buffer.ptr = sendParts[0].address; + reduce_spec.sbuf_desc.buffer.length = sendParts[0].size; + reduce_spec.sbuf_desc.buffer.mem_handle = mr_sbuf->mr; + reduce_spec.sbuf_desc.type = SHARP_DATA_BUFFER; + reduce_spec.sbuf_desc.mem_type = (mr_sbuf->type == NCCL_PTR_CUDA ? SHARP_MEM_TYPE_CUDA:SHARP_MEM_TYPE_HOST); + + reduce_spec.rbuf_desc.buffer.ptr = recvData; + reduce_spec.rbuf_desc.buffer.length = bytesPerRank; + reduce_spec.rbuf_desc.buffer.mem_handle = mr_rbuf->mr; + reduce_spec.rbuf_desc.type = SHARP_DATA_BUFFER; + reduce_spec.rbuf_desc.mem_type = (mr_rbuf->type == NCCL_PTR_CUDA ? SHARP_MEM_TYPE_CUDA:SHARP_MEM_TYPE_HOST); + + reduce_spec.length = sendParts[0].size / dt_size; + reduce_spec.offset = windowOffset; + reduce_spec.dtype = sharp_type; + reduce_spec.op = op_type; + reduce_spec.aggr_mode = SHARP_AGGREGATION_NONE; + +#if BLOCKING==0 + if (SHARP_COLL_SUCCESS != sharp_coll_do_reduce_scatter_nb(cComm->sharpCollComm, &reduce_spec, &req->sharpRequest)) { + WARN("SHARP reduce_scatter failed\n"); + } + req->size = bytesPerRank; +#else + if (SHARP_COLL_SUCCESS != sharp_coll_do_reduce_scatter(cComm->sharpCollComm, &reduce_spec)) { + WARN("SHARP reduce_scater failed\n"); + } + req->sharpRequest = (void *) 0xabababab; + req->size = recvCount * dt_size; +#endif + req->requestType = NCCL_SHARP_REQ_SHARP_COLL; + *request = req; + return ncclSuccess; + } + ncclResult_t ncclSharpIflush(void* collComm, void* data, int size, void* mhandle, void **request) { struct ncclSharpCollComm *cComm = (struct ncclSharpCollComm*)collComm; struct ncclSharpMemHandle *mh = (struct ncclSharpMemHandle *)mhandle; @@ -568,6 +682,26 @@ ncclResult_t ncclSharpCloseListen(void* listenComm) { return status; } +ncclCollNet_v8_t ncclCollNetPlugin_v8 = { + "SHARP", + ncclSharpInit, + ncclSharpDevices, + ncclSharpGetProperties_v8, + ncclSharpListen, + ncclSharpConnect, + ncclSharpReduceSupport, + ncclSharpRegMr, + ncclSharpRegMrDmaBuf, + ncclSharpDeregMr, + ncclSharpIallreduce, + ncclSharpIallgather, + ncclSharpIreducescatter, + ncclSharpIflush, + ncclSharpTest, + ncclSharpCloseColl, + ncclSharpCloseListen +}; + ncclCollNet_v7_t ncclCollNetPlugin_v7 = { "SHARP", ncclSharpInit, diff --git a/src/ucx_plugin.c b/src/ucx_plugin.c index 466face8..0d980cf5 100644 --- a/src/ucx_plugin.c +++ b/src/ucx_plugin.c @@ -878,6 +878,28 @@ ncclResult_t nccl_ucx_close_listen(void *listen_comm) { return ncclSuccess; } +ncclNet_v8_t ucxPlugin_v8 = { + .name = "UCX", + .init = nccl_ucx_init, + .devices = nccl_ucx_devices, + .getProperties = nccl_ucx_get_properties, + .listen = nccl_ucx_listen, + .connect = nccl_ucx_connect, + .accept = nccl_ucx_accept, + .regMr = nccl_ucx_regmr, + .regMrDmaBuf = nccl_ucx_regmr_dmabuf, + .deregMr = nccl_ucx_deregmr, + .isend = nccl_ucx_isend, + .irecv = nccl_ucx_irecv, + .iflush = nccl_ucx_iflush, + .test = nccl_ucx_test, + .closeSend = nccl_ucx_close_send, + .closeRecv = nccl_ucx_close_recv, + .closeListen = nccl_ucx_close_listen, + NULL /* getDeviceMr */, + NULL /* irecvConsumed */ +}; + ncclNet_v7_t ucxPlugin_v7 = { .name = "UCX", .init = nccl_ucx_init, diff --git a/src/ucx_rma_plugin.c b/src/ucx_rma_plugin.c index 3e921a0c..41e98e5f 100644 --- a/src/ucx_rma_plugin.c +++ b/src/ucx_rma_plugin.c @@ -1164,7 +1164,7 @@ ncclResult_t nccl_ucx_rma_close_listen(void *listen_comm) return ncclSuccess; } -ncclNet_v7_t ucxRmaPlugin_v7 = { +ncclNet_v8_t ucxRmaPlugin_v8 = { .name = "UCX-RMA", .init = nccl_ucx_rma_init, .devices = nccl_ucx_rma_devices, @@ -1184,7 +1184,28 @@ ncclNet_v7_t ucxRmaPlugin_v7 = { .closeListen = nccl_ucx_rma_close_listen, NULL /* getDeviceMr */, NULL /* irecvConsumed */ +}; +ncclNet_v7_t ucxRmaPlugin_v7 = { + .name = "UCX-RMA", + .init = nccl_ucx_rma_init, + .devices = nccl_ucx_rma_devices, + .getProperties = nccl_ucx_rma_get_properties, + .listen = nccl_ucx_rma_listen, + .connect = nccl_ucx_rma_connect, + .accept = nccl_ucx_rma_accept, + .regMr = nccl_ucx_rma_regmr, + .regMrDmaBuf = nccl_ucx_rma_regmr_dmabuf, + .deregMr = nccl_ucx_rma_deregmr, + .isend = nccl_ucx_rma_isend, + .irecv = nccl_ucx_rma_irecv, + .iflush = nccl_ucx_rma_iflush, + .test = nccl_ucx_rma_test, + .closeSend = nccl_ucx_rma_close_send, + .closeRecv = nccl_ucx_rma_close_recv, + .closeListen = nccl_ucx_rma_close_listen, + NULL /* getDeviceMr */, + NULL /* irecvConsumed */ }; ncclNet_v6_t ucxRmaPlugin_v6 = {