Skip to content

Commit c15e7ad

Browse files
authored
Atomics+Wave ops intrinsics fixes. (shader-slang#3542)
* Fix atomics intrinsics, increase kMaxDescriptorSets. * Add SPIRVASM to known non-differentiable insts. * Support fp16 wave ops when targeting glsl. * Fixes. * Fix vk validation errors. * Fix. * Add to allowed failures.
1 parent a67cb06 commit c15e7ad

21 files changed

+227
-53
lines changed

slang.h

+2
Original file line numberDiff line numberDiff line change
@@ -4179,6 +4179,8 @@ namespace slang
41794179

41804180
virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) = 0;
41814181

4182+
virtual SLANG_NO_THROW void SLANG_MCALL setSkipSPIRVValidation(bool value) = 0;
4183+
41824184
};
41834185

41844186
#define SLANG_UUID_ICompileRequest ICompileRequest::getTypeGuid()

source/slang/core.meta.slang

+14
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ attribute_syntax [Differentiable(order:int = 0)] : BackwardDifferentiableAttribu
217217
__intrinsic_op($(kIROp_RequirePrelude))
218218
void __requirePrelude(constexpr String preludeText);
219219

220+
__intrinsic_op($(kIROp_RequireGLSLExtension))
221+
void __requireGLSLExtension(constexpr String preludeText);
222+
220223
/// Interface to denote types as differentiable.
221224
/// Allows for user-specified differential types as
222225
/// well as automatic generation, for when the associated type
@@ -2254,13 +2257,24 @@ __generic<T>
22542257
__intrinsic_op($(kIROp_IsFloat))
22552258
bool __isFloat_impl(T t);
22562259

2260+
__generic<T>
2261+
__intrinsic_op($(kIROp_IsHalf))
2262+
bool __isHalf_impl(T t);
2263+
22572264
__generic<T>
22582265
[__unsafeForceInlineEarly]
22592266
bool __isFloat()
22602267
{
22612268
return __isFloat_impl(__declVal<T>());
22622269
}
22632270

2271+
__generic<T>
2272+
[__unsafeForceInlineEarly]
2273+
bool __isHalf()
2274+
{
2275+
return __isHalf_impl(__declVal<T>());
2276+
}
2277+
22642278
__generic<T>
22652279
__intrinsic_op($(kIROp_IsUnsignedInt))
22662280
bool __isUnsignedInt_impl(T t);

source/slang/hlsl.meta.slang

+84-30
Original file line numberDiff line numberDiff line change
@@ -1778,18 +1778,18 @@ float __atomicAdd(__ref float value, float amount)
17781778
}
17791779

