Skip to content

Commit 2f4029a

Browse files
authored
Adapter selection for Renderer (shader-slang#923)
* * Make adapter used selectable on the command line * Added 'adapter' to Renderer::Desc with dx11, dx12, vk honoring it * GL will check that the renderer matches, but cannot select a specific device * Share functionality on dx adapter selection in D3DUtil Note - that on tests that use OpenGL and the adapter doesn't match it will ignore the test (and display a message that the appropriate device couldn't be started) * Small function name improvement. * Variable rename to match type. * Fix typo in Dx12 device selection. * * Add checking if an adapter is warp * Improve some comments
1 parent 5bdc3ef commit 2f4029a

14 files changed

+336
-85
lines changed

tools/gfx/d3d-util.cpp

+147
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include <d3dcompiler.h>
55

6+
#include <dxgi1_4.h>
7+
68
// We will use the C standard library just for printing error messages.
79
#include <stdio.h>
810

@@ -304,4 +306,149 @@ bool D3DUtil::isTypeless(DXGI_FORMAT format)
304306
}
305307
}
306308

309+
/* static */HMODULE D3DUtil::getDxgiModule()
310+
{
311+
static HMODULE s_dxgiModule = LoadLibraryA("dxgi.dll");
312+
if (!s_dxgiModule)
313+
{
314+
fprintf(stderr, "error: failed load 'dxgi.dll'\n");
315+
return nullptr;
316+
}
317+
318+
return s_dxgiModule;
319+
}
320+
321+
/* static */SlangResult D3DUtil::createFactory(DeviceCheckFlags flags, ComPtr<IDXGIFactory>& outFactory)
322+
{
323+
auto dxgiModule = getDxgiModule();
324+
if (!dxgiModule)
325+
{
326+
return SLANG_FAIL;
327+
}
328+
329+
typedef HRESULT(WINAPI *PFN_DXGI_CREATE_FACTORY)(REFIID riid, void **ppFactory);
330+
typedef HRESULT(WINAPI *PFN_DXGI_CREATE_FACTORY_2)(UINT Flags, REFIID riid, _COM_Outptr_ void **ppFactory);
331+
332+
{
333+
auto createFactory2 = (PFN_DXGI_CREATE_FACTORY_2)::GetProcAddress(dxgiModule, "CreateDXGIFactory2");
334+
if (createFactory2)
335+
{
336+
UINT dxgiFlags = 0;
337+
338+
if (flags & DeviceCheckFlag::UseDebug)
339+
{
340+
dxgiFlags |= DXGI_CREATE_FACTORY_DEBUG;
341+
}
342+
343+
ComPtr<IDXGIFactory4> factory;
344+
SLANG_RETURN_ON_FAIL(createFactory2(dxgiFlags, IID_PPV_ARGS(factory.writeRef())));
345+
346+
outFactory = factory;
347+
return SLANG_OK;
348+
}
349+
}
350+
351+
{
352+
auto createFactory = (PFN_DXGI_CREATE_FACTORY)::GetProcAddress(dxgiModule, "CreateDXGIFactory");
353+
if (!createFactory)
354+
{
355+
fprintf(stderr, "error: failed load symbol '%s'\n", "CreateDXGIFactory");
356+
return SLANG_FAIL;
357+
}
358+
return createFactory(IID_PPV_ARGS(outFactory.writeRef()));
359+
}
360+
}
361+
362+
/* static */SlangResult D3DUtil::findAdapters(DeviceCheckFlags flags, const Slang::UnownedStringSlice& adapaterName, List<ComPtr<IDXGIAdapter>>& outDxgiAdapters)
363+
{
364+
ComPtr<IDXGIFactory> factory;
365+
SLANG_RETURN_ON_FAIL(createFactory(flags, factory));
366+
return findAdapters(flags, adapaterName, factory, outDxgiAdapters);
367+
}
368+
369+
static bool _isMatch(IDXGIAdapter* adapter, const Slang::UnownedStringSlice& lowerAdapaterName)
370+
{
371+
if (lowerAdapaterName.size() == 0)
372+
{
373+
return true;
374+
}
375+
376+
DXGI_ADAPTER_DESC desc;
377+
adapter->GetDesc(&desc);
378+
379+
String descName = String::FromWString(desc.Description).ToLower();
380+
381+
return descName.IndexOf(lowerAdapaterName) != UInt(-1);
382+
}
383+
384+
/* static */bool D3DUtil::isWarp(IDXGIFactory* dxgiFactory, IDXGIAdapter* adapterIn)
385+
{
386+
ComPtr<IDXGIFactory4> dxgiFactory4;
387+
if (SLANG_SUCCEEDED(dxgiFactory->QueryInterface(IID_PPV_ARGS(dxgiFactory4.writeRef()))))
388+
{
389+
ComPtr<IDXGIAdapter> warpAdapter;
390+
dxgiFactory4->EnumWarpAdapter(IID_PPV_ARGS(warpAdapter.writeRef()));
391+
392+
return adapterIn == warpAdapter;
393+
}
394+
395+
return false;
396+
}
397+
398+
/* static */SlangResult D3DUtil::findAdapters(DeviceCheckFlags flags, const UnownedStringSlice& adapterName, IDXGIFactory* dxgiFactory, List<ComPtr<IDXGIAdapter>>& outDxgiAdapters)
399+
{
400+
String lowerAdapterName = String(adapterName).ToLower();
401+
402+
outDxgiAdapters.Clear();
403+
404+
ComPtr<IDXGIAdapter> warpAdapter;
405+
if ((flags & DeviceCheckFlag::UseHardwareDevice) == 0)
406+
{
407+
ComPtr<IDXGIFactory4> dxgiFactory4;
408+
if (SLANG_SUCCEEDED(dxgiFactory->QueryInterface(IID_PPV_ARGS(dxgiFactory4.writeRef()))))
409+
{
410+
dxgiFactory4->EnumWarpAdapter(IID_PPV_ARGS(warpAdapter.writeRef()));
411+
if (_isMatch(warpAdapter, lowerAdapterName.getUnownedSlice()))
412+
{
413+
outDxgiAdapters.Add(warpAdapter);
414+
}
415+
}
416+
}
417+
418+
for (UINT adapterIndex = 0; true; adapterIndex++)
419+
{
420+
ComPtr<IDXGIAdapter> dxgiAdapter;
421+
if (dxgiFactory->EnumAdapters(adapterIndex, dxgiAdapter.writeRef()) == DXGI_ERROR_NOT_FOUND)
422+
break;
423+
424+
// Skip if warp (as we will have already added it)
425+
if (dxgiAdapter == warpAdapter)
426+
{
427+
continue;
428+
}
429+
if (!_isMatch(dxgiAdapter, lowerAdapterName.getUnownedSlice()))
430+
{
431+
continue;
432+
}
433+
434+
// Get if it's software
435+
UINT deviceFlags = 0;
436+
ComPtr<IDXGIAdapter1> dxgiAdapter1;
437+
if (SLANG_SUCCEEDED(dxgiAdapter->QueryInterface(IID_PPV_ARGS(dxgiAdapter1.writeRef()))))
438+
{
439+
DXGI_ADAPTER_DESC1 desc;
440+
dxgiAdapter1->GetDesc1(&desc);
441+
deviceFlags = desc.Flags;
442+
}
443+
444+
// If the right type then add it
445+
if ((deviceFlags & DXGI_ADAPTER_FLAG_SOFTWARE) == 0 && (flags & DeviceCheckFlag::UseHardwareDevice) != 0)
446+
{
447+
outDxgiAdapters.Add(dxgiAdapter);
448+
}
449+
}
450+
451+
return SLANG_OK;
452+
}
453+
307454
} // renderer_test

