Skip to content

Commit

Permalink
Add support for older ddi tables
Browse files Browse the repository at this point in the history
- Added support for older dii tables such that the version of the apis
  supported by that dynamic loader are requested and not newer versions.
- Added check for backwards compatability for older init apis.

Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
  • Loading branch information
nrspruit committed Nov 15, 2024
1 parent 0cef9a7 commit b53eff8
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 119 deletions.
2 changes: 1 addition & 1 deletion source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ configure_file(
@ONLY)

include(GNUInstallDirs)
if (DYNAMIC_LOAD_LOADER)
if (BUILD_STATIC)
message(STATUS "Building loader as static library")
add_library(${TARGET_LOADER_NAME}
STATIC
Expand Down
2 changes: 0 additions & 2 deletions source/lib/linux/lib_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ namespace ze_lib
{
#ifndef DYNAMIC_LOAD_LOADER
void __attribute__((constructor)) createLibContext() {
printf("Context created context lib dynamic\n");
context = new context_t;
}
void __attribute__((destructor)) deleteLibContext() {
printf("Context destroyed context lib dynamic\n");
delete context;
}
#endif
Expand Down
128 changes: 100 additions & 28 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,7 @@ namespace ze_lib

///////////////////////////////////////////////////////////////////////////////
context_t::context_t()
{
#ifdef DYNAMIC_LOAD_LOADER
printf("Context created static\n");
#else
printf("Context created dynamic\n");
#endif
};
{};

///////////////////////////////////////////////////////////////////////////////
context_t::~context_t()
Expand All @@ -41,10 +35,17 @@ namespace ze_lib
ze_lib::destruction = true;
};

//////////////////////////////////////////////////////////////////////////
void debug_trace_message(const std::string &message, const std::string &extra)
{
std::cerr << message << " " << extra << std::endl;
}

//////////////////////////////////////////////////////////////////////////
__zedlllocal ze_result_t context_t::Init(ze_init_flags_t flags, bool sysmanOnly, ze_init_driver_type_desc_t* desc)
{
ze_result_t result;
ze_api_version_t version = ZE_API_VERSION_CURRENT;
#ifdef DYNAMIC_LOAD_LOADER
std::string loaderLibraryPath;
#ifdef _WIN32
Expand All @@ -54,47 +55,78 @@ namespace ze_lib
loader = LOAD_DRIVER_LIBRARY(loaderFullLibraryPath.c_str());

if( NULL == loader ) {
printf("loader load failed\n");
std::string message = "ze_lib Context Init() Loader Library Load Failed";
debug_trace_message(message, "");
return ZE_RESULT_ERROR_UNINITIALIZED;
}

typedef ze_result_t (ZE_APICALL *loaderInit_t)();
auto loaderInit = reinterpret_cast<loaderInit_t>(
GET_FUNCTION_PTR(loader, "zeLoaderInit") );
printf("calling loader init static\n");
result = loaderInit();
if( ZE_RESULT_SUCCESS != result ) {
std::string message = "ze_lib Context Init() Loader Init Failed";
debug_trace_message(message, "");
return result;
}
typedef HMODULE (ZE_APICALL *getTracing_t)();
auto getTracing = reinterpret_cast<getTracing_t>(
GET_FUNCTION_PTR(loader, "zeLoaderGetTracingHandle") );
if (getTracing == nullptr) {
printf("missing getTracing\n");
std::string message = "ze_lib Context Init() zeLoaderGetTracingHandle missing";
debug_trace_message(message, "");
return ZE_RESULT_ERROR_UNINITIALIZED;
}
tracing_lib = getTracing();
typedef ze_result_t (ZE_APICALL *zelLoaderDriverCheck_t)(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly);
auto loaderDriverCheck = reinterpret_cast<zelLoaderDriverCheck_t>(
GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") );
if (loaderDriverCheck == nullptr) {
printf("missing loaderDriverCheck\n");
return ZE_RESULT_ERROR_UNINITIALIZED;
}
typedef ze_result_t (ZE_APICALL *zelLoaderTracingLayerInit_t)(std::atomic<ze_dditable_t *> &zeDdiTable);
auto loaderTracingLayerInit = reinterpret_cast<zelLoaderTracingLayerInit_t>(
GET_FUNCTION_PTR(loader, "zelLoaderTracingLayerInit") );
if (loaderTracingLayerInit == nullptr) {
printf("missing loaderTracingLayerInit\n");
std::string message = "ze_lib Context Init() zelLoaderTracingLayerInit missing";
debug_trace_message(message, "");
return ZE_RESULT_ERROR_UNINITIALIZED;
}
typedef loader::context_t * (ZE_APICALL *zelLoaderGetContext_t)();
auto loaderGetContext = reinterpret_cast<zelLoaderGetContext_t>(
GET_FUNCTION_PTR(loader, "zelLoaderGetContext") );
if (loaderGetContext == nullptr) {
printf("missing zelLoaderGetContext\n");
std::string message = "ze_lib Context Init() zelLoaderGetContext missing";
debug_trace_message(message, "");
}

size_t size = 0;
result = zelLoaderGetVersions(&size, nullptr);
if (ZE_RESULT_SUCCESS != result) {
std::string message = "ze_lib Context Init() zelLoaderGetVersions Failed";
debug_trace_message(message, "");
return result;
}

std::vector<zel_component_version_t> versions(size);
result = zelLoaderGetVersions(&size, versions.data());
if (ZE_RESULT_SUCCESS != result) {
std::string message = "ze_lib Context Init() zelLoaderGetVersions Failed to read component versions";
debug_trace_message(message, "");
return result;
}
bool zeInitDriversSupport = true;
const std::string loader_name = "loader";
for (auto &component : versions) {
if (loader_name == component.component_name) {
version = component.spec_version;
if(component.component_lib_version.major == 1) {
if (component.component_lib_version.minor < 18) {
std::string message = "ze_lib Context Init() Version Does not support zeInitDrivers";
debug_trace_message(message, "");
zeInitDriversSupport = false;
}
} else {
std::string message = "ze_lib Context Init() Loader version is too new";
debug_trace_message(message, "");
return ZE_RESULT_ERROR_UNSUPPORTED_VERSION;
}
}
}
printf("done loading functions\n");
#else
result = zeLoaderInit();
if( ZE_RESULT_SUCCESS == result ) {
Expand Down Expand Up @@ -128,22 +160,38 @@ namespace ze_lib
// Init the ZE DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zeDdiTableInit();
result = zeDdiTableInit(version);
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zeDdiTableInit failed";
debug_trace_message(message, "");
}
}
// Init the ZET DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zetDdiTableInit();
result = zetDdiTableInit(version);
if( ZE_RESULT_SUCCESS != result ) {
std::string message = "ze_lib Context Init() zetDdiTableInit failed";
debug_trace_message(message, "");
}
}
// Init the ZES DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zesDdiTableInit();
result = zesDdiTableInit(version);
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zesDdiTableInit failed";
debug_trace_message(message, "");
}
}
// Init the Tracing API DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zelTracingDdiTableInit();
result = zelTracingDdiTableInit(version);
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zelTracingDdiTableInit failed";
debug_trace_message(message, "");
}
}
// Init the stored ddi tables for the tracing layer
if( ZE_RESULT_SUCCESS == result )
Expand All @@ -163,23 +211,47 @@ namespace ze_lib
// No need to check if only initializing sysman
bool requireDdiReinit = false;
#ifdef DYNAMIC_LOAD_LOADER
result = loaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
if (zeInitDriversSupport) {
typedef ze_result_t (ZE_APICALL *zelLoaderDriverCheck_t)(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly);
auto loaderDriverCheck = reinterpret_cast<zelLoaderDriverCheck_t>(
GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") );
if (loaderDriverCheck == nullptr) {
std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing";
debug_trace_message(message, "");
return ZE_RESULT_ERROR_UNINITIALIZED;
}
result = loaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
} else {
typedef ze_result_t (ZE_APICALL *zelLoaderDriverCheck_t)(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly);
auto loaderDriverCheck = reinterpret_cast<zelLoaderDriverCheck_t>(
GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") );
if (loaderDriverCheck == nullptr) {
std::string message = "ze_lib Context Init() zelLoaderDriverCheck missing";
debug_trace_message(message, "");
return ZE_RESULT_ERROR_UNINITIALIZED;
}
result = loaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
}
#else
result = zelLoaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
#endif
if (result != ZE_RESULT_SUCCESS) {
std::string message = "ze_lib Context Init() zelLoaderDriverCheck failed";
debug_trace_message(message, "");
}
// If a driver was removed from the driver list, then the ddi tables need to be reinit to allow for passthru directly to the driver.
if (requireDdiReinit && loaderContextAccessAllowed) {
// If a user has already called the core apis, then ddi table reinit is not possible due to handles already being read by the user.
if (!sysmanOnly && !ze_lib::context->zeInuse) {
// reInit the ZE DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zeDdiTableInit();
result = zeDdiTableInit(version);
}
// reInit the ZET DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zetDdiTableInit();
result = zetDdiTableInit(version);
}
// If ze/zet ddi tables have been reinit and no longer use the intercept layer, then handles passed to zelLoaderTranslateHandleInternal do not require translation.
// Setting intercept_enabled==false changes the behavior of zelLoaderTranslateHandleInternal to avoid translation.
Expand All @@ -191,7 +263,7 @@ namespace ze_lib
// reInit the ZES DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zesDdiTableInit();
result = zesDdiTableInit(version);
}
}
}
Expand Down
17 changes: 13 additions & 4 deletions source/lib/ze_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,31 @@ namespace ze_lib
context_t();
~context_t();

