@@ -363,46 +363,83 @@ void LevelZeroCompilerInDriver<TableExtension>::release(std::shared_ptr<const Ne
363
363
}
364
364
365
365
template <typename TableExtension>
366
- std::vector<uint8_t > LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
367
- std::shared_ptr<const NetworkDescription> networkDescription) {
368
- if (networkDescription->metadata .graphHandle != nullptr && networkDescription->compiledNetwork .size () == 0 ) {
366
+ template <typename T, std::enable_if_t <UseCopyForNativeBinary(T), bool >>
367
+ void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditable_ext_curr_t & graphDdiTableExt,
368
+ ze_graph_handle_t graphHandle,
369
+ std::vector<uint8_t >& blob,
370
+ uint8_t *& blobPtr,
371
+ size_t & blobSize) const {
372
+ // Get blob size first
373
+ auto result = _graphDdiTableExt.pfnGetNativeBinary (graphHandle, &blobSize, nullptr );
374
+ blob.resize (blobSize);
375
+
376
+ OPENVINO_ASSERT (result == ZE_RESULT_SUCCESS,
377
+ " Failed to compile network. L0 pfnGetNativeBinary get blob size" ,
378
+ " result: " ,
379
+ ze_result_to_string (result),
380
+ " , code 0x" ,
381
+ std::hex,
382
+ uint64_t (result),
383
+ " . " ,
384
+ getLatestBuildError ());
385
+
386
+ // Get blob data
387
+ result = _graphDdiTableExt.pfnGetNativeBinary (graphHandle, &blobSize, blob.data ());
388
+
389
+ OPENVINO_ASSERT (result == ZE_RESULT_SUCCESS,
390
+ " Failed to compile network. L0 pfnGetNativeBinary get blob data" ,
391
+ " result: " ,
392
+ ze_result_to_string (result),
393
+ " , code 0x" ,
394
+ std::hex,
395
+ uint64_t (result),
396
+ " . " ,
397
+ getLatestBuildError ());
398
+
399
+ blobPtr = blob.data ();
400
+ }
401
+
402
+ template <typename TableExtension>
403
+ template <typename T, std::enable_if_t <!UseCopyForNativeBinary(T), bool >>
404
+ void LevelZeroCompilerInDriver<TableExtension>::getNativeBinary(ze_graph_dditable_ext_curr_t & graphDdiTableExt,
405
+ ze_graph_handle_t graphHandle,
406
+ std::vector<uint8_t >& /* unusedBlob */ ,
407
+ uint8_t *& blobPtr,
408
+ size_t & blobSize) const {
409
+ // Get blob ptr and size
410
+ auto result = _graphDdiTableExt.pfnGetNativeBinary2 (graphHandle, &blobSize, &blobPtr);
411
+
412
+ OPENVINO_ASSERT (result == ZE_RESULT_SUCCESS,
413
+ " Failed to compile network. L0 pfnGetNativeBinary get blob size" ,
414
+ " result: " ,
415
+ ze_result_to_string (result),
416
+ " , code 0x" ,
417
+ std::hex,
418
+ uint64_t (result),
419
+ " . " ,
420
+ getLatestBuildError ());
421
+ }
422
+
423
+ template <typename TableExtension>
424
+ CompiledNetwork LevelZeroCompilerInDriver<TableExtension>::getCompiledNetwork(
425
+ const NetworkDescription& networkDescription) {
426
+ if (networkDescription.metadata .graphHandle != nullptr && networkDescription.compiledNetwork .size () == 0 ) {
369
427
_logger.info (" LevelZeroCompilerInDriver getCompiledNetwork get blob from graphHandle" );
370
- ze_graph_handle_t graphHandle = static_cast <ze_graph_handle_t >(networkDescription-> metadata .graphHandle );
428
+ ze_graph_handle_t graphHandle = static_cast <ze_graph_handle_t >(networkDescription. metadata .graphHandle );
371
429
372
- // Get blob size first
430
+ uint8_t * blobPtr = nullptr ;
373
431
size_t blobSize = -1 ;
432
+ std::vector<uint8_t > blob;
433
+
434
+ getNativeBinary (_graphDdiTableExt, graphHandle, blob, blobPtr, blobSize);
374
435
375
- auto result = _graphDdiTableExt.pfnGetNativeBinary (graphHandle, &blobSize, nullptr );
376
-
377
- OPENVINO_ASSERT (result == ZE_RESULT_SUCCESS,
378
- " Failed to compile network. L0 pfnGetNativeBinary get blob size" ,
379
- " result: " ,
380
- ze_result_to_string (result),
381
- " , code 0x" ,
382
- std::hex,
383
- uint64_t (result),
384
- " . " ,
385
- getLatestBuildError ());
386
-
387
- std::vector<uint8_t > blob (blobSize);
388
- // Get blob data
389
- result = _graphDdiTableExt.pfnGetNativeBinary (graphHandle, &blobSize, blob.data ());
390
-
391
- OPENVINO_ASSERT (result == ZE_RESULT_SUCCESS,
392
- " Failed to compile network. L0 pfnGetNativeBinary get blob data" ,
393
- " result: " ,
394
- ze_result_to_string (result),
395
- " , code 0x" ,
396
- std::hex,
397
- uint64_t (result),
398
- " . " ,
399
- getLatestBuildError ());
400
436
_logger.info (" LevelZeroCompilerInDriver getCompiledNetwork returning blob" );
401
- return blob;
402
- } else {
403
- _logger.info (" return the blob from network description" );
404
- return networkDescription->compiledNetwork ;
437
+ return CompiledNetwork (blobPtr, blobSize, std::move (blob));
405
438
}
439
+ _logger.info (" return the blob from network description" );
440
+ return CompiledNetwork (networkDescription.compiledNetwork .data (),
441
+ networkDescription.compiledNetwork .size (),
442
+ networkDescription.compiledNetwork );
406
443
}
407
444
408
445
template <typename TableExtension>
@@ -1201,6 +1238,7 @@ template class LevelZeroCompilerInDriver<ze_graph_dditable_ext_1_3_t>;
1201
1238
template class LevelZeroCompilerInDriver <ze_graph_dditable_ext_1_4_t >;
1202
1239
template class LevelZeroCompilerInDriver <ze_graph_dditable_ext_1_5_t >;
1203
1240
template class LevelZeroCompilerInDriver <ze_graph_dditable_ext_1_6_t >;
1241
+ template class LevelZeroCompilerInDriver <ze_graph_dditable_ext_1_7_t >;
1204
1242
1205
1243
} // namespace driverCompilerAdapter
1206
1244
} // namespace intel_npu
0 commit comments