Skip to content

Commit e8eafb8

Browse files
tstamlermkhazraee
andcommitted
API: checkRemoteMD call (#203)
* API: checkRemoteMD call Signed-off-by: Timothy Stamler <tstamler@nvidia.com> * API: added descs to checkRemoteMD * As well as minor reordering of the API Signed-off-by: Timothy Stamler <tstamler@nvidia.com> Signed-off-by: Moein Khazraee <moein@nvidia.com> Co-authored-by: Moein Khazraee <moein@nvidia.com>
1 parent 3810b36 commit e8eafb8

File tree

10 files changed

+244
-178
lines changed

10 files changed

+244
-178
lines changed

.gitlab/test_python.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pytest test/python
4343

4444
echo "==== Running python example ===="
4545
cd examples/python
46-
python3 partial_md_example.py
4746
python3 blocking_send_recv_example.py --mode="target" --ip=127.0.0.1 --port=1234&
48-
sleep 1
47+
sleep 5
4948
python3 blocking_send_recv_example.py --mode="initiator" --ip=127.0.0.1 --port=1234
49+
python3 partial_md_example.py

examples/python/blocking_send_recv_example.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import torch
2121

2222
from nixl._api import nixl_agent, nixl_agent_config
23-
from nixl._bindings import nixlNotFoundError
2423

2524

2625
def parse_args():
@@ -68,12 +67,11 @@ def parse_args():
6867

6968
# Send desc list to initiator when metadata is ready
7069
while not ready:
71-
try:
72-
agent.send_notif("initiator", target_desc_str)
73-
except nixlNotFoundError:
74-
ready = False
75-
else:
76-
ready = True
70+
ready = agent.check_remote_metadata("initiator")
71+
72+
agent.send_notif("initiator", target_desc_str)
73+
74+
print("Waiting for transfer")
7775

7876
# Waiting for transfer
7977
# For now the notification is just UUID, could be any python bytes.
@@ -98,14 +96,13 @@ def parse_args():
9896
# Ensure remote metadata has arrived from fetch
9997
ready = False
10098
while not ready:
101-
try:
102-
xfer_handle = agent.initialize_xfer(
103-
"READ", initiator_descs, target_descs, "target", "UUID"
104-
)
105-
except nixlNotFoundError:
106-
ready = False
107-
else:
108-
ready = True
99+
ready = agent.check_remote_metadata("target")
100+
101+
print("Ready for transfer")
102+
103+
xfer_handle = agent.initialize_xfer(
104+
"READ", initiator_descs, target_descs, "target", "UUID"
105+
)
109106

110107
if not xfer_handle:
111108
print("Creating transfer failed.")

examples/python/partial_md_example.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717

1818
import os
19+
import time
1920

2021
import nixl._utils as nixl_utils
2122
from nixl._api import nixl_agent, nixl_agent_config
@@ -78,17 +79,13 @@
7879

7980
# Wait for metadata to be loaded
8081
ready = False
81-
xfer_handle_1 = 0
82+
8283
while not ready:
83-
try:
84-
# initialize transfer mode
85-
xfer_handle_1 = init_agent.initialize_xfer(
86-
"READ", init_xfer_descs, target_xfer_descs1, "target", b"UUID1"
87-
)
88-
except nixlNotFoundError:
89-
ready = False
90-
else:
91-
ready = True
84+
ready = init_agent.check_remote_metadata("target", target_xfer_descs1)
85+
86+
xfer_handle_1 = init_agent.initialize_xfer(
87+
"READ", init_xfer_descs, target_xfer_descs1, "target", b"UUID1"
88+
)
9289

9390
state = init_agent.transfer(xfer_handle_1)
9491
assert state != "ERR"
@@ -171,4 +168,6 @@
171168
for addr in malloc_addrs:
172169
nixl_utils.free_passthru(addr)
173170

171+
# Give sockets time to close
172+
time.sleep(1)
174173
print("Test Complete.")

src/api/cpp/nixl.h

+61-48
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,59 @@ class nixlAgent {
316316
const nixl_blob_t &msg,
317317
const nixl_opt_args_t* extra_params = nullptr);
318318

319+
/*** Metadata handling through side channel ***/
320+
/**
321+
* @brief Get metadata blob for this agent, to be given to other agents.
322+
*
323+
* @param str [out] The serialized metadata blob
324+
* @return nixl_status_t Error code if call was not successful
325+
*/
326+
nixl_status_t
327+
getLocalMD (nixl_blob_t &str) const;
328+
329+
/**
330+
* @brief Get partial metadata blob for this agent, to be given to other agents.
331+
* If `descs` is empty, only backends' connection info is included in the metadata,
332+
* regardless of the value of `extra_params->includeConnInfo` and `descs` memory type.
333+
* If `descs` is non-empty, the metadata of the descriptors in the list are included,
334+
* and if `extra_params->includeConnInfo` is true, the connection info of the
335+
* backends supporting the memory type is also included.
336+
* If `extra_params->backends` is non-empty, only the descriptors supported by the
337+
* backends in the list and the backends' connection info are included in the metadata.
338+
*
339+
* @param descs [in] Descriptor list to include in the metadata
340+
* @param str [out] The serialized metadata blob
341+
* @param extra_params [in] Optional extra parameters used in getting partial metadata
342+
* @return nixl_status_t Error code if call was not successful
343+
*/
344+
nixl_status_t
345+
getLocalPartialMD(const nixl_reg_dlist_t &descs,
346+
nixl_blob_t &str,
347+
const nixl_opt_args_t* extra_params = nullptr) const;
348+
349+
/**
350+
* @brief Load other agent's metadata and unpack it internally. Now the local
351+
* agent can initiate transfers towards the remote agent.
352+
*
353+
* @param remote_metadata Serialized metadata blob to be loaded
354+
* @param agent_name [out] Agent name extracted from the loaded metadata blob
355+
* @return nixl_status_t Error code if call was not successful
356+
*/
357+
nixl_status_t
358+
loadRemoteMD (const nixl_blob_t &remote_metadata,
359+
std::string &agent_name);
360+
361+
/**
362+
* @brief Invalidate the remote agent metadata cached locally. This will
363+
* disconnect from that agent if already connected, and no more
364+
* transfers can be initiated towards that agent.
365+
*
366+
* @param remote_agent Remote agent name to invalidate its metadata blob
367+
* @return nixl_status_t Error code if call was not successful
368+
*/
369+
nixl_status_t
370+
invalidateRemoteMD (const std::string &remote_agent);
371+
319372
/*** Metadata handling through direct channels (p2p socket and ETCD) ***/
320373
/**
321374
* @brief Send your own agent metadata to a remote location.
@@ -348,7 +401,7 @@ class nixlAgent {
348401
* @return nixl_status_t Error code if call was not successful
349402
*/
350403
nixl_status_t
351-
sendLocalPartialMD(nixl_reg_dlist_t &descs,
404+
sendLocalPartialMD(const nixl_reg_dlist_t &descs,
352405
const nixl_opt_args_t* extra_params = nullptr) const;
353406

354407
/**
@@ -379,58 +432,18 @@ class nixlAgent {
379432
nixl_status_t
380433
invalidateLocalMD (const nixl_opt_args_t* extra_params = nullptr) const;
381434

382-
/*** Metadata handling through side channel ***/
383435
/**
384-
* @brief Get metadata blob for this agent, to be given to other agents.
436+
* @brief Check if metadata is available for a remote agent.
437+
* For partial metadata methods are used, the descriptor list in question
438+
* can be specified; otherwise, empty `descs` can be passed.
385439
*
386-
* @param str [out] The serialized metadata blob
387-
* @return nixl_status_t Error code if call was not successful
440+
* @param str Remote agent to check for
441+
* @return nixl_status_t Error code, NOT_FOUND if metadata not found
388442
*/
389443
nixl_status_t
390-
getLocalMD (nixl_blob_t &str) const;
444+
checkRemoteMD (const std::string remote_name,
445+
const nixl_xfer_dlist_t &descs) const;
391446

392-
/**
393-
* @brief Get partial metadata blob for this agent, to be given to other agents.
394-
* If `descs` is empty, only backends' connection info is included in the metadata,
395-
* regardless of the value of `extra_params->includeConnInfo` and `descs` memory type.
396-
* If `descs` is non-empty, the metadata of the descriptors in the list are included,
397-
* and if `extra_params->includeConnInfo` is true, the connection info of the
398-
* backends supporting the memory type is also included.
399-
* If `extra_params->backends` is non-empty, only the descriptors supported by the
400-
* backends in the list and the backends' connection info are included in the metadata.
401-
*
402-
* @param descs [in] Descriptor list to include in the metadata
403-
* @param str [out] The serialized metadata blob
404-
* @param extra_params [in] Optional extra parameters used in getting partial metadata
405-
* @return nixl_status_t Error code if call was not successful
406-
*/
407-
nixl_status_t
408-
getLocalPartialMD(nixl_reg_dlist_t &descs,
409-
nixl_blob_t &str,
410-
const nixl_opt_args_t* extra_params = nullptr) const;
411-
412-
/**
413-
* @brief Load other agent's metadata and unpack it internally. Now the local
414-
* agent can initiate transfers towards the remote agent.
415-
*
416-
* @param remote_metadata Serialized metadata blob to be loaded
417-
* @param agent_name [out] Agent name extracted from the loaded metadata blob
418-
* @return nixl_status_t Error code if call was not successful
419-
*/
420-
nixl_status_t
421-
loadRemoteMD (const nixl_blob_t &remote_metadata,
422-
std::string &agent_name);
423-
424-
/**
425-
* @brief Invalidate the remote agent metadata cached locally. This will
426-
* disconnect from that agent if already connected, and no more
427-
* transfers can be initiated towards that agent.
428-
*
429-
* @param remote_agent Remote agent name to invalidate its metadata blob
430-
* @return nixl_status_t Error code if call was not successful
431-
*/
432-
nixl_status_t
433-
invalidateRemoteMD (const std::string &remote_agent);
434447
};
435448

436449
#endif

src/api/cpp/nixl_params.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ class nixlAgentConfig {
7070
nixlAgentConfig (const bool use_prog_thread,
7171
const bool use_listen_thread=false,
7272
const int port=0,
73+
nixl_thread_sync_t sync_mode=nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT,
7374
const uint64_t pthr_delay_us=0,
74-
const uint64_t lthr_delay_us = 100000,
75-
nixl_thread_sync_t sync_mode=nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT) :
75+
const uint64_t lthr_delay_us = 100000) :
7676
useProgThread(use_prog_thread),
7777
useListenThread(use_listen_thread),
7878
listenPort(port),

src/api/python/_api.py

+46-18
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,18 @@ def __init__(
8181
if not nixl_conf:
8282
nixl_conf = nixl_agent_config() # Using defaults set in nixl_agent_config
8383

84+
thread_config = (
85+
nixlBind.NIXL_THREAD_SYNC_STRICT
86+
if nixl_conf.enable_listen
87+
else nixlBind.NIXL_THREAD_SYNC_NONE
88+
)
89+
8490
# Set agent config and instantiate an agent
8591
agent_config = nixlBind.nixlAgentConfig(
86-
nixl_conf.enable_pthread, nixl_conf.enable_listen, nixl_conf.port
92+
nixl_conf.enable_pthread,
93+
nixl_conf.enable_listen,
94+
nixl_conf.port,
95+
thread_config,
8796
)
8897
self.agent = nixlBind.nixlAgent(agent_name, agent_config)
8998

@@ -612,6 +621,29 @@ def get_partial_agent_metadata(
612621
handle_list.append(self.backends[backend_string])
613622
return self.agent.getLocalPartialMD(descs, inc_conn_info, handle_list)
614623

624+
"""
625+
@brief Add a remote agent using its metadata. After this call, current agent can
626+
initiate transfers towards the remote agent.
627+
628+
@param metadata Metadata of the remote agent, received out-of-band in bytes.
629+
@return Name of the added remote agent.
630+
"""
631+
632+
def add_remote_agent(self, metadata: bytes) -> str:
633+
agent_name = self.agent.loadRemoteMD(metadata)
634+
return agent_name
635+
636+
"""
637+
@brief Remove a remote agent. After this call, current agent cannot initiate
638+
transfers towards the remote agent specified in the call anymore.
639+
This call will also result in a disconnect between the two agents.
640+
641+
@param agent Name of the remote agent.
642+
"""
643+
644+
def remove_remote_agent(self, agent: str):
645+
self.agent.invalidateRemoteMD(agent)
646+
615647
"""
616648
@brief Send all of your metadata to a peer or central metadata server.
617649
@@ -674,27 +706,23 @@ def invalidate_local_metadata(
674706
self.agent.invalidateLocalMD(ip_addr, port)
675707

676708
"""
677-
@brief Add a remote agent using its metadata. After this call, current agent can
678-
initiate transfers towards the remote agent.
679-
680-
@param metadata Metadata of the remote agent, received out-of-band in bytes.
681-
@return Name of the added remote agent.
682-
"""
683-
684-
def add_remote_agent(self, metadata: bytes) -> str:
685-
agent_name = self.agent.loadRemoteMD(metadata)
686-
return agent_name
687-
688-
"""
689-
@brief Remove a remote agent. After this call, current agent cannot initiate
690-
transfers towards the remote agent specified in the call anymore.
691-
This call will also result in a disconnect between the two agents.
709+
@brief Check if the remote metadata for a specific agent is available.
710+
When partial metadata methods are used, the descriptor list in question can be specified.
692711
693712
@param agent Name of the remote agent.
713+
714+
@return True if available, False otherwise
694715
"""
695716

696-
def remove_remote_agent(self, agent: str):
697-
self.agent.invalidateRemoteMD(agent)
717+
def check_remote_metadata(
718+
self, agent: str, descs: nixlBind.nixlXferDList = None
719+
) -> bool:
720+
if descs is None: # Just empty list, mem_type not important
721+
descs = nixlBind.nixlXferDList(nixlBind.DRAM_SEG)
722+
if self.agent.checkRemoteMD(agent, descs) == nixlBind.NIXL_SUCCESS:
723+
return True
724+
else:
725+
return False
698726

699727
"""
700728
@brief Get nixlXferDList from different input types:

src/bindings/python/nixl_bindings.cpp

+16-8
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ PYBIND11_MODULE(_bindings, m) {
123123
m.attr("DEFAULT_COMM_PORT") = default_comm_port;
124124

125125
//cast types
126+
py::enum_<nixl_thread_sync_t>(m, "nixl_thread_sync_t")
127+
.value("NIXL_THREAD_SYNC_NONE", nixl_thread_sync_t::NIXL_THREAD_SYNC_NONE)
128+
.value("NIXL_THREAD_SYNC_STRICT", nixl_thread_sync_t::NIXL_THREAD_SYNC_STRICT)
129+
.value("NIXL_THREAD_SYNC_DEFAULT", nixl_thread_sync_t::NIXL_THREAD_SYNC_DEFAULT)
130+
.export_values();
131+
126132
py::enum_<nixl_mem_t>(m, "nixl_mem_t")
127133
.value("DRAM_SEG", DRAM_SEG)
128134
.value("VRAM_SEG", VRAM_SEG)
@@ -282,7 +288,8 @@ PYBIND11_MODULE(_bindings, m) {
282288
//implicit constructor
283289
.def(py::init<bool>())
284290
.def(py::init<bool, bool>())
285-
.def(py::init<bool, bool, int>());
291+
.def(py::init<bool, bool, int>())
292+
.def(py::init<bool, bool, int, nixl_thread_sync_t>());
286293

287294
//note: pybind will automatically convert notif_map to python types:
288295
//so, a Dictionary of string: List<string>
@@ -504,6 +511,13 @@ PYBIND11_MODULE(_bindings, m) {
504511
throw_nixl_exception(agent.getLocalPartialMD(descs, ret_str, &extra_params));
505512
return py::bytes(ret_str);
506513
}, py::arg("descs"), py::arg("inc_conn_info") = false, py::arg("backends") = std::vector<uintptr_t>({}))
514+
.def("loadRemoteMD", [](nixlAgent &agent, const std::string &remote_metadata) -> py::bytes {
515+
//python can only interpret text strings
516+
std::string remote_name("");
517+
throw_nixl_exception(agent.loadRemoteMD(remote_metadata, remote_name));
518+
return py::bytes(remote_name);
519+
})
520+
.def("invalidateRemoteMD", &nixlAgent::invalidateRemoteMD)
507521
.def("sendLocalMD", [](nixlAgent &agent, std::string ip_addr, int port){
508522
nixl_opt_args_t extra_params;
509523

@@ -542,11 +556,5 @@ PYBIND11_MODULE(_bindings, m) {
542556

543557
throw_nixl_exception(agent.invalidateLocalMD(&extra_params));
544558
}, py::arg("ip_addr") = std::string(""), py::arg("port") = 0 )
545-
.def("loadRemoteMD", [](nixlAgent &agent, const std::string &remote_metadata) -> py::bytes {
546-
//python can only interpret text strings
547-
std::string remote_name("");
548-
throw_nixl_exception(agent.loadRemoteMD(remote_metadata, remote_name));
549-
return py::bytes(remote_name);
550-
})
551-
.def("invalidateRemoteMD", &nixlAgent::invalidateRemoteMD);
559+
.def("checkRemoteMD", &nixlAgent::checkRemoteMD);
552560
}

src/core/meson.build

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
nixl_lib_deps = [nixl_infra, serdes_interface, stream_interface, dl_dep]
16+
# Add dependency on the common utility library which brings in logging deps
17+
nixl_lib_deps = [nixl_infra, serdes_interface, stream_interface, dl_dep, nixl_common_dep, thread_dep]
1718

1819
if 'UCX' in static_plugins
1920
nixl_lib_deps += [ ucx_backend_interface, cuda_dep ]

0 commit comments

Comments
 (0)