Skip to content

Commit 87a3bf4

Browse files
committed
Adress comments
1 parent e7df641 commit 87a3bf4

File tree

10 files changed

+327
-311
lines changed

10 files changed

+327
-311
lines changed

examples/cpp/meson.build

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
nixl_example = executable('nixl_example',
1717
'nixl_example.cpp',
18-
dependencies: [nixl_dep, nixl_infra, ucx_backend_interface],
18+
dependencies: [nixl_dep, nixl_infra],
1919
include_directories: [nixl_inc_dirs, utils_inc_dirs],
2020
link_with: [serdes_lib],
2121
install: true)
2222
if etcd_dep.found()
2323
etcd_example = executable('nixl_etcd_example',
2424
'nixl_etcd_example.cpp',
25-
dependencies: [nixl_dep, nixl_infra, ucx_backend_interface],
25+
dependencies: [nixl_dep, nixl_infra],
2626
include_directories: [nixl_inc_dirs, utils_inc_dirs],
2727
link_with: [serdes_lib],
2828
install: true)

src/plugins/ucx/meson.build

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
compile_flags = []
1818

19-
ucx_backend_dep_list = [nixl_infra, ucx_utils_interface, cuda_utils_interface, serdes_interface, thread_dep]
19+
ucx_backend_dep_list = [nixl_infra, nixl_common_dep, ucx_utils_interface,
20+
cuda_utils_interface, serdes_interface, thread_dep]
2021