///////////////////////////////////////////////////////////////////////////////
template <typename T, typename TableType>
ze_result_t getTableWithCheck(T getTable, ze_api_version_t version, TableType* table) {
if (getTable == nullptr) {
return ZE_RESULT_ERROR_UNINITIALIZED;
}
return getTable(version, table);
}

std::once_flag initOnce;
std::once_flag initOnceDrivers;
std::once_flag initOnceSysMan;

ze_result_t Init(ze_init_flags_t flags, bool sysmanOnly, ze_init_driver_type_desc_t* desc);

ze_result_t zeDdiTableInit();
ze_result_t zeDdiTableInit(ze_api_version_t version);
std::atomic<ze_dditable_t *> zeDdiTable = {nullptr};

ze_result_t zetDdiTableInit();
ze_result_t zetDdiTableInit(ze_api_version_t version);
std::atomic<zet_dditable_t *> zetDdiTable = {nullptr};

ze_result_t zesDdiTableInit();
ze_result_t zesDdiTableInit(ze_api_version_t version);
std::atomic<zes_dditable_t *> zesDdiTable = {nullptr};

ze_result_t zelTracingDdiTableInit();
ze_result_t zelTracingDdiTableInit(ze_api_version_t version);
zel_tracing_dditable_t zelTracingDdiTable = {};
std::atomic<ze_dditable_t *> pTracingZeDdiTable = {nullptr};
ze_dditable_t initialzeDdiTable;
Expand Down
Loading

0 comments on commit b53eff8

Please sign in to comment.