diff --git a/source/lib/ze_lib.cpp b/source/lib/ze_lib.cpp index ff564dd2..c7707830 100644 --- a/source/lib/ze_lib.cpp +++ b/source/lib/ze_lib.cpp @@ -69,16 +69,32 @@ namespace ze_lib typedef HMODULE (ZE_APICALL *getTracing_t)(); auto getTracing = reinterpret_cast( GET_FUNCTION_PTR(loader, "zeLoaderGetTracingHandle") ); + if (getTracing == nullptr) { + printf("missing getTracing\n"); + return ZE_RESULT_ERROR_UNINITIALIZED; + } tracing_lib = getTracing(); typedef ze_result_t (ZE_APICALL *zelLoaderDriverCheck_t)(ze_init_flags_t flags, ze_init_driver_type_desc_t* desc, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly); auto loaderDriverCheck = reinterpret_cast( GET_FUNCTION_PTR(loader, "zelLoaderDriverCheck") ); + if (loaderDriverCheck == nullptr) { + printf("missing loaderDriverCheck\n"); + return ZE_RESULT_ERROR_UNINITIALIZED; + } typedef ze_result_t (ZE_APICALL *zelLoaderTracingLayerInit_t)(std::atomic &zeDdiTable); auto loaderTracingLayerInit = reinterpret_cast( GET_FUNCTION_PTR(loader, "zelLoaderTracingLayerInit") ); + if (loaderTracingLayerInit == nullptr) { + printf("missing loaderTracingLayerInit\n"); + return ZE_RESULT_ERROR_UNINITIALIZED; + } typedef loader::context_t * (ZE_APICALL *zelLoaderGetContext_t)(); auto loaderGetContext = reinterpret_cast( GET_FUNCTION_PTR(loader, "zelLoaderGetContext") ); + if (loaderGetContext == nullptr) { + printf("missing zelLoaderGetContext\n"); + } + printf("done loading functions\n"); #else result = zeLoaderInit(); if( ZE_RESULT_SUCCESS == result ) { @@ -95,10 +111,15 @@ namespace ze_lib } // Given zesInit, then zesDrivers needs to be used as the sysmanInstanceDrivers; + bool loaderContextAccessAllowed = true; #ifdef DYNAMIC_LOAD_LOADER - loader::context = loaderGetContext(); + if (loaderGetContext == nullptr) { + loaderContextAccessAllowed = false; + } else { + loader::context = loaderGetContext(); + } #endif - if (sysmanOnly) { + if (sysmanOnly && loaderContextAccessAllowed) { loader::context->sysmanInstanceDrivers = &loader::context->zesDrivers; } @@ -147,7 +168,7 @@ namespace ze_lib result = zelLoaderDriverCheck(flags, desc, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly); #endif // If a driver was removed from the driver list, then the ddi tables need to be reinit to allow for passthru directly to the driver. - if (requireDdiReinit) { + if (requireDdiReinit && loaderContextAccessAllowed) { // If a user has already called the core apis, then ddi table reinit is not possible due to handles already being read by the user. if (!sysmanOnly && !ze_lib::context->zeInuse) { // reInit the ZE DDI Tables diff --git a/source/loader/ze_loader_internal.h b/source/loader/ze_loader_internal.h index 087adef9..6de49357 100644 --- a/source/loader/ze_loader_internal.h +++ b/source/loader/ze_loader_internal.h @@ -131,6 +131,7 @@ namespace loader bool tracingLayerEnabled = false; dditable_t tracing_dditable = {}; std::shared_ptr zel_logger; + uint32_t init_counter = 0; }; extern context_t *context;