17801780
__glsl_version(430)
1781-
__glsl_extension(GL_EXT_shader_atomic_float2)
1782-
half __atomicAdd(__ref half value, half amount)
1781+
__glsl_extension(GL_NV_shader_atomic_fp16_vector)
1782+
half2 __atomicAdd(__ref half2 value, half2 amount)
17831783
{
17841784
__target_switch
17851785
{
17861786
case glsl: __intrinsic_asm "atomicAdd($0, $1)";
17871787
case spirv:
17881788
return spirv_asm
17891789
{
1790-
OpExtension "SPV_EXT_shader_atomic_float16_add";
1791-
OpCapability AtomicFloat16AddEXT;
1792-
result:$$half = OpAtomicFAddEXT &value Device None $amount
1790+
OpExtension "SPV_EXT_shader_atomic_float_add";
1791+
OpCapability AtomicFloat32AddEXT;
1792+
result:$$half2 = OpAtomicFAddEXT &value Device None $amount
17931793
};
17941794
}
17951795
}
@@ -2337,7 +2337,7 @@ ${{{{
23372337
__target_switch
23382338
{
23392339
case hlsl:
2340-
__intrinsic_asm "NvInterlockedAddFp32($0, $1, $2))";
2340+
__intrinsic_asm "NvInterlockedAddFp16x2($0, $1, $2))";
23412341
}
23422342
}
23432343

@@ -2364,8 +2364,15 @@ ${{{{
23642364
case glsl:
23652365
case spirv:
23662366
{
2367-
let buf = __getEquivalentStructuredBuffer<half>(this);
2368-
originalValue = __atomicAdd(buf[byteAddress / 2], value);
2367+
let buf = __getEquivalentStructuredBuffer<half2>(this);
2368+
if ((byteAddress & 2) == 0)
2369+
{
2370+
originalValue = __atomicAdd(buf[byteAddress/4], half2(value, half(0.0))).x;
2371+
}
2372+
else
2373+
{
2374+
originalValue = __atomicAdd(buf[byteAddress/4], half2(half(0.0), value)).y;
2375+
}
23692376
return;
23702377
}
23712378
}
@@ -7555,14 +7562,19 @@ __target_intrinsic(cuda, "_waveProductMultiple($0, $1)")
75557562
__target_intrinsic(hlsl, "WaveActiveProduct($1)")
75567563
matrix<T,N,M> WaveMaskProduct(WaveMask mask, matrix<T,N,M> expr);
75577564

7565+
__intrinsic_op($(kIROp_RequireGLSLExtension))
7566+
void __requireGLSLExtension(String extensionName);
7567+
75587568
__generic<T : __BuiltinArithmeticType>
75597569
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
75607570
__spirv_version(1.3)
75617571
T WaveMaskSum(WaveMask mask, T expr)
75627572
{
75637573
__target_switch
75647574
{
7565-
case glsl: __intrinsic_asm "subgroupAdd($1)";
7575+
case glsl:
7576+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
7577+
__intrinsic_asm "subgroupAdd($1)";
75667578
case cuda: __intrinsic_asm "_waveSum($0, $1)";
75677579
case hlsl: __intrinsic_asm "WaveActiveSum($1)";
75687580
case spirv:
@@ -7591,7 +7603,9 @@ vector<T,N> WaveMaskSum(WaveMask mask, vector<T,N> expr)
75917603
{
75927604
__target_switch
75937605
{
7594-
case glsl: __intrinsic_asm "subgroupAdd($1)";
7606+
case glsl:
7607+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
7608+
__intrinsic_asm "subgroupAdd($1)";
75957609
case cuda: __intrinsic_asm "_waveSumMultiple($0, $1)";
75967610
case hlsl: __intrinsic_asm "WaveActiveSum($1)";
75977611
case spirv:
@@ -7627,6 +7641,7 @@ bool WaveMaskAllEqual(WaveMask mask, T value)
76277641
__target_switch
76287642
{
76297643
case glsl:
7644+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
76307645
__intrinsic_asm "subgroupAllEqual($1)";
76317646
case hlsl:
76327647
__intrinsic_asm "WaveActiveAllEqual($1)";
@@ -7651,6 +7666,7 @@ bool WaveMaskAllEqual(WaveMask mask, vector<T,N> value)
76517666
__target_switch
76527667
{
76537668
case glsl:
7669+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
76547670
__intrinsic_asm "subgroupAllEqual($1)";
76557671
case hlsl:
76567672
__intrinsic_asm "WaveActiveAllEqual($1)";
@@ -7681,7 +7697,9 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr)
76817697
{
76827698
__target_switch
76837699
{
7684-
case glsl: __intrinsic_asm "subgroupExclusiveMul($1)";
7700+
case glsl:
7701+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
7702+
__intrinsic_asm "subgroupExclusiveMul($1)";
76857703
case cuda: __intrinsic_asm "_wavePrefixProduct($0, $1)";
76867704
case hlsl: __intrinsic_asm "WavePrefixProduct($1)";
76877705
case spirv:
@@ -7710,7 +7728,9 @@ vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr)
77107728
{
77117729
__target_switch
77127730
{
7713-
case glsl: __intrinsic_asm "subgroupExclusiveMul($1)";
7731+
case glsl:
7732+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
7733+
__intrinsic_asm "subgroupExclusiveMul($1)";
77147734
case cuda: __intrinsic_asm "_wavePrefixProductMultiple($0, $1)";
77157735
case hlsl: __intrinsic_asm "WavePrefixProduct($1)";
77167736
case spirv:
@@ -7744,7 +7764,9 @@ T WaveMaskPrefixSum(WaveMask mask, T expr)
77447764
{
77457765
__target_switch
77467766
{
7747-
case glsl: __intrinsic_asm "subgroupExclusiveAdd($1)";
7767+
case glsl:
7768+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
7769+
__intrinsic_asm "subgroupExclusiveAdd($1)";
77487770
case cuda: __intrinsic_asm "_wavePrefixSum($0, $1)";
77497771
case hlsl: __intrinsic_asm "WavePrefixSum($1)";
77507772
case spirv:
@@ -7774,7 +7796,9 @@ vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr)
77747796
{
77757797
__target_switch
77767798
{
7777-
case glsl: __intrinsic_asm "subgroupExclusiveAdd($1)";
7799+
case glsl:
7800+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
7801+
__intrinsic_asm "subgroupExclusiveAdd($1)";
77787802
case cuda: __intrinsic_asm "_wavePrefixSumMultiple($0, $1)";
77797803
case hlsl: __intrinsic_asm "WavePrefixSum($1)";
77807804
case spirv:
@@ -8281,7 +8305,9 @@ T WaveActive$(opName.hlslName)(T expr)
82818305
{
82828306
__target_switch
82838307
{
8284-
case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
8308+
case glsl:
8309+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8310+
__intrinsic_asm "subgroup$(opName.glslName)($0)";
82858311
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
82868312
case spirv:
82878313
if (__isFloat<T>())
@@ -8320,7 +8346,9 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
83208346
{
83218347
__target_switch
83228348
{
8323-
case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
8349+
case glsl:
8350+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8351+
__intrinsic_asm "subgroup$(opName.glslName)($0)";
83248352
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
83258353
case spirv:
83268354
if (__isFloat<T>())
@@ -8574,7 +8602,9 @@ T WavePrefixProduct(T expr)
85748602
{
85758603
__target_switch
85768604
{
8577-
case glsl: __intrinsic_asm "subgroupExclusiveMul($0)";
8605+
case glsl:
8606+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8607+
__intrinsic_asm "subgroupExclusiveMul($0)";
85788608
case hlsl: __intrinsic_asm "WavePrefixProduct";
85798609
case spirv:
85808610
if (__isFloat<T>())
@@ -8609,7 +8639,9 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr)
86098639
{
86108640
__target_switch
86118641
{
8612-
case glsl: __intrinsic_asm "subgroupExclusiveMul($0)";
8642+
case glsl:
8643+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8644+
__intrinsic_asm "subgroupExclusiveMul($0)";
86138645
case hlsl: __intrinsic_asm "WavePrefixProduct";
86148646
case spirv:
86158647
if (__isFloat<T>())
@@ -8647,7 +8679,9 @@ T WavePrefixSum(T expr)
86478679
{
86488680
__target_switch
86498681
{
8650-
case glsl: __intrinsic_asm "subgroupExclusiveAdd($0)";
8682+
case glsl:
8683+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8684+
__intrinsic_asm "subgroupExclusiveAdd($0)";
86518685
case hlsl: __intrinsic_asm "WavePrefixSum";
86528686
case spirv:
86538687
if (__isFloat<T>())
@@ -8678,7 +8712,9 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
86788712
{
86798713
__target_switch
86808714
{
8681-
case glsl: __intrinsic_asm "subgroupExclusiveAdd($0)";
8715+
case glsl:
8716+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8717+
__intrinsic_asm "subgroupExclusiveAdd($0)";
86828718
case hlsl: __intrinsic_asm "WavePrefixSum";
86838719
case spirv:
86848720
if (__isFloat<T>())
@@ -8716,7 +8752,9 @@ T WaveReadLaneFirst(T expr)
87168752
{
87178753
__target_switch
87188754
{
8719-
case glsl: __intrinsic_asm "subgroupBroadcastFirst($0)";
8755+
case glsl:
8756+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8757+
__intrinsic_asm "subgroupBroadcastFirst($0)";
87208758
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
87218759
case spirv:
87228760
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
@@ -8732,7 +8770,9 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
87328770
{
87338771
__target_switch
87348772
{
8735-
case glsl: __intrinsic_asm "subgroupBroadcastFirst($0)";
8773+
case glsl:
8774+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8775+
__intrinsic_asm "subgroupBroadcastFirst($0)";
87368776
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
87378777
case spirv:
87388778
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
@@ -8761,7 +8801,9 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
87618801
{
87628802
__target_switch
87638803
{
8764-
case glsl: __intrinsic_asm "subgroupBroadcast($0, $1)";
8804+
case glsl:
8805+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8806+
__intrinsic_asm "subgroupBroadcast($0, $1)";
87658807
case hlsl: __intrinsic_asm "WaveReadLaneAt";
87668808
case spirv:
87678809
let ulane = uint(lane);
@@ -8778,7 +8820,9 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
87788820
{
87798821
__target_switch
87808822
{
8781-
case glsl: __intrinsic_asm "subgroupBroadcast($0, $1)";
8823+
case glsl:
8824+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8825+
__intrinsic_asm "subgroupBroadcast($0, $1)";
87828826
case hlsl: __intrinsic_asm "WaveReadLaneAt";
87838827
case spirv:
87848828
let ulane = uint(lane);
@@ -8805,7 +8849,9 @@ T WaveReadLaneAt(T value, int lane)
88058849
{
88068850
__target_switch
88078851
{
8808-
case glsl: __intrinsic_asm "subgroupShuffle($0, $1)";
8852+
case glsl:
8853+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8854+
__intrinsic_asm "subgroupShuffle($0, $1)";
88098855
case hlsl: __intrinsic_asm "WaveReadLaneAt";
88108856
case spirv:
88118857
let ulane = uint(lane);
@@ -8822,7 +8868,9 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
88228868
{
88238869
__target_switch
88248870
{
8825-
case glsl: __intrinsic_asm "subgroupShuffle($0, $1)";
8871+
case glsl:
8872+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8873+
__intrinsic_asm "subgroupShuffle($0, $1)";
88268874
case hlsl: __intrinsic_asm "WaveReadLaneAt";
88278875
case spirv:
88288876
let ulane = uint(lane);
@@ -8850,7 +8898,9 @@ T WaveShuffle(T value, int lane)
88508898
{
88518899
__target_switch
88528900
{
8853-
case glsl: __intrinsic_asm "subgroupShuffle($0, $1)";
8901+
case glsl:
8902+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8903+
__intrinsic_asm "subgroupShuffle($0, $1)";
88548904
case hlsl: __intrinsic_asm "WaveReadLaneAt";
88558905
case spirv:
88568906
let ulane = uint(lane);
@@ -8867,7 +8917,9 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane)
88678917
{
88688918
__target_switch
88698919
{
8870-
case glsl: __intrinsic_asm "subgroupShuffle($0, $1)";
8920+
case glsl:
8921+
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
8922+
__intrinsic_asm "subgroupShuffle($0, $1)";
88718923
case hlsl: __intrinsic_asm "WaveReadLaneAt";
88728924
case spirv:
88738925
let ulane = uint(lane);
@@ -8890,7 +8942,8 @@ uint WavePrefixCountBits(bool value)
88908942
{
88918943
__target_switch
88928944
{
8893-
case glsl: __intrinsic_asm "subgroupBallotExclusiveBitCount(subgroupBallot($0))";
8945+
case glsl:
8946+
__intrinsic_asm "subgroupBallotExclusiveBitCount(subgroupBallot($0))";
88948947
case hlsl: __intrinsic_asm "WavePrefixCountBits($0)";
88958948
case spirv:
88968949
return spirv_asm
@@ -8910,7 +8963,8 @@ uint4 WaveGetConvergedMulti()
89108963
{
89118964
__target_switch
89128965
{
8913-
case glsl: __intrinsic_asm "subgroupBallot(true)";
8966+
case glsl:
8967+
__intrinsic_asm "subgroupBallot(true)";
89148968
case hlsl: __intrinsic_asm "WaveActiveBallot(true)";
89158969
case cuda: __intrinsic_asm "make_uint4(__activemask(), 0, 0, 0)";
89168970
case spirv:

source/slang/slang-compiler.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -2329,6 +2329,17 @@ namespace Slang
23292329
return false;
23302330
}
23312331

2332+
bool CodeGenContext::shouldSkipSPIRVValidation()
2333+
{
2334+
if (auto endToEndReq = isEndToEndCompile())
2335+
{
2336+
if (endToEndReq->m_skipSPIRVValidation)
2337+
return true;
2338+
}
2339+
2340+
return false;
2341+
}
2342+
23322343
bool CodeGenContext::shouldDumpIR()
23332344
{
23342345
if (getTargetReq()->getTargetFlags() & SLANG_TARGET_FLAG_DUMP_IR)

0 commit comments

Comments
 (0)