Skip to content

Commit 71885de

Browse files
authored
First pass support for half on vk (shader-slang#912)
* Look at getting half to work on vk. * Alter half test so can always produce consistent test results. * First pass working half on vk. * Improve comments for vulkan extensions around half. * Upgraded vulkan headers to v1.1.103 https://github.com/KhronosGroup/Vulkan-Headers * * Add getFeatures on Render interface * Vulkan renderer determines at startup if it can support half * Parse render-features on render-test * Small changes to half-calc.slang test. * Structured buffer half access works as expected for Vk, but isn't for dx12, so disable for now. * Require the half feature for renderers for the half-structured-buffer.slang test. * * Added ToolReturnCode to be more rigerous about how a return code is passed back from a tool * Added support for a tool being able to pass back an 'ignored' result. * Used enum codes to indicate meanings * Made spawnAndWait return a ToolReturnCode * Ignore tests that don't have required render-feature * Fix macro line continuation usage. * Check dx12 has half support. * Checking for half on dx12 - if CheckFeatureSupport fails, don't fail renderer initialization. * Fix typo.
1 parent 7004871 commit 71885de

17 files changed

+276
-15
lines changed

source/slang/emit.cpp

+44-2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ struct ExtensionUsageTracker
4747
StringBuilder glslExtensionRequireLines;
4848

4949
ProfileVersion profileVersion = ProfileVersion::GLSL_110;
50+
51+
bool hasHalfExtension = false;
5052
};
5153

5254
void requireGLSLExtension(
@@ -76,6 +78,21 @@ void requireGLSLVersionImpl(
7678
}
7779
}
7880

81+
void requireGLSLHalfExtension(ExtensionUsageTracker* tracker)
82+
{
83+
if (!tracker->hasHalfExtension)
84+
{
85+
// https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_shader_16bit_storage.txt
86+
requireGLSLExtension(tracker, "GL_EXT_shader_16bit_storage");
87+
88+
// https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_shader_explicit_arithmetic_types.txt
89+
// Use GL_KHX_shader_explicit_arithmetic_types because that is what appears defined in glslang
90+
requireGLSLExtension(tracker, "GL_KHX_shader_explicit_arithmetic_types");
91+
92+
tracker->hasHalfExtension = true;
93+
}
94+
}
95+
7996

8097
// Shared state for an entire emit session
8198
struct SharedEmitContext
@@ -810,7 +827,12 @@ struct EmitVisitor
810827

811828
case kIROp_BoolType: Emit("b"); break;
812829

813-
case kIROp_HalfType: Emit("f16"); break;
830+
case kIROp_HalfType:
831+
{
832+
_requireHalf();
833+
Emit("f16");
834+
break;
835+
}
814836
case kIROp_DoubleType: Emit("d"); break;
815837

816838
case kIROp_VectorType:
@@ -1197,6 +1219,14 @@ struct EmitVisitor
11971219
}
11981220
}
11991221

