Skip to content

Commit

Permalink
MR: Set minimal access permissions for each MR based on its usage
Browse files Browse the repository at this point in the history
MR access permissions are curerntly set in the internal function that
initialize MR attributes for all registered MRs (set_mr_req_attr() for
RDMA protocol and register_mr_buffers() for SENDRECV protocol).
Thus, this function is required to consider all the possible usages of
registered MRs in the codebase to properly set the access permissions.

First, this is prone to error because one may not consider some MR usage
(See previous commit which fixes such a bug). Second, setting minimal
access permissions for each MR based on its usage provides us some
protection against mistakenly mixing MR handles. Resulting in such case
in MR access permission violation rather than a memory corruption, which
is easier to debug and root-cause.

Signed-off-by: Liran Alon <liran@amazon.com>
  • Loading branch information
liralon committed Aug 15, 2023
1 parent 0366982 commit 9a3f74d
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 55 deletions.
5 changes: 4 additions & 1 deletion include/nccl_ofi_freelist.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct nccl_ofi_freelist_block_t {
* Note that the freelist lock will be held during this function. The
* caller must avoid a deadlock situation with this behavior.
*/
typedef int (*nccl_ofi_freelist_regmr_fn)(void *opaque, void *data, size_t size,
typedef int (*nccl_ofi_freelist_regmr_fn)(void *opaque,
void *data, size_t size, uint64_t access,
void **handle);

/*
Expand Down Expand Up @@ -100,6 +101,7 @@ struct nccl_ofi_freelist_t {
bool have_reginfo;
nccl_ofi_freelist_regmr_fn regmr_fn;
nccl_ofi_freelist_deregmr_fn deregmr_fn;
uint64_t mr_access;
void *regmr_opaque;
size_t reginfo_offset;

Expand Down Expand Up @@ -152,6 +154,7 @@ int nccl_ofi_freelist_init_mr(size_t entry_size,
size_t max_entry_count,
nccl_ofi_freelist_regmr_fn regmr_fn,
nccl_ofi_freelist_deregmr_fn deregmr_fn,
uint64_t mr_access,
void *regmr_opaque,
size_t reginfo_offset,
nccl_ofi_freelist_t **freelist_p);
Expand Down
10 changes: 8 additions & 2 deletions src/nccl_ofi_freelist.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ static int freelist_init_internal(size_t entry_size,
bool have_reginfo,
nccl_ofi_freelist_regmr_fn regmr_fn,
nccl_ofi_freelist_deregmr_fn deregmr_fn,
uint64_t mr_access,
void *regmr_opaque,
size_t reginfo_offset,
nccl_ofi_freelist_t **freelist_p)
Expand All @@ -42,6 +43,7 @@ static int freelist_init_internal(size_t entry_size,
freelist->have_reginfo = have_reginfo;
freelist->regmr_fn = regmr_fn;
freelist->deregmr_fn = deregmr_fn;
freelist->mr_access = mr_access;
freelist->regmr_opaque = regmr_opaque;
freelist->reginfo_offset = reginfo_offset;

Expand Down Expand Up @@ -78,6 +80,7 @@ int nccl_ofi_freelist_init(size_t entry_size,
false,
NULL,
NULL,
0,
NULL,
0,
freelist_p);
Expand All @@ -89,6 +92,7 @@ int nccl_ofi_freelist_init_mr(size_t entry_size,
size_t max_entry_count,
nccl_ofi_freelist_regmr_fn regmr_fn,
nccl_ofi_freelist_deregmr_fn deregmr_fn,
uint64_t mr_access,
void *regmr_opaque,
size_t reginfo_offset,
nccl_ofi_freelist_t **freelist_p)
Expand All @@ -100,6 +104,7 @@ int nccl_ofi_freelist_init_mr(size_t entry_size,
true,
regmr_fn,
deregmr_fn,
mr_access,
regmr_opaque,
reginfo_offset,
freelist_p);
Expand Down Expand Up @@ -177,8 +182,9 @@ int nccl_ofi_freelist_add(nccl_ofi_freelist_t *freelist,
freelist->blocks = block;

if (freelist->regmr_fn) {
ret = freelist->regmr_fn(freelist->regmr_opaque, block->memory,
freelist->entry_size * allocation_count,
ret = freelist->regmr_fn(freelist->regmr_opaque,
block->memory, freelist->entry_size * allocation_count,
freelist->mr_access,
&block->mr_handle);
if (ret != 0) {
NCCL_OFI_WARN("freelist extension registration failed: %d", ret);
Expand Down
70 changes: 37 additions & 33 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ static int write_topo_file(nccl_ofi_topo_t *topo)
* non-zero on error
*/
static ncclResult_t set_mr_req_attr(nccl_ofi_mr_keypool_t *key_pool, int dev_id,
void *data, size_t size, int type,
void *data, size_t size, int type, uint64_t access,
struct fi_mr_attr *mr_attr, struct iovec *iov)
{
ncclResult_t ret = ncclSuccess;
Expand All @@ -547,31 +547,14 @@ static ncclResult_t set_mr_req_attr(nccl_ofi_mr_keypool_t *key_pool, int dev_id,
mr_attr->mr_iov = iov;
mr_attr->iov_count = 1;

/*
* Communication buffer is used as a message source when sending data eagerly.
* Bounce buffer is used as a message target when receiving eager data or control message.
* Control message is used as a message source.
*/
mr_attr->access = FI_SEND | FI_RECV;
/* Communication buffer is used as a source/target of RMA writes */
mr_attr->access |= (FI_WRITE | FI_REMOTE_WRITE);
mr_attr->access = access;

switch (type) {
case NCCL_PTR_HOST:
/*
* Host buffer is used as a RMA read source on eager local copy
* and a RMA read target on GPU flush.
*/
mr_attr->access |= FI_READ | FI_REMOTE_READ;
mr_attr->iface = FI_HMEM_SYSTEM;
break;
#if HAVE_CUDA
case NCCL_PTR_CUDA:
/*
* CUDA buffer is used as a RMA read source on GPU flush
* and a RMA read target on eager local copy.
*/
mr_attr->access |= FI_READ | FI_REMOTE_READ;
mr_attr->iface = FI_HMEM_CUDA;

/* Get CUDA device ID */
Expand All @@ -583,8 +566,6 @@ static ncclResult_t set_mr_req_attr(nccl_ofi_mr_keypool_t *key_pool, int dev_id,
#endif
#if HAVE_NEURON
case NCCL_PTR_NEURON:
/* Neuron buffer is used as a RMA read target on eager local copy */
mr_attr->access |= FI_READ;
mr_attr->iface = FI_HMEM_NEURON;
/*
* Store a sentinel; libfabric requires this to be initialized Libfabric
Expand Down Expand Up @@ -2559,7 +2540,8 @@ bool valid_mr_buff_type(int type)
}

static ncclResult_t reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep, void *data,
size_t size, int type, nccl_net_ofi_rdma_mr_handle_t **mhandle)
size_t size, int type, uint64_t access,
nccl_net_ofi_rdma_mr_handle_t **mhandle)
{
ncclResult_t ret = ncclSuccess;
struct fi_mr_attr mr_attr = {0};
Expand Down Expand Up @@ -2597,7 +2579,7 @@ static ncclResult_t reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep, void *data,
}

/* Create memory registration request */
ret = set_mr_req_attr(key_pool, dev_id, data, size, type, &mr_attr, &iov);
ret = set_mr_req_attr(key_pool, dev_id, data, size, type, access, &mr_attr, &iov);
if (OFI_UNLIKELY(ret != ncclSuccess)) {
NCCL_OFI_WARN("Could not set registration request attributes, dev: %d",
dev_id);
Expand Down Expand Up @@ -2633,14 +2615,32 @@ static ncclResult_t reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, void *
size_t size, int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *) send_comm->base.ep;
return reg_mr_ep(ep, data, size, type, (nccl_net_ofi_rdma_mr_handle_t **)mhandle);
/* Send buffer is used as a RMA write source
* or a tagged message source in case it is sent eagerly */
const uint64_t access = FI_WRITE | FI_SEND;

return reg_mr_ep(ep, data,
size, type, access,
(nccl_net_ofi_rdma_mr_handle_t **)mhandle);
}

static ncclResult_t reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, void *data,
size_t size, int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *) recv_comm->base.ep;
return reg_mr_ep(ep, data, size, type, (nccl_net_ofi_rdma_mr_handle_t **)mhandle);
/* Recv buffer is used as a RMA write target
* or a RMA read target in case of local eager copy */
uint64_t access = FI_REMOTE_WRITE | FI_READ;

/* GPU flush target is registered as recv communication buffer
* and is used as RMA read source */
if (type == NCCL_PTR_CUDA) {
access |= FI_REMOTE_READ;
}

return reg_mr_ep(ep, data,
size, type, access,
(nccl_net_ofi_rdma_mr_handle_t **)mhandle);
}

static ncclResult_t dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle,
Expand Down Expand Up @@ -2689,12 +2689,14 @@ typedef struct {
*
* This interface is suitable for use with freelist_init_mr.
*/
static int freelist_regmr_host_fn(void *ep_void_ptr, void *data, size_t size, void **handle)
static int freelist_regmr_host_fn(void *ep_void_ptr,
void *data, size_t size, uint64_t access,
void **handle)
{
nccl_net_ofi_rdma_ep_t *ep = ep_void_ptr;

nccl_net_ofi_rdma_mr_handle_t *mr_handle;
ncclResult_t ret = reg_mr_ep(ep, data, size, NCCL_PTR_HOST, &mr_handle);
ncclResult_t ret = reg_mr_ep(ep, data, size, NCCL_PTR_HOST, access, &mr_handle);

if (ret != ncclSuccess) {
NCCL_OFI_WARN("Failed call to reg_mr_ep: %d", ret);
Expand Down Expand Up @@ -3252,8 +3254,10 @@ static ncclResult_t alloc_and_reg_flush_buff(nccl_net_ofi_rdma_recv_comm_t *r_co
/* Check if provider requires registration of local buffers */
if (local_mr == true) {
/* Register flush dummy buffer for provider access */
reg_mr_ep((nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep, flush_buff->host_buffer, page_size,
NCCL_PTR_HOST, &mr_handle);
reg_mr_ep((nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep,
flush_buff->host_buffer, page_size,
NCCL_PTR_HOST, FI_READ,
&mr_handle);
if (OFI_UNLIKELY(ret != ncclSuccess)) {
NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d",
dev_id);
Expand Down Expand Up @@ -3623,9 +3627,9 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen
}

ret = nccl_ofi_freelist_init_mr(sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t), 8, 8,
NCCL_OFI_MAX_REQUESTS, freelist_regmr_host_fn,
freelist_deregmr_host_fn, ep, 0,
&r_comm->ctrl_buff_fl);
NCCL_OFI_MAX_REQUESTS,
freelist_regmr_host_fn, freelist_deregmr_host_fn,
FI_SEND, ep, 0, &r_comm->ctrl_buff_fl);
if (ret != 0) {
NCCL_OFI_WARN("Call to freelist_init_mr failed: %d", ret);
return NULL;
Expand Down Expand Up @@ -4892,7 +4896,7 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep)
ret = nccl_ofi_freelist_init_mr(sizeof(nccl_net_ofi_rdma_bounce_fl_item_t) + ep->bounce_buff_size,
ofi_nccl_rdma_min_posted_bounce_buffers(), 16, 0,
freelist_regmr_host_fn, freelist_deregmr_host_fn,
ep, 0, &ep->bounce_buff_fl);
FI_RECV | FI_REMOTE_READ, ep, 0, &ep->bounce_buff_fl);
if (ret != 0) {
NCCL_OFI_WARN("Failed to init bounce_buff_fl");
if (nccl_ofi_freelist_fini(ep->bounce_buff_reqs_fl))
Expand Down
45 changes: 27 additions & 18 deletions src/nccl_ofi_sendrecv.c
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ static ssize_t post_recv_conn(nccl_net_ofi_sendrecv_listen_comm_t *l_comm,
*/
static ncclResult_t register_mr_buffers(struct fid_domain *domain, struct fid_ep *ep,
nccl_ofi_mr_keypool_t *key_pool, int dev_id,
void *data, size_t size,
void *data, size_t size, uint64_t access,
int type, struct fid_mr **mr_handle)
{
ncclResult_t ret = ncclSuccess;
Expand All @@ -489,20 +489,14 @@ static ncclResult_t register_mr_buffers(struct fid_domain *domain, struct fid_ep
/* Initialize MR attributes */
mr_attr.mr_iov = &iov;
mr_attr.iov_count = 1;

/* Communication buffer is used as a message source/target */
mr_attr.access = FI_SEND | FI_RECV;
mr_attr.access = access;

switch (type) {
case NCCL_PTR_HOST:
/* Host buffer is used as a RMA read target on GPU flush */
mr_attr.access |= FI_READ;
mr_attr.iface = FI_HMEM_SYSTEM;
break;
#if HAVE_CUDA
case NCCL_PTR_CUDA:
/* CUDA buffer is used as a RMA read source on GPU flush */
mr_attr.access |= FI_REMOTE_READ;
mr_attr.iface = FI_HMEM_CUDA;

/* Get CUDA device ID */
Expand Down Expand Up @@ -573,7 +567,7 @@ static ncclResult_t register_mr_buffers(struct fid_domain *domain, struct fid_ep

static ncclResult_t reg_mr_base(struct fid_domain *domain, struct fid_ep *ep,
nccl_ofi_mr_keypool_t *key_pool, int dev_id,
void *data, size_t size, int type,
void *data, size_t size, int type, uint64_t access,
void **mhandle)
{
/* Validate type of buffer */
Expand All @@ -591,12 +585,15 @@ static ncclResult_t reg_mr_base(struct fid_domain *domain, struct fid_ep *ep,
return ncclInternalError;
}

return register_mr_buffers(domain, ep, key_pool, dev_id, data, size, type,
(struct fid_mr **)mhandle);
return register_mr_buffers(domain, ep,
key_pool, dev_id,
data, size, access,
type, (struct fid_mr **)mhandle);
}

static ncclResult_t reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data,
size_t size, int type, void **mhandle)
size_t size, int type, uint64_t access,
void **mhandle)
{
/* Retrieve and validate endpoint */
nccl_net_ofi_sendrecv_ep_t *ep =
Expand All @@ -616,20 +613,30 @@ static ncclResult_t reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data,
int dev_id = device->base.dev_id;

nccl_ofi_mr_keypool_t *key_pool = &device->key_pool;
return reg_mr_base(device->domain, ep->ofi_ep, key_pool,
dev_id, data, size, type, mhandle);
return reg_mr_base(device->domain, ep->ofi_ep,
key_pool, dev_id,
data, size, type, access,
mhandle);
}

static ncclResult_t reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, void *data,
size_t size, int type, void **mhandle)
{
return reg_mr_base_comm(&send_comm->base, data, size, type, mhandle);
return reg_mr_base_comm(&send_comm->base, data, size, type, FI_SEND, mhandle);
}

static ncclResult_t reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, void *data,
size_t size, int type, void **mhandle)
{
return reg_mr_base_comm(&recv_comm->base, data, size, type, mhandle);
uint64_t access = FI_RECV;

/* GPU flush target is registered as recv communication buffer
* and is used as RMA read source */
if (type == NCCL_PTR_CUDA) {
access |= FI_REMOTE_READ;
}

return reg_mr_base_comm(&recv_comm->base, data, size, type, access, mhandle);
}

static ncclResult_t dereg_mr_base_comm(struct fid_mr *mr_handle,
Expand Down Expand Up @@ -1062,8 +1069,10 @@ static int alloc_and_reg_flush_buff(struct fid_domain *domain, struct fid_ep *ep
}

/* Register flush dummy buffer for provider access */
ret = register_mr_buffers(domain, ep, key_pool, dev_id, flush_buff->host_buffer,
page_size, NCCL_PTR_HOST, &mr_handle);
ret = register_mr_buffers(domain, ep,
key_pool, dev_id,
flush_buff->host_buffer, page_size, FI_READ,
NCCL_PTR_HOST, &mr_handle);
if (OFI_UNLIKELY(ret != ncclSuccess)) {
NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d",
dev_id);
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/freelist.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@
#include "config.h"

#include <stdio.h>
#include <rdma/fabric.h>

#include "test-common.h"
#include "nccl_ofi_freelist.h"

void *simple_base;
size_t simple_size;
uint64_t simple_access;
void *simple_handle;

int regmr_simple(void *opaque, void *data, size_t size, void **handle)
int regmr_simple(void *opaque,
void *data, size_t size, uint64_t access,
void **handle)
{
*handle = simple_handle = opaque;
simple_base = data;
simple_size = size;
simple_access = access;

return ncclSuccess;
}
Expand All @@ -29,6 +34,7 @@ int deregmr_simple(void *handle)

simple_base = NULL;
simple_size = 0;
simple_access = 0;
simple_handle = NULL;

return ncclSuccess;
Expand Down Expand Up @@ -172,6 +178,7 @@ int main(int argc, char *argv[])
32,
regmr_simple,
deregmr_simple,
FI_REMOTE_WRITE,
(void *)0xdeadbeaf,
offsetof(struct random_freelisted_item, reginfo),
&freelist);
Expand Down Expand Up @@ -201,6 +208,12 @@ int main(int argc, char *argv[])
item->reginfo.base_offset);
exit(1);
}

if (simple_access != FI_REMOTE_WRITE) {
NCCL_OFI_WARN("MR access mismatch got=%lu expected=%lu",
simple_access, FI_REMOTE_WRITE);
exit(1);
}
}
nccl_ofi_freelist_fini(freelist);
if (simple_base) {
Expand Down

0 comments on commit 9a3f74d

Please sign in to comment.