tools/gfx/d3d-util.h

+18
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88
#include "../../slang-com-ptr.h"
99
#include "../../source/core/list.h"
1010

11+
#include "flag-combiner.h"
12+
1113
#include "render.h"
1214

1315
#include <D3Dcommon.h>
1416
#include <DXGIFormat.h>
17+
#include <dxgi.h>
1518

1619
namespace gfx {
1720

@@ -56,6 +59,21 @@ class D3DUtil
5659

5760
/// Append text in in, into wide char array
5861
static void appendWideChars(const char* in, Slang::List<wchar_t>& out);
62+
63+
64+
static SlangResult createFactory(DeviceCheckFlags flags, Slang::ComPtr<IDXGIFactory>& outFactory);
65+
66+
/// Get the dxgiModule
67+
static HMODULE getDxgiModule();
68+
69+
/// Find adapters
70+
static SlangResult findAdapters(DeviceCheckFlags flags, const Slang::UnownedStringSlice& adapaterName, IDXGIFactory* dxgiFactory, Slang::List<Slang::ComPtr<IDXGIAdapter>>& dxgiAdapters);
71+
/// Find adapters
72+
static SlangResult findAdapters(DeviceCheckFlags flags, const Slang::UnownedStringSlice& adapaterName, Slang::List<Slang::ComPtr<IDXGIAdapter>>& dxgiAdapters);
73+
74+
/// True if the adapter is warp
75+
static bool isWarp(IDXGIFactory* dxgiFactory, IDXGIAdapter* adapter);
76+
5977
};
6078