1222+
void _requireHalf()
1223+
{
1224+
if (getTarget(context) == CodeGenTarget::GLSL)
1225+
{
1226+
requireGLSLHalfExtension(&context->shared->extensionUsageTracker);
1227+
}
1228+
}
1229+
12001230
void emitSimpleTypeImpl(IRType* type)
12011231
{
12021232
switch (type->op)
@@ -1217,7 +1247,19 @@ struct EmitVisitor
12171247
case kIROp_UIntType: Emit("uint"); return;
12181248
case kIROp_UInt64Type: Emit("uint64_t"); return;
12191249

1220-
case kIROp_HalfType: Emit("half"); return;
1250+
case kIROp_HalfType:
1251+
{
1252+
_requireHalf();
1253+
if (getTarget(context) == CodeGenTarget::GLSL)
1254+
{
1255+
Emit("float16_t");
1256+
}
1257+
else
1258+
{
1259+
Emit("half");
1260+
}
1261+
return;
1262+
}
12211263
case kIROp_FloatType: Emit("float"); return;
12221264
case kIROp_DoubleType: Emit("double"); return;
12231265

tests/compute/half-calc.slang

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//TEST(compute):COMPARE_COMPUTE:-dx12 -compute -use-dxil -profile cs_6_2 -render-features half
2+
//TEST(compute):COMPARE_COMPUTE:-vk -compute -use-dxil -profile cs_6_2 -render-features half
3+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
4+
5+
// Test for doing a calculation using half
6+
7+
RWStructuredBuffer<float> outputBuffer;
8+
9+
[numthreads(4, 1, 1)]
10+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
11+
{
12+
uint tid = dispatchThreadID.x;
13+
14+
//half2 v0 = { 1, -2 };
15+
//half2 v1 = { -2, 4 };
16+
17+
//half2 v2 = (v0 * 2.0f + v1);
18+
19+
// This should work (it compiles on dxc, but slang doesn't seem to have overloads yet)
20+
//half offset = length(v2);
21+
22+
//half offset = v2.x + v2.y;
23+
24+
half offset = 0.0f;
25+
26+
half v = tid;
27+
v *= half(3.0f);
28+
v += half(1.0f);
29+
v += offset;
30+
31+
outputBuffer[tid] = v;
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
40800000
3+
40E00000
4+
41200000
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//TEST(compute):COMPARE_COMPUTE:-vk -compute -profile cs_6_2 -render-features half
2+
//Disable on Dx12 for now - because writing to structured buffer produces unexpected results
3+
//DISABLE_TEST(compute):COMPARE_COMPUTE:-dx12 -compute -use-dxil -profile cs_6_2 -render-features half
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=16):dxbinding(0),glbinding(0),out
5+
6+
struct Thing
7+
{
8+
uint pos;
9+
float radius;
10+
half4 color;
11+
};
12+
13+
RWStructuredBuffer<Thing> outputBuffer;
14+
15+
[numthreads(4, 1, 1)]
16+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
17+
{
18+
uint tid = dispatchThreadID.x;
19+
20+
float v = float(tid);
21+
22+
float base = v* 4.0f;
23+
24+
Thing thing;
25+
thing.pos = tid;
26+
thing.color = half4(base, base + 1.0f, base + 2.0f, base + 3.0f);
27+
thing.radius = v;
28+
29+
outputBuffer[tid] = thing;
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
0
2+
0
3+
3C000000
4+
42004000
5+
1
6+
3F800000
7+
45004400
8+
47004600
9+
2
10+
40000000
11+
48804800
12+
49804900
13+
3
14+
40400000
15+
4A804A00
16+
4B804B00

tools/gfx/render-d3d11.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class D3D11Renderer : public Renderer
5454

5555
// Renderer implementation
5656
virtual SlangResult initialize(const Desc& desc, void* inWindowHandle) override;
57+
virtual const List<String>& getFeatures() override { return m_features; }
5758
virtual void setClearColor(const float color[4]) override;
5859
virtual void clearFrame() override;
5960
virtual void presentFrame() override;
@@ -329,6 +330,8 @@ class D3D11Renderer : public Renderer
329330
Desc m_desc;
330331

331332
float m_clearColor[4] = { 0, 0, 0, 0 };
333+
334+
List<String> m_features;
332335
};
333336

334337
Renderer* createD3D11Renderer()

tools/gfx/render-d3d12.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class D3D12Renderer : public Renderer
5555
public:
5656
// Renderer implementation
5757
virtual SlangResult initialize(const Desc& desc, void* inWindowHandle) override;
58+
virtual const List<String>& getFeatures() override { return m_features; }
5859
virtual void setClearColor(const float color[4]) override;
5960
virtual void clearFrame() override;
6061
virtual void presentFrame() override;
@@ -571,6 +572,8 @@ class D3D12Renderer : public Renderer
571572
PFN_D3D12_SERIALIZE_ROOT_SIGNATURE m_D3D12SerializeRootSignature = nullptr;
572573

573574
HWND m_hwnd = nullptr;
575+
576+
List<String> m_features;
574577
};
575578

576579
Renderer* createD3D12Renderer()
@@ -1397,6 +1400,8 @@ Result D3D12Renderer::initialize(const Desc& desc, void* inWindowHandle)
13971400
return SLANG_FAIL;
13981401
}
13991402

1403+
1404+
14001405
FlagCombiner combiner;
14011406
// TODO: we should probably provide a command-line option
14021407
// to override UseDebug of default rather than leave it
@@ -1426,6 +1431,22 @@ Result D3D12Renderer::initialize(const Desc& desc, void* inWindowHandle)
14261431
return SLANG_FAIL;
14271432
}
14281433

