Skip to content

Commit

Permalink
Support for Sorting Drivers based on the devices provided
Browse files Browse the repository at this point in the history
- Added support for reporting drivers sorted based on the types of
  devices returned.
- By default, drivers will be sorted such that the ordering will be:
	- Drivers with Discrete GPUs only
        - Drivers with Discrete and Integrated GPUs
        - Drivers with Integrated GPUs
        - Drivers with Mixed Devices Types (ie GPU + NPU)
        - Drivers with Non GPU Devices Only
- If ZE_ENABLE_PCI_ID_DEVICE_ORDER is set, then the following ordering
  is provided:
	- Drivers with Integrated GPUs
        - Drivers with Discrete and Integrated GPUs
        - Drivers with Discrete GPUs only
        - Drivers with Mixed Devices Types (ie GPU + NPU)
        - Drivers with Non GPU Devices Only

- Expanded the zello_world sample to test both init paths for easy
  prototyping.

Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
  • Loading branch information
nrspruit committed Jan 14, 2025
1 parent 3969f34 commit 5a4efe4
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 61 deletions.
8 changes: 6 additions & 2 deletions samples/include/zello_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ inline bool argparse( int argc, char *argv[],
}

//////////////////////////////////////////////////////////////////////////
inline bool init_ze( void )
inline bool init_ze( bool legacy_init , uint32_t &driverCount, ze_init_driver_type_desc_t &driverTypeDesc)
{
ze_result_t result;
// Initialize the driver
result = zeInit(0);
if (legacy_init) {
result = zeInit(0);
} else {
result = zeInitDrivers(&driverCount, nullptr, &driverTypeDesc);
}
if(result != ZE_RESULT_SUCCESS) {
std::cout << "Driver not initialized: " << to_string(result) << std::endl;
return false;
Expand Down
43 changes: 31 additions & 12 deletions samples/zello_world/zello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void print_loader_versions(){
int main( int argc, char *argv[] )
{
bool tracing_runtime_enabled = false;
bool legacy_init = false;
if( argparse( argc, argv, "-null", "--enable_null_driver" ) )
{
putenv_safe( const_cast<char *>( "ZE_ENABLE_NULL_DRIVER=1" ) );
Expand All @@ -61,13 +62,22 @@ int main( int argc, char *argv[] )
{
tracing_runtime_enabled = true;
}
if( argparse( argc, argv, "-legacy", "--enable_legacy_init" ) )
{
legacy_init = true;
}

ze_result_t status;
const ze_device_type_t type = ZE_DEVICE_TYPE_GPU;

ze_init_driver_type_desc_t driverTypeDesc = {ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC};
driverTypeDesc.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
driverTypeDesc.pNext = nullptr;

ze_driver_handle_t pDriver = nullptr;
ze_device_handle_t pDevice = nullptr;
if( init_ze() )
uint32_t driverCount = 0;
if( init_ze(legacy_init, driverCount, driverTypeDesc) )
{

print_loader_versions();
Expand All @@ -81,18 +91,27 @@ int main( int argc, char *argv[] )
}
}

uint32_t driverCount = 0;
status = zeDriverGet(&driverCount, nullptr);
if(status != ZE_RESULT_SUCCESS) {
std::cout << "zeDriverGet Failed with return code: " << to_string(status) << std::endl;
exit(1);
}
std::vector<ze_driver_handle_t> drivers;
if (legacy_init) {
status = zeDriverGet(&driverCount, nullptr);
if(status != ZE_RESULT_SUCCESS) {
std::cout << "zeDriverGet Failed with return code: " << to_string(status) << std::endl;
exit(1);
}

std::vector<ze_driver_handle_t> drivers( driverCount );
status = zeDriverGet( &driverCount, drivers.data() );
if(status != ZE_RESULT_SUCCESS) {
std::cout << "zeDriverGet Failed with return code: " << to_string(status) << std::endl;
exit(1);
drivers.resize(driverCount);
status = zeDriverGet( &driverCount, drivers.data() );
if(status != ZE_RESULT_SUCCESS) {
std::cout << "zeDriverGet Failed with return code: " << to_string(status) << std::endl;
exit(1);
}
} else {
drivers.resize(driverCount);
status = zeInitDrivers( &driverCount, drivers.data(), &driverTypeDesc);
if(status != ZE_RESULT_SUCCESS) {
std::cout << "zeInitDrivers Failed with return code: " << to_string(status) << std::endl;
exit(1);
}
}

for( uint32_t driver = 0; driver < driverCount; ++driver )
Expand Down
229 changes: 182 additions & 47 deletions source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,47 @@ namespace loader
return flags_value;
}

bool driverSortComparitor(const driver_t &a, const driver_t &b) {
bool pciOrderingRequested = getenv_tobool( "ZE_ENABLE_PCI_ID_DEVICE_ORDER" );
if (pciOrderingRequested) {
if (a.driverType == ZEL_DRIVER_TYPE_OTHER) {
return false;
}
if (a.driverType == ZEL_DRIVER_TYPE_MIXED && b.driverType == ZEL_DRIVER_TYPE_OTHER) {
return true;
} else if(a.driverType == ZEL_DRIVER_TYPE_MIXED) {
return false;
}
return a.driverType > b.driverType;
}
return a.driverType < b.driverType;
}

/**
* @brief Checks and initializes drivers based on the provided flags and descriptors.
*
* This function performs the following operations:
* 1. If debug tracing is enabled, logs the input parameters.
* 2. If `zeInitDrivers` is not supported by the driver and it is called first, returns `ZE_RESULT_ERROR_UNINITIALIZED`.
* 3. Determines the appropriate driver vector (`zeDrivers` or `zesDrivers`) based on the input parameters.
* 4. Iterates over the drivers and attempts to initialize each driver:
* - If initialization fails and the driver is not in use, removes the driver from the list.
* - If the number of drivers becomes one and interception is not forced, sets the `requireDdiReinit` flag to true.
* - If the initialization fails and `return_first_driver_result` is true, returns the result immediately.
* - If initialization succeeds, marks the driver as in use.
* 5. Sorts the drivers in ascending order of driver type using `driverSortComparitor`.
* 6. Logs the sorted driver list if debug tracing is enabled.
* 7. If no drivers are left, returns `ZE_RESULT_ERROR_UNINITIALIZED`.
* 8. Returns `ZE_RESULT_SUCCESS` if at least one driver is successfully initialized.
*
* @param flags Initialization flags.
* @param desc Driver type descriptor (optional).
* @param globalInitStored Pointer to global DDI table for initialization.
* @param sysmanGlobalInitStored Pointer to Sysman global DDI table for initialization.
* @param requireDdiReinit Pointer to a boolean flag indicating if DDI reinitialization is required.
* @param sysmanOnly Boolean flag indicating if only Sysman drivers should be checked.
* @return `ZE_RESULT_SUCCESS` if at least one driver is successfully initialized, otherwise an appropriate error code.
*/
ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly) {
if (debugTraceEnabled) {
if (desc) {
Expand All @@ -147,8 +188,10 @@ namespace loader
}
// If zeInitDrivers is not supported by this driver, but zeInitDrivers is called first, then return uninitialized.
if (desc && !loader::context->initDriversSupport) {
std::string message = "zeInitDrivers called first, but not supported by driver, returning uninitialized.";
debug_trace_message(message, "");
if (debugTraceEnabled) {
std::string message = "zeInitDrivers called first, but not supported by driver, returning uninitialized.";
debug_trace_message(message, "");
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

Expand Down Expand Up @@ -199,14 +242,24 @@ namespace loader
}
}

// Sort drivers in ascending order of driver type unless ZE_ENABLE_PCI_ID_DEVICE_ORDER, then in decending order with MIXED and OTHER at the end.
std::sort(drivers->begin(), drivers->end(), driverSortComparitor);

if (debugTraceEnabled) {
std::string message = "Drivers after sorting:";
for (const auto& driver : *drivers) {
message += "\nDriver Type: " + std::to_string(driver.driverType) + " Driver Name: " + driver.name;
}
debug_trace_message(message, "");
}

if(drivers->size() == 0)
return ZE_RESULT_ERROR_UNINITIALIZED;

return ZE_RESULT_SUCCESS;
}

ze_result_t context_t::init_driver(driver_t &driver, ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly) {

if (sysmanOnly) {
auto getTable = reinterpret_cast<zes_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(driver.handle, "zesGetGlobalProcAddrTable"));
Expand Down Expand Up @@ -248,36 +301,11 @@ namespace loader
}
return res;
} else {
auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable"));
if(!getTable) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

ze_global_dditable_t global;
auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global);
if(getTableResult != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() failed with ";
debug_trace_message(errorMessage, loader::to_string(getTableResult));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
uint32_t pCount = 0;
std::vector<ze_driver_handle_t> driverHandles;

if (!desc) {
if(nullptr == global.pfnInit) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

auto pfnInit = global.pfnInit;
auto pfnInit = driver.dditable.ze.Global.pfnInit;
if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
Expand All @@ -291,22 +319,33 @@ namespace loader
// Verify that this driver successfully init in the call above.
if (driver.initStatus != ZE_RESULT_SUCCESS) {
res = driver.initStatus;
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning ";
debug_trace_message(message, loader::to_string(res));
res = driver.dditable.ze.Driver.pfnGet(&pCount, nullptr);
// Verify that this driver successfully init in the call above.
if (res != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeDriverGet(" + loader::to_string(desc) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
return res;
} else {
if(nullptr == global.pfnInitDrivers) {
driverHandles.resize(pCount);
res = driver.dditable.ze.Driver.pfnGet(&pCount, driverHandles.data());
// Verify that this driver successfully init in the call above.
if (res != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInitDrivers function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
std::string message = "init driver " + driver.name + " zeDriverGet(" + loader::to_string(desc) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
return res;
}

auto pfnInitDrivers = global.pfnInitDrivers;
} else {
auto pfnInitDrivers = driver.dditable.ze.Global.pfnInitDrivers;
if(nullptr == pfnInitDrivers || globalInitStored->pfnInitDrivers == nullptr) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, pfnInitDrivers function pointer null. Returning ";
Expand All @@ -315,19 +354,115 @@ namespace loader
return ZE_RESULT_ERROR_UNINITIALIZED;
}

pCount = 0;
// Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls
uint32_t pCount = 0;
ze_result_t res = globalInitStored->pfnInitDrivers(&pCount, nullptr, desc);
// Verify that this driver successfully init in the call above.
if (driver.initDriversStatus != ZE_RESULT_SUCCESS) {
res = driver.initDriversStatus;
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInitDrivers(" + loader::to_string(desc) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInitDrivers(" + loader::to_string(desc) + ") returning ";
debug_trace_message(message, loader::to_string(res));

// Reset pCount to 0 when calling the driver init function from the driver's ddi table.
pCount = 0;
res = driver.dditable.ze.Global.pfnInitDrivers(&pCount, nullptr, desc);
// Verify that this driver successfully init in the call above.
if (res != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInitDrivers(" + loader::to_string(desc) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
driverHandles.resize(pCount);
// Use the driver's init function to query the driver handles and read the properties.
res = driver.dditable.ze.Global.pfnInitDrivers(&pCount, driverHandles.data(), desc);
// Verify that this driver successfully init in the call above.
if (res != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInitDrivers(" + loader::to_string(desc) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
return res;
}

for (auto handle : driverHandles) {
ze_driver_properties_t properties;
ze_result_t res = driver.dditable.ze.Driver.pfnGetProperties(handle, &properties);
if (res != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " failed, zeDriverGetProperties returned ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
driver.properties = properties;
uint32_t deviceCount = 0;
res = driver.dditable.ze.Device.pfnGet( handle, &deviceCount, nullptr );
if( ZE_RESULT_SUCCESS != res ) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " failed, zeDeviceGet returned ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
if (deviceCount == 0) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " failed, zeDeviceGet returned 0 devices";
debug_trace_message(message, "");
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
std::vector<ze_device_handle_t> deviceHandles(deviceCount);
res = driver.dditable.ze.Device.pfnGet( handle, &deviceCount, deviceHandles.data() );
if( ZE_RESULT_SUCCESS != res ) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " failed, zeDeviceGet returned ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
bool integratedGPU = false;
bool discreteGPU = false;
bool other = false;
for( auto device : deviceHandles ) {
ze_device_properties_t deviceProperties;
res = driver.dditable.ze.Device.pfnGetProperties(device, &deviceProperties);
if( ZE_RESULT_SUCCESS != res ) {
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " failed, zeDeviceGetProperties returned ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
if (deviceProperties.type == ZE_DEVICE_TYPE_GPU) {
if (deviceProperties.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED) {
integratedGPU = true;
} else {
discreteGPU = true;
}
} else {
other = true;
}
}
if (integratedGPU && discreteGPU && other) {
driver.driverType = ZEL_DRIVER_TYPE_MIXED;
} else if (integratedGPU && discreteGPU) {
driver.driverType = ZEL_DRIVER_TYPE_GPU;
} else if (integratedGPU) {
driver.driverType = ZEL_DRIVER_TYPE_INTEGRATED_GPU;
} else if (discreteGPU) {
driver.driverType = ZEL_DRIVER_TYPE_DISCRETE_GPU;
} else if (other) {
driver.driverType = ZEL_DRIVER_TYPE_OTHER;
}
}
return ZE_RESULT_SUCCESS;
}
}

Expand Down
Loading

0 comments on commit 5a4efe4

Please sign in to comment.