Skip to content

Commit

Permalink
rdma: add separate request types for eager/ctrl rx buffers
Browse files Browse the repository at this point in the history
These two requst types current share an underlying data structure.

Signed-off-by: Eric Raut <eraut@amazon.com>
  • Loading branch information
rauteric committed Jan 10, 2025
1 parent d611611 commit ccb32bd
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 19 deletions.
10 changes: 8 additions & 2 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ typedef enum nccl_net_ofi_rdma_req_type {
NCCL_OFI_RDMA_RECV_SEGMS,
/* Eager local copy request. Subrequest of NCCL_OFI_RDMA_RECV */
NCCL_OFI_RDMA_EAGER_COPY,
/* Rx buff post request */
NCCL_OFI_RDMA_RX_BUFF,
/* Ctrl rx buff post request */
NCCL_OFI_RDMA_CTRL_RX_BUFF,
/* Eager rx buff post request */
NCCL_OFI_RDMA_EAGER_RX_BUFF,
/* Flush request */
NCCL_OFI_RDMA_FLUSH,
/* Connect message send request */
Expand Down Expand Up @@ -689,6 +691,10 @@ struct nccl_net_ofi_ep_rail {
size_t max_rx_buff_posted;
/* Mutex for rx buffer operations */
pthread_mutex_t rx_buff_mutex;

/* Allocate a receive buffer request for this rail (eager or ctrl) */
nccl_net_ofi_rdma_req_t* (*rx_buff_req_alloc)(nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_ep_rail_t *rail);
};

/*
Expand Down
87 changes: 70 additions & 17 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,8 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev,
* @brief Return rx data struct of rx request
*/
static inline rdma_req_rx_buff_data_t *get_rx_buff_data(nccl_net_ofi_rdma_req_t *req) {
assert(req->type == NCCL_OFI_RDMA_RX_BUFF);
assert(req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF ||
NCCL_OFI_RDMA_EAGER_RX_BUFF);
return &req->rx_buff_data;
}

Expand Down Expand Up @@ -1242,7 +1243,7 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm);