1434+
// Find what features are supported
1435+
{
1436+
// Check this is how this is laid out...
1437+
SLANG_COMPILE_TIME_ASSERT(D3D_SHADER_MODEL_6_0 == 0x60);
1438+
1439+
D3D12_FEATURE_DATA_SHADER_MODEL featureShaderMode;
1440+
featureShaderMode.HighestShaderModel = D3D_SHADER_MODEL(0x62);
1441+
1442+
if (SLANG_SUCCEEDED(m_device->CheckFeatureSupport(D3D12_FEATURE_SHADER_MODEL, &featureShaderMode, sizeof(featureShaderMode))) &&
1443+
featureShaderMode.HighestShaderModel >= 0x62)
1444+
{
1445+
// With sm_6_2 we have half
1446+
m_features.Add("half");
1447+
}
1448+
}
1449+
14291450
m_numRenderFrames = 3;
14301451
m_numRenderTargets = 2;
14311452

tools/gfx/render-gl.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class GLRenderer : public Renderer
8181

8282
// Renderer implementation
8383
virtual SlangResult initialize(const Desc& desc, void* inWindowHandle) override;
84+
virtual const List<String>& getFeatures() override { return m_features; }
8485
virtual void setClearColor(const float color[4]) override;
8586
virtual void clearFrame() override;
8687
virtual void presentFrame() override;
@@ -354,6 +355,8 @@ class GLRenderer : public Renderer
354355

355356
Desc m_desc;
356357

358+
List<String> m_features;
359+
357360
// Declare a function pointer for each OpenGL
358361
// extension function we need to load
359362
#define DECLARE_GL_EXTENSION_FUNC(NAME, TYPE) TYPE NAME;

tools/gfx/render-vk.cpp

+71-10
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class VKRenderer : public Renderer
4242

4343
// Renderer implementation
4444
virtual SlangResult initialize(const Desc& desc, void* inWindowHandle) override;
45+
virtual const List<String>& getFeatures() override { return m_features; }
4546
virtual void setClearColor(const float color[4]) override;
4647
virtual void clearFrame() override;
4748
virtual void presentFrame() override;
@@ -488,6 +489,7 @@ class VKRenderer : public Renderer
488489
float m_clearColor[4] = { 0, 0, 0, 0 };
489490

490491
Desc m_desc;
492+
List<String> m_features;
491493
};
492494

493495
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VkRenderer::Buffer !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
@@ -898,6 +900,8 @@ SlangResult VKRenderer::initialize(const Desc& desc, void* inWindowHandle)
898900
{
899901
VK_KHR_SURFACE_EXTENSION_NAME,
900902

903+
VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME,
904+
901905
#if SLANG_WINDOWS_FAMILY
902906
VK_KHR_WIN32_SURFACE_EXTENSION_NAME,
903907
#else
@@ -949,6 +953,71 @@ SlangResult VKRenderer::initialize(const Desc& desc, void* inWindowHandle)
949953

950954
SLANG_RETURN_ON_FAIL(m_api.initPhysicalDevice(physicalDevices[selectedDeviceIndex]));
951955

956+
List<const char*> deviceExtensions;
957+
deviceExtensions.Add(VK_KHR_SWAPCHAIN_EXTENSION_NAME);
958+
959+
VkDeviceCreateInfo deviceCreateInfo = { VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO };
960+
deviceCreateInfo.queueCreateInfoCount = 1;
961+
962+
deviceCreateInfo.pEnabledFeatures = &m_api.m_deviceFeatures;
963+
964+
// Get the device features (doesn't use, but useful when debugging)
965+
if (m_api.vkGetPhysicalDeviceFeatures2)
966+
{
967+
VkPhysicalDeviceFeatures2 deviceFeatures2 = {};
968+
deviceFeatures2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
969+
m_api.vkGetPhysicalDeviceFeatures2(m_api.m_physicalDevice, &deviceFeatures2);
970+
}
971+
972+
VkPhysicalDeviceProperties basicProps = {};
973+
m_api.vkGetPhysicalDeviceProperties(m_api.m_physicalDevice, &basicProps);
974+
975+
// Get the API version
976+
const uint32_t majorVersion = VK_VERSION_MAJOR(basicProps.apiVersion);
977+
const uint32_t minorVersion = VK_VERSION_MINOR(basicProps.apiVersion);
978+
979+
// Float16 features
980+
// Need in this scope because it will be linked into the device creation (if it is available)
981+
VkPhysicalDeviceFloat16Int8FeaturesKHR float16Features = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FLOAT16_INT8_FEATURES_KHR };
982+
983+
// API version check, can't use vkGetPhysicalDeviceProperties2 yet since this device might not support it
984+
if (VK_MAKE_VERSION(majorVersion, minorVersion, 0) >= VK_API_VERSION_1_1 &&
985+
m_api.vkGetPhysicalDeviceProperties2 &&
986+
m_api.vkGetPhysicalDeviceFeatures2)
987+
{
988+
VkPhysicalDeviceProperties2 physicalDeviceProps2;
989+
990+
physicalDeviceProps2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
991+
physicalDeviceProps2.pNext = nullptr;
992+
physicalDeviceProps2.properties = {};
993+
994+
m_api.vkGetPhysicalDeviceProperties2(m_api.m_physicalDevice, &physicalDeviceProps2);
995+
996+
// Get device features
997+
VkPhysicalDeviceFeatures2 deviceFeatures2 = {};
998+
deviceFeatures2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
999+
1000+
// Link together for lookup
1001+
float16Features.pNext = deviceFeatures2.pNext;
1002+
deviceFeatures2.pNext = &float16Features;
1003+
1004+
m_api.vkGetPhysicalDeviceFeatures2(m_api.m_physicalDevice, &deviceFeatures2);
1005+
1006+
// If we have float16 features then enable
1007+
if (float16Features.shaderFloat16)
1008+
{
1009+
// Link into the creation features
1010+
float16Features.pNext = (void*)deviceCreateInfo.pNext;
1011+
deviceCreateInfo.pNext = &float16Features;
1012+
1013+
// Add the Float16 extension
1014+
deviceExtensions.Add(VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);
1015+
1016+
// We have half support
1017+
m_features.Add("half");
1018+
}
1019+
}
1020+
9521021
int queueFamilyIndex = m_api.findQueue(VK_QUEUE_GRAPHICS_BIT | VK_QUEUE_COMPUTE_BIT);
9531022
assert(queueFamilyIndex >= 0);
9541023