6179
} // renderer_test

tools/gfx/render-d3d11.cpp

+23-2
Original file line numberDiff line numberDiff line change
@@ -472,12 +472,33 @@ SlangResult D3D11Renderer::initialize(const Desc& desc, void* inWindowHandle)
472472
for (int i = 0; i < numCombinations; ++i)
473473
{
474474
const auto deviceCheckFlags = combiner.getCombination(i);
475-
const D3D_DRIVER_TYPE driverType = (deviceCheckFlags & DeviceCheckFlag::UseHardwareDevice) ? D3D_DRIVER_TYPE_HARDWARE : D3D_DRIVER_TYPE_REFERENCE;
475+
476+
// If we have an adapter set on the desc, look it up. We only need to do so for hardware
477+
ComPtr<IDXGIAdapter> adapter;
478+
if (desc.adapter.Length() && (deviceCheckFlags & DeviceCheckFlag::UseHardwareDevice))
479+
{
480+
List<ComPtr<IDXGIAdapter>> dxgiAdapters;
481+
D3DUtil::findAdapters(deviceCheckFlags, desc.adapter.getUnownedSlice(), dxgiAdapters);
482+
if (dxgiAdapters.Count() == 0)
483+
{
484+
continue;
485+
}
486+
adapter = dxgiAdapters[0];
487+
}
488+
489+
// The adapter can be nullptr - that just means 'default', but when so we need to select the driver type
490+
D3D_DRIVER_TYPE driverType = D3D_DRIVER_TYPE_UNKNOWN;
491+
if (adapter == nullptr)
492+
{
493+
// If we don't have an adapter, select directly
494+
driverType = (deviceCheckFlags & DeviceCheckFlag::UseHardwareDevice) ? D3D_DRIVER_TYPE_HARDWARE : D3D_DRIVER_TYPE_REFERENCE;
495+
}
496+
476497
const int startFeatureIndex = (deviceCheckFlags & DeviceCheckFlag::UseFullFeatureLevel) ? 0 : 1;
477498
const UINT deviceFlags = (deviceCheckFlags & DeviceCheckFlag::UseDebug) ? D3D11_CREATE_DEVICE_DEBUG : 0;
478499

479500
res = D3D11CreateDeviceAndSwapChain_(
480-
nullptr, // adapter (use default)
501+
adapter,
481502
driverType,
482503
nullptr, // software
483504
deviceFlags,

0 commit comments

Comments
 (0)