2122
if 'UCX' in static_plugins
2223
ucx_backend_lib = static_library('UCX',

src/plugins/ucx/ucx_backend.cpp

+16-82
Original file line numberDiff line numberDiff line change
@@ -16,73 +16,8 @@
1616
*/
1717
#include "ucx_backend.h"
1818
#include <serdes/serdes.h>
19+
#include <nixl_log.h>
1920

20-
nixl_status_t
21-
nixlUcxEngine::vramUpdateCtx(void *address, uint64_t devId, bool &restart_reqd)
22-
{
23-
restart_reqd = false;
24-
25-
if (!nixlCudaPtrCtx::vramIsSupported()) {
26-
// NIXL doesn't support CUDA but VRAM pointer was given
27-
return NIXL_ERR_INVALID_PARAM;
28-
}
29-
30-
// IF the workaround is globally disabled
31-
if(!cudaAddrWA) {
32-
// Nothing to do
33-
return NIXL_SUCCESS;
34-
}
35-
36-
std::unique_ptr<nixlCudaPtrCtx> ctx =
37-
nixlCudaPtrCtx::nixlCudaPtrCtxInit(address);
38-
39-
switch(ctx->getMemType()) {
40-
case nixlCudaPtrCtx::MEM_HOST:
41-
// Nothing is required for host
42-
return NIXL_SUCCESS;
43-
case nixlCudaPtrCtx::MEM_DEV:
44-
case nixlCudaPtrCtx::MEM_VMM_DEV:
45-
// Continue with setting the context
46-
break;
47-
case nixlCudaPtrCtx::MEM_VMM_HOST:
48-
default:
49-
// TODO: figure out how to handle VMM
50-
return NIXL_ERR_INVALID_PARAM;
51-
}
52-
53-
if (ctx->getDevId() != devId) {
54-
// TODO: log error
55-
return NIXL_ERR_INVALID_PARAM;
56-
}
57-
58-
if (nullptr == cudaPtrCtx.get()) {
59-
// The context was not previously set
60-
// Set it now and indicate that an update
61-
// is required
62-
cudaPtrCtx.swap(ctx);
63-
restart_reqd = true;
64-
return NIXL_SUCCESS;
65-
}
66-
67-
// The context was set previously
68-
// Check that it is consistent with the new address
69-
if (! (*ctx == *cudaPtrCtx)) {
70-
// TODO: log out error that for UCX that requires CUDA context to be set
71-
// addresses from different contexts are used
72-
return NIXL_ERR_NOT_SUPPORTED;
73-
}
74-
75-
return NIXL_SUCCESS;
76-
}
77-
78-
nixl_status_t nixlUcxEngine::vramApplyCtx()
79-
{
80-
auto ctx = cudaPtrCtx.get();
81-
if (ctx) {
82-
return ctx->setMemCtx();
83-
}
84-
return NIXL_SUCCESS;
85-
}
8621

8722
/****************************************
8823
* UCX request management
@@ -231,7 +166,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
231166
void nixlUcxEngine::progressFunc()
232167
{
233168
using namespace nixlTime;
234-
vramApplyCtx();
169+
cudaMemCtx->set();
235170

236171
pthrActive = 1;
237172

@@ -333,13 +268,8 @@ nixlUcxEngine::nixlUcxEngine (const nixlBackendInitParams* init_params)
333268
pthrOn = false;
334269
}
335270

336-
// Temp fixup
337-
if (getenv("NIXL_DISABLE_CUDA_ADDR_WA")) {
338-
std::cout << "WARNING: disabling CUDA address workaround" << std::endl;
339-
cudaAddrWA = false;
340-
} else {
341-
cudaAddrWA = true;
342-
}
271+
cudaMemCtx = nixlCudaMemCtx::nixlCudaMemCtxInit();
272+
343273
progressThreadStart();
344274
}
345275

@@ -580,15 +510,19 @@ nixl_status_t nixlUcxEngine::registerMem (const nixlBlobDesc &mem,
580510
size_t rkey_size;
581511

582512
if (nixl_mem == VRAM_SEG) {
583-
bool need_restart;
584-
nixl_status_t status;
585-
status = vramUpdateCtx((void*)mem.addr, mem.devId, need_restart);
586-
if (NIXL_SUCCESS != status) {
587-
//TODO Add to logging
588-
return status;
589-
}
590-
if (need_restart) {
513+
nixl_status_t status = cudaMemCtx->enableAddr((void*)mem.addr, mem.devId);
514+
if (NIXL_IN_PROG == status) {
515+
// The context was updated and must be set in the progress thread
516+
// This procedure ensures that access to cudaMemCtx is serialized
591517
progressThreadRestart();
518+
} else if (NIXL_SUCCESS != status) {
519+
NIXL_ERROR << "Address " << std::hex << mem.addr << std::dec
520+
<< " is not supported by the UCX backend";
521+
NIXL_ERROR << "Returned status is " << status;
522+
// TODO use nixlEnumStrings::statusStr(status); once circ dep between libnixl & utils is resolved
523+
NIXL_ERROR << "The likely reason is that your UCX supports only one CUDA device per UCP context";
524+
NIXL_ERROR << "Consider upgrading to UCX 1.19 or using UCX_MO backend";
525+
return status;
592526
}
593527
}
594528

src/plugins/ucx/ucx_backend.h

+1-7
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ class nixlUcxEngine : public nixlBackendEngine {
107107
nixlTime::us_t pthrDelay;
108108

109109
/* CUDA data*/
110-
std::unique_ptr<nixlCudaPtrCtx> cudaPtrCtx;
111-
//nixlUcxCudaCtx *cudaCtx;
112-
bool cudaAddrWA;
110+
std::unique_ptr<nixlCudaMemCtx> cudaMemCtx;
113111

114112
/* Notifications */
115113
notif_list_t notifMainList;
@@ -120,10 +118,6 @@ class nixlUcxEngine : public nixlBackendEngine {
120118
std::unordered_map<std::string, nixlUcxConnection,
121119
std::hash<std::string>, strEqual> remoteConnMap;
122120

123-
124-
nixl_status_t vramUpdateCtx(void *address, uint64_t devId, bool &restart_reqd);
125-
nixl_status_t vramApplyCtx();
126-
127121
// Threading infrastructure
128122
// TODO: move the thread management one outside of NIXL common infra
129123
void progressFunc();

src/plugins/ucx_mo/meson.build

-5
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@
1616
ucx_mo_dep_list = [nixl_infra, ucx_backend_interface, serdes_interface]
1717

1818
compile_flags = []
19-
if cuda_dep.found()
20-
compile_flags = [ '-DHAVE_CUDA' ]
21-
else
22-
compile_flags = [ '-UHAVE_CUDA' ]
23-
endif
2419

2520
if 'UCX_MO' in static_plugins
2621
ucx_mo_backend_lib = static_library('UCX_MO',

0 commit comments

Comments
 (0)