@@ -958,18 +1027,10 @@ SlangResult VKRenderer::initialize(const Desc& desc, void* inWindowHandle)
9581027
queueCreateInfo.queueCount = 1;
9591028
queueCreateInfo.pQueuePriorities = &queuePriority;
9601029

961-
char const* const deviceExtensions[] =
962-
{
963-
VK_KHR_SWAPCHAIN_EXTENSION_NAME,
964-
};
965-
966-
VkDeviceCreateInfo deviceCreateInfo = { VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO };
967-
deviceCreateInfo.queueCreateInfoCount = 1;
9681030
deviceCreateInfo.pQueueCreateInfos = &queueCreateInfo;
969-
deviceCreateInfo.pEnabledFeatures = &m_api.m_deviceFeatures;
9701031

971-
deviceCreateInfo.enabledExtensionCount = SLANG_COUNT_OF(deviceExtensions);
972-
deviceCreateInfo.ppEnabledExtensionNames = deviceExtensions;
1032+
deviceCreateInfo.enabledExtensionCount = uint32_t(deviceExtensions.Count());
1033+
deviceCreateInfo.ppEnabledExtensionNames = deviceExtensions.Buffer();
9731034

9741035
SLANG_VK_RETURN_ON_FAIL(m_api.vkCreateDevice(m_api.m_physicalDevice, &deviceCreateInfo, nullptr, &m_device));
9751036
SLANG_RETURN_ON_FAIL(m_api.initDeviceProcs(m_device));

tools/gfx/render.h

+3
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,9 @@ class Renderer: public Slang::RefObject
785785

786786
virtual SlangResult initialize(const Desc& desc, void* inWindowHandle) = 0;
787787

788+
bool hasFeature(const Slang::UnownedStringSlice& feature) { return getFeatures().IndexOf(Slang::String(feature)) != UInt(-1); }
789+
virtual const Slang::List<Slang::String>& getFeatures() = 0;
790+
788791
virtual void setClearColor(const float color[4]) = 0;
789792
virtual void clearFrame() = 0;
790793

tools/gfx/vk-api.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,15 @@ Slang::Result VulkanApi::initInstanceProcs(VkInstance instance)
4949

5050
VK_API_ALL_INSTANCE_PROCS(VK_API_GET_INSTANCE_PROC)
5151

52+
// Get optional
53+
VK_API_INSTANCE_PROCS_OPT(VK_API_GET_INSTANCE_PROC)
54+
5255
if (!areDefined(ProcType::Instance))
5356
{
5457
return SLANG_FAIL;
5558
}
5659

60+
5761
m_instance = instance;
5862
return SLANG_OK;
5963
}

0 commit comments

Comments
 (0)