|
| 1 | +/* |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + * |
| 5 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | + * you may not use this file except in compliance with the License. |
| 7 | + * You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +#include <iostream> |
| 19 | +#include <liburing.h> |
| 20 | +#include <algorithm> |
| 21 | +#include <cmath> |
| 22 | +#include "posix_backend.h" |
| 23 | +#include <unordered_map> |
| 24 | + |
| 25 | +namespace { |
| 26 | + static constexpr int max_posix_ring_log = 6; //posix_ring_log |
| 27 | + static constexpr int max_posix_ring_size = 1 << max_posix_ring_log; |
| 28 | +} |
| 29 | + |
| 30 | +void nixlPosixBackendReqH::createUringParams() { |
| 31 | + // We only need 2 params, as only the last one will be the remainder, and others will be max_posix_ring_size |
| 32 | + num_uring_params = std::min(num_urings, 2); |
| 33 | + uring_params.resize(num_uring_params); |
| 34 | + |
| 35 | + // Calculate the remainder, but if it's 0, set to max_posix_ring_size instead |
| 36 | + unsigned int remainder = (this->local_desc_count - 1) % max_posix_ring_size + 1; |
| 37 | + |
| 38 | + switch (num_uring_params) { |
| 39 | + case 2: |
| 40 | + uring_params[1] = { |
| 41 | + .sq_entries = max_posix_ring_size, |
| 42 | + .cq_entries = max_posix_ring_size, |
| 43 | + }; |
| 44 | + case 1: |
| 45 | + uring_params[0] = { |
| 46 | + .sq_entries = remainder, |
| 47 | + .cq_entries = remainder, |
| 48 | + }; |
| 49 | + break; |
| 50 | + default: |
| 51 | + throw UringInitError(); |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +void nixlPosixBackendReqH::exitUrings(int num_urings) { |
| 56 | + for (int i = 0; i < num_urings; ++i) |
| 57 | + io_uring_queue_exit(&uring[i]); |
| 58 | +} |
| 59 | + |
| 60 | +int nixlPosixBackendReqH::createUringsAndNumEntries() { |
| 61 | + num_entries.resize(num_urings); |
| 62 | + uring.resize(num_urings); |
| 63 | + // Calculate the remainder, but if it's 0, set to max_posix_ring_size instead |
| 64 | + unsigned int remainder = (this->local_desc_count - 1) % max_posix_ring_size + 1; |
| 65 | + |
| 66 | + for (int i = 0; i < num_urings - 1; ++i) { |
| 67 | + num_entries[i] = max_posix_ring_size; |
| 68 | + int uring_init_status = io_uring_queue_init_params(max_posix_ring_size, &uring[i], &uring_params[1]); |
| 69 | + if (uring_init_status != 0) { |
| 70 | + std::cerr << "Error in initializing io_uring " << i << "\n"; |
| 71 | + this->exitUrings(i); |
| 72 | + return uring_init_status; |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + num_entries[num_urings - 1] = remainder; |
| 77 | + int uring_init_status = io_uring_queue_init_params(remainder, &uring[num_urings - 1], &uring_params[0]); |
| 78 | + if (uring_init_status != 0) { |
| 79 | + std::cerr << "Error in initializing io_uring " << num_urings - 1 << "\n"; |
| 80 | + this->exitUrings(num_urings); |
| 81 | + } |
| 82 | + |
| 83 | + return uring_init_status; |
| 84 | +} |
| 85 | + |
| 86 | +const std::unordered_map<nixl_xfer_op_t, nixlPosixBackendReqH::io_uring_prep_func_t> nixlPosixBackendReqH::op_to_func = { |
| 87 | + {NIXL_READ, (io_uring_prep_func_t)io_uring_prep_read}, |
| 88 | + {NIXL_WRITE, (io_uring_prep_func_t)io_uring_prep_write} |
| 89 | +}; |
| 90 | + |
| 91 | +nixlPosixBackendReqH::nixlPosixBackendReqH(const nixl_xfer_op_t &operation, |
| 92 | + const nixl_meta_dlist_t &local, |
| 93 | + const nixl_meta_dlist_t &remote, |
| 94 | + const nixl_opt_b_args_t* opt_args) |
| 95 | + : local_desc_count(local.descCount()), operation(operation), local(local), |
| 96 | + remote(remote), opt_args(opt_args), is_prepped(false), erred(false) |
| 97 | +{ |
| 98 | + // Initialize io_uring_prep_func based on the operation |
| 99 | + auto it = this->op_to_func.find(operation); |
| 100 | + if (it == this->op_to_func.end()) { |
| 101 | + throw OperationError(); |
| 102 | + } |
| 103 | + this->io_uring_prep_func = it->second; |
| 104 | + |
| 105 | + // Calculate the number of io_urings needed based on the number of descriptors |
| 106 | + this->num_urings = static_cast<int>(std::ceil(static_cast<double>(local_desc_count) / max_posix_ring_size)); |
| 107 | + |
| 108 | + this->num_completed.resize(this->num_urings, 0); |
| 109 | + |
| 110 | + this->createUringParams(); |
| 111 | + |
| 112 | + if (this->createUringsAndNumEntries() != 0) |
| 113 | + throw UringInitError(); |
| 114 | +} |
| 115 | + |
| 116 | +nixlPosixBackendReqH::~nixlPosixBackendReqH() { |
| 117 | + this->exitUrings(this->num_urings); |
| 118 | +} |
| 119 | + |
| 120 | +bool nixlPosixBackendReqH::isErred() const { |
| 121 | + return erred; |
| 122 | +} |
| 123 | + |
| 124 | +nixl_status_t nixlPosixBackendReqH::updateNumCompleted(unsigned int uring_idx) { |
| 125 | + if (this->isErred()) // TODO: Talk with team if this is the correct error code |
| 126 | + return NIXL_ERR_INVALID_PARAM; |
| 127 | + |
| 128 | + if (this->num_completed[uring_idx] == this->num_entries[uring_idx]) |
| 129 | + return NIXL_SUCCESS; |
| 130 | + |
| 131 | + struct io_uring_cqe *cqes[max_posix_ring_size]; |
| 132 | + |
| 133 | + int num_ret_cqes = io_uring_peek_batch_cqe(&this->uring[uring_idx], cqes, max_posix_ring_size); |
| 134 | + for (int i = 0; i < num_ret_cqes; ++i) { |
| 135 | + int res = cqes[i]->res; |
| 136 | + if (res < 0) { |
| 137 | + io_uring_cqe_seen(&this->uring[uring_idx], cqes[i]); |
| 138 | + this->erred = true; |
| 139 | + return NIXL_ERR_BACKEND; |
| 140 | + } |
| 141 | + io_uring_cqe_seen(&this->uring[uring_idx], cqes[i]); |
| 142 | + } |
| 143 | + |
| 144 | + this->num_completed[uring_idx] += num_ret_cqes; |
| 145 | + return NIXL_SUCCESS; |
| 146 | +} |
| 147 | + |
| 148 | +bool nixlPosixBackendReqH::isPrepped() const { |
| 149 | + return this->is_prepped; |
| 150 | +} |
| 151 | + |
| 152 | +nixl_status_t nixlPosixBackendReqH::prepXfer() { |
| 153 | + if (this->isErred()) // TODO: Talk with team if this is the correct error code |
| 154 | + return NIXL_ERR_INVALID_PARAM; |
| 155 | + |
| 156 | + if (this->is_prepped) |
| 157 | + return NIXL_SUCCESS; |
| 158 | + |
| 159 | + struct io_uring_sqe *entry; |
| 160 | + |
| 161 | + for (int global_entry_index = 0; global_entry_index < this->local_desc_count; ++global_entry_index) { |
| 162 | + // Calculate the index of the io_uring to use for this entry |
| 163 | + int uring_index = global_entry_index / max_posix_ring_size; |
| 164 | + entry = io_uring_get_sqe(&this->uring[uring_index]); |
| 165 | + if (!entry) { |
| 166 | + std::cerr << "Error in getting sqe\n"; |
| 167 | + return NIXL_ERR_BACKEND; |
| 168 | + } |
| 169 | + io_uring_prep_func(entry, this->remote[global_entry_index].devId, |
| 170 | + (void *)this->local[global_entry_index].addr, |
| 171 | + this->remote[global_entry_index].len, |
| 172 | + this->remote[global_entry_index].addr); |
| 173 | + } |
| 174 | + |
| 175 | + this->is_prepped = true; |
| 176 | + return NIXL_SUCCESS; |
| 177 | +} |
| 178 | + |
| 179 | +nixl_status_t nixlPosixBackendReqH::checkXfer() { |
| 180 | + if (this->isErred()) // TODO: Talk with team if this is the correct error code |
| 181 | + return NIXL_ERR_INVALID_PARAM; |
| 182 | + |
| 183 | + nixl_status_t status; |
| 184 | + for (int i = 0; i < this->num_urings; ++i) { |
| 185 | + status = this->updateNumCompleted(i); |
| 186 | + if (status != NIXL_SUCCESS) |
| 187 | + return status; |
| 188 | + if (this->num_completed[i] != this->num_entries[i]) |
| 189 | + return NIXL_IN_PROG; |
| 190 | + } |
| 191 | + |
| 192 | + return NIXL_SUCCESS; |
| 193 | +} |
| 194 | + |
| 195 | +nixl_status_t nixlPosixBackendReqH::postXfer() { |
| 196 | + for (int i = 0; i < this->num_urings; ++i) { |
| 197 | + int ret = io_uring_submit(&this->uring[i]); |
| 198 | + if (ret != this->num_entries[i]) { |
| 199 | + std::cerr << "Error in submitting io_uring " << i << " - "; |
| 200 | + std::cerr << "ret: " << ret << " - num_entries: " << this->num_entries[i] << std::endl; |
| 201 | + return NIXL_ERR_BACKEND; |
| 202 | + } |
| 203 | + } |
| 204 | + return NIXL_IN_PROG; |
| 205 | +} |
| 206 | + |
| 207 | +const nixl_mem_list_t nixlPosixEngine::supported_mems = { |
| 208 | + FILE_SEG, |
| 209 | + DRAM_SEG |
| 210 | +}; |
| 211 | + |
| 212 | +nixl_status_t nixlPosixEngine::registerMem(const nixlBlobDesc &mem, |
| 213 | + const nixl_mem_t &nixl_mem, |
| 214 | + nixlBackendMD* &out) { |
| 215 | + if (std::find(supported_mems.begin(), supported_mems.end(), nixl_mem) != supported_mems.end()) |
| 216 | + return NIXL_SUCCESS; |
| 217 | + |
| 218 | + return NIXL_ERR_NOT_SUPPORTED; |
| 219 | +} |
| 220 | + |
| 221 | +nixl_status_t nixlPosixEngine::deregisterMem(nixlBackendMD *) { |
| 222 | + return NIXL_SUCCESS; |
| 223 | +} |
| 224 | + |
| 225 | +bool isPrepXferParamsValid(const nixl_xfer_op_t &operation, |
| 226 | + const nixl_meta_dlist_t &local, |
| 227 | + const nixl_meta_dlist_t &remote) { |
| 228 | + bool is_valid = true; |
| 229 | + if (local.getType() != DRAM_SEG) { |
| 230 | + std::cerr << "Error: local memory type must be DRAM_SEG\n"; |
| 231 | + is_valid = false; |
| 232 | + } |
| 233 | + |
| 234 | + // TODO: Consult with team if to change these to else ifs (right now it tells all the errors) |
| 235 | + if (remote.getType() != FILE_SEG) { |
| 236 | + std::cerr << "Error: remote memory type must be FILE_SEG\n"; |
| 237 | + is_valid = false; |
| 238 | + } |
| 239 | + |
| 240 | + if (local.descCount() != remote.descCount()) { |
| 241 | + std::cerr <<"Error in count of local and remote descriptors\n"; |
| 242 | + is_valid = false; |
| 243 | + } |
| 244 | + |
| 245 | + if (local.getType() != DRAM_SEG) { |
| 246 | + std::cerr << "Error: local memory type must be DRAM_SEG\n"; |
| 247 | + is_valid = false; |
| 248 | + } |
| 249 | + |
| 250 | + if (remote.getType() != FILE_SEG) { |
| 251 | + std::cerr << "Error: remote memory type must be FILE_SEG\n"; |
| 252 | + is_valid = false; |
| 253 | + } |
| 254 | + |
| 255 | + return is_valid; |
| 256 | +} |
| 257 | + |
| 258 | +nixl_status_t nixlPosixEngine::prepXfer(const nixl_xfer_op_t &operation, |
| 259 | + const nixl_meta_dlist_t &local, |
| 260 | + const nixl_meta_dlist_t &remote, |
| 261 | + const std::string &remote_agent __attribute__((unused)), |
| 262 | + nixlBackendReqH* &handle, |
| 263 | + const nixl_opt_b_args_t* opt_args) { |
| 264 | + if (!isPrepXferParamsValid(operation, local, remote)) |
| 265 | + return NIXL_ERR_INVALID_PARAM; |
| 266 | + |
| 267 | + try { |
| 268 | + auto posix_handle = std::make_unique<nixlPosixBackendReqH>(operation, local, remote, opt_args); |
| 269 | + |
| 270 | + nixl_status_t status = posix_handle->prepXfer(); |
| 271 | + if (status != NIXL_SUCCESS) |
| 272 | + return status; |
| 273 | + |
| 274 | + handle = posix_handle.release(); |
| 275 | + return NIXL_SUCCESS; |
| 276 | + } catch (const nixlPosixBackendReqH::OperationError& e) { |
| 277 | + std::cerr << "Operation error: " << e.what() << std::endl; |
| 278 | + return NIXL_ERR_INVALID_PARAM; |
| 279 | + } catch (const nixlPosixBackendReqH::UringInitError& e) { |
| 280 | + std::cerr << "Uring initialization error: " << e.what() << std::endl; |
| 281 | + return NIXL_ERR_BACKEND; |
| 282 | + } catch (const std::exception& e) { |
| 283 | + std::cerr << "Unexpected error: " << e.what() << std::endl; |
| 284 | + return NIXL_ERR_BACKEND; |
| 285 | + } |
| 286 | +} |
| 287 | + |
| 288 | +nixl_status_t nixlPosixEngine::postXfer(const nixl_xfer_op_t &operation, |
| 289 | + const nixl_meta_dlist_t &local, |
| 290 | + const nixl_meta_dlist_t &remote, |
| 291 | + const std::string &remote_agent __attribute__((unused)), |
| 292 | + nixlBackendReqH* &handle, |
| 293 | + const nixl_opt_b_args_t* opt_args) { |
| 294 | + nixl_status_t status = NIXL_SUCCESS; |
| 295 | + |
| 296 | + nixlPosixBackendReqH *posix_handle = (nixlPosixBackendReqH *)handle; |
| 297 | + if (!posix_handle->isPrepped()) { |
| 298 | + status = this->prepXfer(operation, local, remote, remote_agent, handle, opt_args); |
| 299 | + if (status != NIXL_SUCCESS) |
| 300 | + return status; |
| 301 | + } |
| 302 | + |
| 303 | + status = posix_handle->postXfer(); |
| 304 | + if (status != NIXL_SUCCESS && status != NIXL_IN_PROG) |
| 305 | + std::cerr << "Error in submitting io_uring\n"; |
| 306 | + |
| 307 | + return status; |
| 308 | +} |
| 309 | + |
| 310 | +nixl_status_t nixlPosixEngine::checkXfer(nixlBackendReqH* handle) { |
| 311 | + nixlPosixBackendReqH *posix_handle = (nixlPosixBackendReqH *)handle; |
| 312 | + return posix_handle->checkXfer(); |
| 313 | +} |
| 314 | + |
| 315 | +nixl_status_t nixlPosixEngine::releaseReqH(nixlBackendReqH* handle) { |
| 316 | + nixlPosixBackendReqH *posix_handle = (nixlPosixBackendReqH *)handle; |
| 317 | + delete posix_handle; |
| 318 | + return NIXL_SUCCESS; |
| 319 | +} |
0 commit comments