Skip to content

Commit 79e7223

Browse files
authored
Preliminary CUDA Half support (shader-slang#1808)
* #include an absolute path didn't work - because paths were taken to always be relative. * WIP CUDA half support. * Working support for half on CUDA - requires cuda_fp16.h and associated files can be found. * Fix for win32 for unused funcs. * Fix for Clang. * Hack to disable unused local function warning.
1 parent a47e775 commit 79e7223

12 files changed

+593
-247
lines changed

prelude/slang-cuda-prelude.h

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
// Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support.
2+
// For this to work NVRTC needs to have the path to the CUDA SDK.
3+
//
4+
// As it stands the includes paths defined for Slang are passed down to NVRTC. Similarly defines defined for the Slang compile
5+
// are passed down.
6+
7+
#ifdef SLANG_CUDA_ENABLE_HALF
8+
#include <cuda_fp16.h>
9+
#endif
10+
111
#ifdef SLANG_CUDA_ENABLE_OPTIX
212
#include <optix.h>
313
#endif
414

5-
615
// Must be large enough to cause overflow and therefore infinity
716
#ifndef SLANG_INFINITY
817
# define SLANG_INFINITY ((float)(1e+300 * 1e+300))
@@ -42,6 +51,35 @@
4251
// Here we don't have the index zeroing behavior, as such bounds checks are generally not on GPU targets either.
4352
#ifndef SLANG_CUDA_FIXED_ARRAY_BOUND_CHECK
4453
# define SLANG_CUDA_FIXED_ARRAY_BOUND_CHECK(index, count) SLANG_PRELUDE_ASSERT(index < count);
54+
#endif
55+
56+
//
57+
// Half support
58+
//
59+
60+
#if SLANG_CUDA_ENABLE_HALF
61+
62+
// Add the other vector half types
63+
struct __half3 { __half2 xy; __half z; };
64+
struct __half4 { __half2 xy; __half2 zw; };
65+
66+
// Mechanism to make half vectors
67+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 make___half2(__half x, __half y) { return __halves2half2(x, y); }
68+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 make___half3(__half x, __half y, __half z) { __half3 o; o.xy = __halves2half2(x, y); o.z = z; return o; }
69+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 make___half4(__half x, __half y, __half z, __half w) { __half4 o; o.xy = __halves2half2(x, y); o.zw = __halves2half2(z, w); return o; }
70+
71+
// Use the round nearest as the default - it is the only one defined
72+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float22half2(const float2 a) { return __float22half2_rn(a); }
73+
74+
// Implement the vector versions
75+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __float2half(float2 a) { return __float22half2(a); }
76+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 __float2half(float3 a) { __half3 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.z = __float2half(a.z); return o; }
77+
SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 __float2half(float4 a) { __half4 o; o.xy = __float22half2(make_float2(a.x, a.y)); o.zw = __float22half2(make_float2(a.z, a.w)); return o; }
78+
79+
SLANG_FORCE_INLINE SLANG_CUDA_CALL float2 __half2float(__half2 a) { return __half22float2(a); }
80+
SLANG_FORCE_INLINE SLANG_CUDA_CALL float3 __half2float(__half3 a) { float2 xy = __half22float2(a.xy); float z = __half2float(a.z); return make_float3(xy.x, xy.y, z); }
81+
SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 __half2float(__half4 a) { float2 xy = __half22float2(a.xy); float2 zw = __half22float2(a.zw); return make_float4(xy.x, xy.y, zw.x, zw.y); }
82+
4583
#endif
4684

4785
// This macro handles how out-of-range surface coordinates are handled;

source/compiler-core/slang-downstream-compiler.h

+7-6
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ class DownstreamCompiler: public RefObject
140140
{
141141
enum Enum : SourceLanguageFlags
142142
{
143-
Unknown = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_UNKNOWN,
144-
Slang = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_SLANG,
145-
HLSL = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_HLSL,
143+
Unknown = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_UNKNOWN,
144+
Slang = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_SLANG,
145+
HLSL = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_HLSL,
146146
GLSL = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_GLSL,
147147
C = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_C,
148148
CPP = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_CPP,
@@ -247,9 +247,10 @@ class DownstreamCompiler: public RefObject
247247
{
248248
enum Enum : Flags
249249
{
250-
EnableExceptionHandling = 0x01,
251-
Verbose = 0x02,
252-
EnableSecurityChecks = 0x04,
250+
EnableExceptionHandling = 0x01, ///< Enables exception handling support (say as optionally supported by C++)
251+
Verbose = 0x02, ///< Give more verbose diagnostics
252+
EnableSecurityChecks = 0x04, ///< Enable runtime security checks (such as for buffer overruns) - enabling typically decreases performance
253+
EnableFloat16 = 0x08, ///< If set compiles with support for float16/half
253254
};
254255
};
255256

0 commit comments

Comments
 (0)