Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MR: Set minimal access permissions for each MR based on its usage #240

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/nccl_ofi_freelist.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct nccl_ofi_freelist_block_t {
* Note that the freelist lock will be held during this function. The
* caller must avoid a deadlock situation with this behavior.
*/
typedef int (*nccl_ofi_freelist_regmr_fn)(void *opaque, void *data, size_t size,
typedef int (*nccl_ofi_freelist_regmr_fn)(void *opaque,
void *data, size_t size, uint64_t access,
void **handle);

/*
Expand Down Expand Up @@ -100,6 +101,7 @@ struct nccl_ofi_freelist_t {
bool have_reginfo;
nccl_ofi_freelist_regmr_fn regmr_fn;
nccl_ofi_freelist_deregmr_fn deregmr_fn;
uint64_t mr_access;
void *regmr_opaque;
size_t reginfo_offset;

Expand Down Expand Up @@ -152,6 +154,7 @@ int nccl_ofi_freelist_init_mr(size_t entry_size,
size_t max_entry_count,
nccl_ofi_freelist_regmr_fn regmr_fn,
nccl_ofi_freelist_deregmr_fn deregmr_fn,
uint64_t mr_access,
void *regmr_opaque,
size_t reginfo_offset,
nccl_ofi_freelist_t **freelist_p);
Expand Down
10 changes: 8 additions & 2 deletions src/nccl_ofi_freelist.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ static int freelist_init_internal(size_t entry_size,
bool have_reginfo,
nccl_ofi_freelist_regmr_fn regmr_fn,
nccl_ofi_freelist_deregmr_fn deregmr_fn,
uint64_t mr_access,
void *regmr_opaque,
size_t reginfo_offset,
nccl_ofi_freelist_t **freelist_p)
Expand All @@ -42,6 +43,7 @@ static int freelist_init_internal(size_t entry_size,
freelist->have_reginfo = have_reginfo;
freelist->regmr_fn = regmr_fn;
freelist->deregmr_fn = deregmr_fn;
freelist->mr_access = mr_access;
freelist->regmr_opaque = regmr_opaque;
freelist->reginfo_offset = reginfo_offset;

Expand Down Expand Up @@ -78,6 +80,7 @@ int nccl_ofi_freelist_init(size_t entry_size,
false,
NULL,
NULL,
0,
NULL,
0,
freelist_p);
Expand All @@ -89,6 +92,7 @@ int nccl_ofi_freelist_init_mr(size_t entry_size,
size_t max_entry_count,
nccl_ofi_freelist_regmr_fn regmr_fn,
nccl_ofi_freelist_deregmr_fn deregmr_fn,
uint64_t mr_access,
void *regmr_opaque,
size_t reginfo_offset,
nccl_ofi_freelist_t **freelist_p)
Expand All @@ -100,6 +104,7 @@ int nccl_ofi_freelist_init_mr(size_t entry_size,
true,
regmr_fn,
deregmr_fn,
mr_access,
regmr_opaque,
reginfo_offset,
freelist_p);
Expand Down Expand Up @@ -176,8 +181,9 @@ int nccl_ofi_freelist_add(nccl_ofi_freelist_t *freelist,
block->next = freelist->blocks;

if (freelist->regmr_fn) {
ret = freelist->regmr_fn(freelist->regmr_opaque, block->memory,
freelist->entry_size * allocation_count,
ret = freelist->regmr_fn(freelist->regmr_opaque,
block->memory, freelist->entry_size * allocation_count,
freelist->mr_access,
&block->mr_handle);
if (ret != 0) {
NCCL_OFI_WARN("freelist extension registration failed: %d", ret);
Expand Down
59 changes: 38 additions & 21 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,9 @@ static int write_topo_file(nccl_ofi_topo_t *topo)
* @return Populated I/O vector, on success
* @return 0 on success
* non-zero on error
*/
*/
static ncclResult_t set_mr_req_attr(nccl_ofi_mr_keypool_t *key_pool, int dev_id,
void *data, size_t size, int type,
void *data, size_t size, int type, uint64_t access,
struct fi_mr_attr *mr_attr, struct iovec *iov)
{
ncclResult_t ret = ncclSuccess;
Expand All @@ -546,20 +546,15 @@ static ncclResult_t set_mr_req_attr(nccl_ofi_mr_keypool_t *key_pool, int dev_id,
/* Initialize MR attributes */
mr_attr->mr_iov = iov;
mr_attr->iov_count = 1;
mr_attr->access = FI_SEND | FI_RECV;

/* Add FI_WRITE (source of fi_write) and FI_REMOTE_WRITE (target of fi_write)
for RDMA send/recv buffers */
mr_attr->access |= (FI_WRITE | FI_REMOTE_WRITE);
mr_attr->access = access;

switch (type) {
case NCCL_PTR_HOST:
mr_attr->access |= FI_READ;
mr_attr->iface = FI_HMEM_SYSTEM;
break;
#if HAVE_CUDA
case NCCL_PTR_CUDA:
mr_attr->access |= FI_REMOTE_READ;
mr_attr->iface = FI_HMEM_CUDA;

/* Get CUDA device ID */
Expand All @@ -571,7 +566,6 @@ static ncclResult_t set_mr_req_attr(nccl_ofi_mr_keypool_t *key_pool, int dev_id,
#endif
#if HAVE_NEURON
case NCCL_PTR_NEURON:
mr_attr->access |= FI_REMOTE_READ;
mr_attr->iface = FI_HMEM_NEURON;
/*
* Store a sentinel; libfabric requires this to be initialized Libfabric
Expand Down Expand Up @@ -2548,7 +2542,8 @@ bool valid_mr_buff_type(int type)
}

static ncclResult_t reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep, void *data,
size_t size, int type, nccl_net_ofi_rdma_mr_handle_t **mhandle)
size_t size, int type, uint64_t access,
nccl_net_ofi_rdma_mr_handle_t **mhandle)
{
ncclResult_t ret = ncclSuccess;
struct fi_mr_attr mr_attr = {0};
Expand Down Expand Up @@ -2586,7 +2581,7 @@ static ncclResult_t reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep, void *data,
}

/* Create memory registration request */
ret = set_mr_req_attr(key_pool, dev_id, data, size, type, &mr_attr, &iov);
ret = set_mr_req_attr(key_pool, dev_id, data, size, type, access, &mr_attr, &iov);
if (OFI_UNLIKELY(ret != ncclSuccess)) {
NCCL_OFI_WARN("Could not set registration request attributes, dev: %d",
dev_id);
Expand Down Expand Up @@ -2622,14 +2617,32 @@ static ncclResult_t reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, void *
size_t size, int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *) send_comm->base.ep;
return reg_mr_ep(ep, data, size, type, (nccl_net_ofi_rdma_mr_handle_t **)mhandle);
/* Send buffer is used as a RMA write source
* or a tagged message source in case it is sent eagerly */
const uint64_t access = FI_WRITE | FI_SEND;

return reg_mr_ep(ep, data,
size, type, access,
(nccl_net_ofi_rdma_mr_handle_t **)mhandle);
}

static ncclResult_t reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, void *data,
size_t size, int type, void **mhandle)
{
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *) recv_comm->base.ep;
return reg_mr_ep(ep, data, size, type, (nccl_net_ofi_rdma_mr_handle_t **)mhandle);
/* Recv buffer is used as a RMA write target
* or a RMA read target in case of local eager copy */
uint64_t access = FI_REMOTE_WRITE | FI_READ;

/* GPU flush target is registered as recv communication buffer
* and is used as RMA read source */
if (type == NCCL_PTR_CUDA) {
access |= FI_REMOTE_READ;
}

return reg_mr_ep(ep, data,
size, type, access,
(nccl_net_ofi_rdma_mr_handle_t **)mhandle);
}

static ncclResult_t dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle,
Expand Down Expand Up @@ -2678,12 +2691,14 @@ typedef struct {
*
* This interface is suitable for use with freelist_init_mr.
*/
static int freelist_regmr_host_fn(void *ep_void_ptr, void *data, size_t size, void **handle)
static int freelist_regmr_host_fn(void *ep_void_ptr,
void *data, size_t size, uint64_t access,
void **handle)
{
nccl_net_ofi_rdma_ep_t *ep = ep_void_ptr;

nccl_net_ofi_rdma_mr_handle_t *mr_handle;
ncclResult_t ret = reg_mr_ep(ep, data, size, NCCL_PTR_HOST, &mr_handle);
ncclResult_t ret = reg_mr_ep(ep, data, size, NCCL_PTR_HOST, access, &mr_handle);

if (ret != ncclSuccess) {
NCCL_OFI_WARN("Failed call to reg_mr_ep: %d", ret);
Expand Down Expand Up @@ -3241,8 +3256,10 @@ static ncclResult_t alloc_and_reg_flush_buff(nccl_net_ofi_rdma_recv_comm_t *r_co
/* Check if provider requires registration of local buffers */
if (local_mr == true) {
/* Register flush dummy buffer for provider access */
reg_mr_ep((nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep, flush_buff->host_buffer, page_size,
NCCL_PTR_HOST, &mr_handle);
reg_mr_ep((nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep,
flush_buff->host_buffer, page_size,
NCCL_PTR_HOST, FI_READ,
&mr_handle);
if (OFI_UNLIKELY(ret != ncclSuccess)) {
NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d",
dev_id);
Expand Down Expand Up @@ -3612,9 +3629,9 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen
}

ret = nccl_ofi_freelist_init_mr(sizeof(nccl_net_ofi_rdma_ctrl_fl_item_t), 8, 8,
NCCL_OFI_MAX_REQUESTS, freelist_regmr_host_fn,
freelist_deregmr_host_fn, ep, 0,
&r_comm->ctrl_buff_fl);
NCCL_OFI_MAX_REQUESTS,
freelist_regmr_host_fn, freelist_deregmr_host_fn,
FI_SEND, ep, 0, &r_comm->ctrl_buff_fl);
if (ret != 0) {
NCCL_OFI_WARN("Call to freelist_init_mr failed: %d", ret);
return NULL;
Expand Down Expand Up @@ -4881,7 +4898,7 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep)
ret = nccl_ofi_freelist_init_mr(sizeof(nccl_net_ofi_rdma_bounce_fl_item_t) + ep->bounce_buff_size,
ofi_nccl_rdma_min_posted_bounce_buffers(), 16, 0,
freelist_regmr_host_fn, freelist_deregmr_host_fn,
ep, 0, &ep->bounce_buff_fl);
FI_RECV | FI_REMOTE_READ, ep, 0, &ep->bounce_buff_fl);
if (ret != 0) {
NCCL_OFI_WARN("Failed to init bounce_buff_fl");
if (nccl_ofi_freelist_fini(ep->bounce_buff_reqs_fl))
Expand Down
42 changes: 27 additions & 15 deletions src/nccl_ofi_sendrecv.c
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ static ssize_t post_recv_conn(nccl_net_ofi_sendrecv_listen_comm_t *l_comm,
*/
static ncclResult_t register_mr_buffers(struct fid_domain *domain, struct fid_ep *ep,
nccl_ofi_mr_keypool_t *key_pool, int dev_id,
void *data, size_t size,
void *data, size_t size, uint64_t access,
int type, struct fid_mr **mr_handle)
{
ncclResult_t ret = ncclSuccess;
Expand All @@ -489,16 +489,14 @@ static ncclResult_t register_mr_buffers(struct fid_domain *domain, struct fid_ep
/* Initialize MR attributes */
mr_attr.mr_iov = &iov;
mr_attr.iov_count = 1;
mr_attr.access = FI_SEND | FI_RECV;
mr_attr.access = access;

switch (type) {
case NCCL_PTR_HOST:
mr_attr.access |= FI_READ;
mr_attr.iface = FI_HMEM_SYSTEM;
break;
#if HAVE_CUDA
case NCCL_PTR_CUDA:
mr_attr.access |= FI_REMOTE_READ;
mr_attr.iface = FI_HMEM_CUDA;

/* Get CUDA device ID */
Expand All @@ -510,7 +508,6 @@ static ncclResult_t register_mr_buffers(struct fid_domain *domain, struct fid_ep
#endif
#if HAVE_NEURON
case NCCL_PTR_NEURON:
mr_attr.access |= FI_REMOTE_READ;
mr_attr.iface = FI_HMEM_NEURON;
/*
* Store a sentinel; libfabric requires this to be initialized Libfabric
Expand Down Expand Up @@ -570,7 +567,7 @@ static ncclResult_t register_mr_buffers(struct fid_domain *domain, struct fid_ep

static ncclResult_t reg_mr_base(struct fid_domain *domain, struct fid_ep *ep,
nccl_ofi_mr_keypool_t *key_pool, int dev_id,
void *data, size_t size, int type,
void *data, size_t size, int type, uint64_t access,
void **mhandle)
{
/* Validate type of buffer */
Expand All @@ -588,12 +585,15 @@ static ncclResult_t reg_mr_base(struct fid_domain *domain, struct fid_ep *ep,
return ncclInternalError;
}

return register_mr_buffers(domain, ep, key_pool, dev_id, data, size, type,
(struct fid_mr **)mhandle);
return register_mr_buffers(domain, ep,
key_pool, dev_id,
data, size, access,
type, (struct fid_mr **)mhandle);
}

static ncclResult_t reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data,
size_t size, int type, void **mhandle)
size_t size, int type, uint64_t access,
void **mhandle)
{
/* Retrieve and validate endpoint */
nccl_net_ofi_sendrecv_ep_t *ep =
Expand All @@ -613,20 +613,30 @@ static ncclResult_t reg_mr_base_comm(nccl_net_ofi_comm_t *base_comm, void *data,
int dev_id = device->base.dev_id;

nccl_ofi_mr_keypool_t *key_pool = &device->key_pool;
return reg_mr_base(device->domain, ep->ofi_ep, key_pool,
dev_id, data, size, type, mhandle);
return reg_mr_base(device->domain, ep->ofi_ep,
key_pool, dev_id,
data, size, type, access,
mhandle);
}

static ncclResult_t reg_mr_send_comm(nccl_net_ofi_send_comm_t *send_comm, void *data,
size_t size, int type, void **mhandle)
{
return reg_mr_base_comm(&send_comm->base, data, size, type, mhandle);
return reg_mr_base_comm(&send_comm->base, data, size, type, FI_SEND, mhandle);
}

static ncclResult_t reg_mr_recv_comm(nccl_net_ofi_recv_comm_t *recv_comm, void *data,
size_t size, int type, void **mhandle)
{
return reg_mr_base_comm(&recv_comm->base, data, size, type, mhandle);
uint64_t access = FI_RECV;

/* GPU flush target is registered as recv communication buffer
* and is used as RMA read source */
if (type == NCCL_PTR_CUDA) {
access |= FI_REMOTE_READ;
}

return reg_mr_base_comm(&recv_comm->base, data, size, type, access, mhandle);
}

static ncclResult_t dereg_mr_base_comm(struct fid_mr *mr_handle,
Expand Down Expand Up @@ -1059,8 +1069,10 @@ static int alloc_and_reg_flush_buff(struct fid_domain *domain, struct fid_ep *ep
}

/* Register flush dummy buffer for provider access */
ret = register_mr_buffers(domain, ep, key_pool, dev_id, flush_buff->host_buffer,
page_size, NCCL_PTR_HOST, &mr_handle);
ret = register_mr_buffers(domain, ep,
key_pool, dev_id,
flush_buff->host_buffer, page_size, FI_READ,
NCCL_PTR_HOST, &mr_handle);
if (OFI_UNLIKELY(ret != ncclSuccess)) {
NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d",
dev_id);
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/freelist.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@
#include "config.h"

#include <stdio.h>
#include <rdma/fabric.h>

#include "test-common.h"
#include "nccl_ofi_freelist.h"

void *simple_base;
size_t simple_size;
uint64_t simple_access;
void *simple_handle;

int regmr_simple(void *opaque, void *data, size_t size, void **handle)
int regmr_simple(void *opaque,
void *data, size_t size, uint64_t access,
void **handle)
{
*handle = simple_handle = opaque;
simple_base = data;
simple_size = size;
simple_access = access;

return ncclSuccess;
}
Expand All @@ -29,6 +34,7 @@ int deregmr_simple(void *handle)

simple_base = NULL;
simple_size = 0;
simple_access = 0;
simple_handle = NULL;

return ncclSuccess;
Expand Down Expand Up @@ -172,6 +178,7 @@ int main(int argc, char *argv[])
32,
regmr_simple,
deregmr_simple,
FI_REMOTE_WRITE,
(void *)0xdeadbeaf,
offsetof(struct random_freelisted_item, reginfo),
&freelist);
Expand Down Expand Up @@ -201,6 +208,12 @@ int main(int argc, char *argv[])
item->reginfo.base_offset);
exit(1);
}

if (simple_access != FI_REMOTE_WRITE) {
NCCL_OFI_WARN("MR access mismatch got=%lu expected=%lu",
simple_access, FI_REMOTE_WRITE);
exit(1);
}
}
nccl_ofi_freelist_fini(freelist);
if (simple_base) {
Expand Down