static int handle_close_msg_recv(nccl_net_ofi_rdma_req_t *rx_buff_req)
{
assert(rx_buff_req->type == NCCL_OFI_RDMA_RX_BUFF);
assert(rx_buff_req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF);

rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(rx_buff_req);

Expand Down Expand Up @@ -1287,7 +1288,8 @@ static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, int ra
NCCL_OFI_WARN("RECV event had NULL ctx!");
return -EINVAL;
}
if (OFI_UNLIKELY(rx_buff_req->type != NCCL_OFI_RDMA_RX_BUFF)) {
if (OFI_UNLIKELY((eager && (rx_buff_req->type != NCCL_OFI_RDMA_EAGER_RX_BUFF))
|| ((!eager) && (rx_buff_req->type != NCCL_OFI_RDMA_CTRL_RX_BUFF)))) {
NCCL_OFI_WARN("Invalid non-rx_buff request as ctx!");
return -EINVAL;
}
Expand Down Expand Up @@ -1530,8 +1532,10 @@ static const char *req_type_str(nccl_net_ofi_rdma_req_type_t type)
return "SEND_CLOSE";
case NCCL_OFI_RDMA_RECV_SEGMS:
return "RECV_SEGMS";
case NCCL_OFI_RDMA_RX_BUFF:
return "RX_BUFF";
case NCCL_OFI_RDMA_EAGER_RX_BUFF:
return "EAGER_RX_BUFF";
case NCCL_OFI_RDMA_CTRL_RX_BUFF:
return "CTRL_RX_BUFF";
case NCCL_OFI_RDMA_FLUSH:
return "FLUSH";
case NCCL_OFI_RDMA_EAGER_COPY:
Expand Down Expand Up @@ -1656,7 +1660,8 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_
case NCCL_OFI_RDMA_SEND_CLOSE:
case NCCL_OFI_RDMA_RECV_SEGMS:
case NCCL_OFI_RDMA_EAGER_COPY:
case NCCL_OFI_RDMA_RX_BUFF:
case NCCL_OFI_RDMA_CTRL_RX_BUFF:
case NCCL_OFI_RDMA_EAGER_RX_BUFF:
case NCCL_OFI_RDMA_FLUSH:
case NCCL_OFI_RDMA_SEND_CONN:
case NCCL_OFI_RDMA_RECV_CONN:
Expand Down Expand Up @@ -1691,7 +1696,8 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_
case NCCL_OFI_RDMA_SEND_CTRL:
case NCCL_OFI_RDMA_SEND_CLOSE:
case NCCL_OFI_RDMA_RECV_SEGMS:
case NCCL_OFI_RDMA_RX_BUFF:
case NCCL_OFI_RDMA_CTRL_RX_BUFF:
case NCCL_OFI_RDMA_EAGER_RX_BUFF:
case NCCL_OFI_RDMA_SEND_CONN:
case NCCL_OFI_RDMA_RECV_CONN:
case NCCL_OFI_RDMA_RECV_CONN_RESP:
Expand Down Expand Up @@ -1774,7 +1780,7 @@ static inline int process_err_completion(nccl_net_ofi_rdma_device_t *device,
err_entry.prov_errno,
fi_cq_strerror(cq, err_entry.prov_errno, err_entry.err_data, NULL, 0),
(long)err_entry.len, nccl_net_ofi_req_str(req));
if (req->type == NCCL_OFI_RDMA_RX_BUFF) {
if (req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF || NCCL_OFI_RDMA_EAGER_RX_BUFF) {
/* A rx buffer receive failed -- this is an internal error so bail out */
NCCL_OFI_WARN("Fatal: rx buffer recv completed with error");
} else {
Expand Down Expand Up @@ -1850,7 +1856,8 @@ static int receive_progress(nccl_net_ofi_rdma_req_t *req, bool add_to_pending)
case NCCL_OFI_RDMA_RECV:
case NCCL_OFI_RDMA_SEND:
case NCCL_OFI_RDMA_RECV_SEGMS:
case NCCL_OFI_RDMA_RX_BUFF:
case NCCL_OFI_RDMA_CTRL_RX_BUFF:
case NCCL_OFI_RDMA_EAGER_RX_BUFF:
case NCCL_OFI_RDMA_SEND_CONN:
case NCCL_OFI_RDMA_RECV_CONN:
case NCCL_OFI_RDMA_RECV_CONN_RESP:
Expand Down Expand Up @@ -1910,7 +1917,8 @@ static int process_pending_reqs(nccl_net_ofi_rdma_ep_t *ep)
switch (req->type) {
case NCCL_OFI_RDMA_WRITE:
case NCCL_OFI_RDMA_SEND:
case NCCL_OFI_RDMA_RX_BUFF:
case NCCL_OFI_RDMA_CTRL_RX_BUFF:
case NCCL_OFI_RDMA_EAGER_RX_BUFF:
rc = send_progress(req);
break;
case NCCL_OFI_RDMA_READ:
Expand Down Expand Up @@ -2290,7 +2298,7 @@ static inline int free_invalid(nccl_net_ofi_rdma_req_t *req,
return -EINVAL;
}

static inline int free_rx_buff_req(nccl_net_ofi_rdma_req_t *req,
static inline int free_eager_rx_buff_req(nccl_net_ofi_rdma_req_t *req,
bool dec_inflight_reqs)
{
assert(!dec_inflight_reqs);
Expand All @@ -2303,16 +2311,58 @@ static inline int free_rx_buff_req(nccl_net_ofi_rdma_req_t *req,
return free_base_req(NULL, ep->rx_buff_reqs_fl, req, false);
}

static inline nccl_net_ofi_rdma_req_t *alloc_rx_buff_req(nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_ep_rail_t *rail)
static inline nccl_net_ofi_rdma_req_t *eager_rx_buff_req_alloc(nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_ep_rail_t *rail)
{
nccl_net_ofi_rdma_req_t *req = allocate_req(ep->rx_buff_reqs_fl);
if (!req) return NULL;

req->comm = NULL;
req->type = NCCL_OFI_RDMA_RX_BUFF;
req->type = NCCL_OFI_RDMA_EAGER_RX_BUFF;
req->dev_id = rdma_endpoint_get_device(ep)->base.dev_id;
req->free = free_rx_buff_req;
req->free = free_eager_rx_buff_req;

rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req);

nccl_ofi_freelist_elem_t *rx_buff_fl_elem =
nccl_ofi_freelist_entry_alloc(ep->rx_buff_fl);
if (!rx_buff_fl_elem) {
NCCL_OFI_WARN("Failed to allocate rx_buff_fl_elem");
req->free(req, false);
return NULL;
}
assert(NCCL_OFI_IS_PTR_ALIGNED(rx_buff_fl_elem->ptr, RX_BUFFER_ALIGNMENT));

rx_buff_data->rx_buff_fl_elem = rx_buff_fl_elem;
rx_buff_data->buff_len = ep->rx_buff_size;
rx_buff_data->rail = rail;
rx_buff_data->ep = ep;
return req;
}

static inline int free_ctrl_rx_buff_req(nccl_net_ofi_rdma_req_t *req,
bool dec_inflight_reqs)
{
assert(!dec_inflight_reqs);
rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req);
nccl_net_ofi_rdma_ep_t *ep = rx_buff_data->ep;
/* Free buffer */
if (rx_buff_data->rx_buff_fl_elem) {
nccl_ofi_freelist_entry_free(ep->rx_buff_fl, rx_buff_data->rx_buff_fl_elem);
}
return free_base_req(NULL, ep->rx_buff_reqs_fl, req, false);
}

static inline nccl_net_ofi_rdma_req_t *ctrl_rx_buff_req_alloc(nccl_net_ofi_rdma_ep_t *ep,
nccl_net_ofi_ep_rail_t *rail)
{
nccl_net_ofi_rdma_req_t *req = allocate_req(ep->rx_buff_reqs_fl);
if (!req) return NULL;

req->comm = NULL;
req->type = NCCL_OFI_RDMA_CTRL_RX_BUFF;
req->dev_id = rdma_endpoint_get_device(ep)->base.dev_id;
req->free = free_ctrl_rx_buff_req;

rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req);

Expand Down Expand Up @@ -2371,7 +2421,7 @@ static inline int post_rx_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep,
for (size_t i = 0; i < buffers_needed; ++i) {
bool is_last_req = (i == (buffers_needed - 1));
nccl_net_ofi_rdma_req_t *req =
alloc_rx_buff_req(ep, rail);
rail->rx_buff_req_alloc(ep, rail);
if (!req) {
NCCL_OFI_WARN("Failed to allocate rx_buff req");
return -ENOMEM;
Expand Down Expand Up @@ -5569,7 +5619,8 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req)
// Successfully sent the xfer with this rail
rma_op_data->xferred_rail_id++;
}
} else if (req->type == NCCL_OFI_RDMA_RX_BUFF) { // Post rx Buffer
} else if (req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF ||
req->type == NCCL_OFI_RDMA_EAGER_RX_BUFF) { // Post rx Buffer
rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req);
/* Get ep rail information to xfer the req */
assert(rx_buff_data->rail != NULL);
Expand Down Expand Up @@ -6196,6 +6247,7 @@ static inline int init_rx_buffers(nccl_net_ofi_rdma_ep_t *ep)
);
rail->num_rx_buff_posted = 0;
nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL);
rail->rx_buff_req_alloc = ctrl_rx_buff_req_alloc;
}

for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) {
Expand All @@ -6208,6 +6260,7 @@ static inline int init_rx_buffers(nccl_net_ofi_rdma_ep_t *ep)
);
rail->num_rx_buff_posted = 0;
nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL);
rail->rx_buff_req_alloc = eager_rx_buff_req_alloc;
}

return ret;
Expand Down

0 comments on commit ccb32bd

Please sign in to comment.