|
16 | 16 | */
|
17 | 17 | #include "ucx_backend.h"
|
18 | 18 | #include <serdes/serdes.h>
|
| 19 | +#include <nixl_log.h> |
19 | 20 |
|
20 |
| -nixl_status_t |
21 |
| -nixlUcxEngine::vramUpdateCtx(void *address, uint64_t devId, bool &restart_reqd) |
22 |
| -{ |
23 |
| - restart_reqd = false; |
24 |
| - |
25 |
| - if (!nixlCudaPtrCtx::vramIsSupported()) { |
26 |
| - // NIXL doesn't support CUDA but VRAM pointer was given |
27 |
| - return NIXL_ERR_INVALID_PARAM; |
28 |
| - } |
29 |
| - |
30 |
| - // IF the workaround is globally disabled |
31 |
| - if(!cudaAddrWA) { |
32 |
| - // Nothing to do |
33 |
| - return NIXL_SUCCESS; |
34 |
| - } |
35 |
| - |
36 |
| - std::unique_ptr<nixlCudaPtrCtx> ctx = |
37 |
| - nixlCudaPtrCtx::nixlCudaPtrCtxInit(address); |
38 |
| - |
39 |
| - switch(ctx->getMemType()) { |
40 |
| - case nixlCudaPtrCtx::MEM_HOST: |
41 |
| - // Nothing is required for host |
42 |
| - return NIXL_SUCCESS; |
43 |
| - case nixlCudaPtrCtx::MEM_DEV: |
44 |
| - case nixlCudaPtrCtx::MEM_VMM_DEV: |
45 |
| - // Continue with setting the context |
46 |
| - break; |
47 |
| - case nixlCudaPtrCtx::MEM_VMM_HOST: |
48 |
| - default: |
49 |
| - // TODO: figure out how to handle VMM |
50 |
| - return NIXL_ERR_INVALID_PARAM; |
51 |
| - } |
52 |
| - |
53 |
| - if (ctx->getDevId() != devId) { |
54 |
| - // TODO: log error |
55 |
| - return NIXL_ERR_INVALID_PARAM; |
56 |
| - } |
57 |
| - |
58 |
| - if (nullptr == cudaPtrCtx.get()) { |
59 |
| - // The context was not previously set |
60 |
| - // Set it now and indicate that an update |
61 |
| - // is required |
62 |
| - cudaPtrCtx.swap(ctx); |
63 |
| - restart_reqd = true; |
64 |
| - return NIXL_SUCCESS; |
65 |
| - } |
66 |
| - |
67 |
| - // The context was set previously |
68 |
| - // Check that it is consistent with the new address |
69 |
| - if (! (*ctx == *cudaPtrCtx)) { |
70 |
| - // TODO: log out error that for UCX that requires CUDA context to be set |
71 |
| - // addresses from different contexts are used |
72 |
| - return NIXL_ERR_NOT_SUPPORTED; |
73 |
| - } |
74 |
| - |
75 |
| - return NIXL_SUCCESS; |
76 |
| -} |
77 |
| - |
78 |
| -nixl_status_t nixlUcxEngine::vramApplyCtx() |
79 |
| -{ |
80 |
| - auto ctx = cudaPtrCtx.get(); |
81 |
| - if (ctx) { |
82 |
| - return ctx->setMemCtx(); |
83 |
| - } |
84 |
| - return NIXL_SUCCESS; |
85 |
| -} |
86 | 21 |
|
87 | 22 | /****************************************
|
88 | 23 | * UCX request management
|
@@ -231,7 +166,7 @@ class nixlUcxBackendH : public nixlBackendReqH {
|
231 | 166 | void nixlUcxEngine::progressFunc()
|
232 | 167 | {
|
233 | 168 | using namespace nixlTime;
|
234 |
| - vramApplyCtx(); |
| 169 | + cudaMemCtx->set(); |
235 | 170 |
|
236 | 171 | pthrActive = 1;
|
237 | 172 |
|
@@ -333,13 +268,8 @@ nixlUcxEngine::nixlUcxEngine (const nixlBackendInitParams* init_params)
|
333 | 268 | pthrOn = false;
|
334 | 269 | }
|
335 | 270 |
|
336 |
| - // Temp fixup |
337 |
| - if (getenv("NIXL_DISABLE_CUDA_ADDR_WA")) { |
338 |
| - std::cout << "WARNING: disabling CUDA address workaround" << std::endl; |
339 |
| - cudaAddrWA = false; |
340 |
| - } else { |
341 |
| - cudaAddrWA = true; |
342 |
| - } |
| 271 | + cudaMemCtx = nixlCudaMemCtx::nixlCudaMemCtxInit(); |
| 272 | + |
343 | 273 | progressThreadStart();
|
344 | 274 | }
|
345 | 275 |
|
@@ -580,15 +510,19 @@ nixl_status_t nixlUcxEngine::registerMem (const nixlBlobDesc &mem,
|
580 | 510 | size_t rkey_size;
|
581 | 511 |
|
582 | 512 | if (nixl_mem == VRAM_SEG) {
|
583 |
| - bool need_restart; |
584 |
| - nixl_status_t status; |
585 |
| - status = vramUpdateCtx((void*)mem.addr, mem.devId, need_restart); |
586 |
| - if (NIXL_SUCCESS != status) { |
587 |
| - //TODO Add to logging |
588 |
| - return status; |
589 |
| - } |
590 |
| - if (need_restart) { |
| 513 | + nixl_status_t status = cudaMemCtx->enableAddr((void*)mem.addr, mem.devId); |
| 514 | + if (NIXL_IN_PROG == status) { |
| 515 | + // The context was updated and must be set in the progress thread |
| 516 | + // This procedure ensures that access to cudaMemCtx is serialized |
591 | 517 | progressThreadRestart();
|
| 518 | + } else if (NIXL_SUCCESS != status) { |
| 519 | + NIXL_ERROR << "Address " << std::hex << mem.addr << std::dec |
| 520 | + << " is not supported by the UCX backend"; |
| 521 | + NIXL_ERROR << "Returned status is " << status; |
| 522 | + // TODO use nixlEnumStrings::statusStr(status); once circ dep between libnixl & utils is resolved |
| 523 | + NIXL_ERROR << "The likely reason is that your UCX supports only one CUDA device per UCP context"; |
| 524 | + NIXL_ERROR << "Consider upgrading to UCX 1.19 or using UCX_MO backend"; |
| 525 | + return status; |
592 | 526 | }
|
593 | 527 | }
|
594 | 528 |
|
|
0 commit comments