@@ -1778,18 +1778,18 @@ float __atomicAdd(__ref float value, float amount)
1778
1778
}
1779
1779
1780
1780
__glsl_version(430 )
1781
- __glsl_extension(GL_EXT_shader_atomic_float2 )
1782
- half __atomicAdd(__ref half value, half amount)
1781
+ __glsl_extension(GL_NV_shader_atomic_fp16_vector )
1782
+ half2 __atomicAdd(__ref half2 value, half2 amount)
1783
1783
{
1784
1784
__target_switch
1785
1785
{
1786
1786
case glsl: __intrinsic_asm " atomicAdd($0, $1)" ;
1787
1787
case spirv:
1788
1788
return spirv_asm
1789
1789
{
1790
- OpExtension " SPV_EXT_shader_atomic_float16_add " ;
1791
- OpCapability AtomicFloat16AddEXT ;
1792
- result:$$half = OpAtomicFAddEXT & value Device None $amount
1790
+ OpExtension " SPV_EXT_shader_atomic_float_add " ;
1791
+ OpCapability AtomicFloat32AddEXT ;
1792
+ result:$$half2 = OpAtomicFAddEXT & value Device None $amount
1793
1793
};
1794
1794
}
1795
1795
}
@@ -2337,7 +2337,7 @@ ${{{{
2337
2337
__target_switch
2338
2338
{
2339
2339
case hlsl:
2340
- __intrinsic_asm " NvInterlockedAddFp32 ($0, $1, $2))" ;
2340
+ __intrinsic_asm " NvInterlockedAddFp16x2 ($0, $1, $2))" ;
2341
2341
}
2342
2342
}
2343
2343
@@ -2364,8 +2364,15 @@ ${{{{
2364
2364
case glsl:
2365
2365
case spirv:
2366
2366
{
2367
- let buf = __getEquivalentStructuredBuffer< half > (this );
2368
- originalValue = __atomicAdd(buf [byteAddress / 2 ], value);
2367
+ let buf = __getEquivalentStructuredBuffer< half2 > (this );
2368
+ if ((byteAddress & 2 ) == 0 )
2369
+ {
2370
+ originalValue = __atomicAdd(buf [byteAddress/ 4 ], half2(value, half (0 . 0 ))).x ;
2371
+ }
2372
+ else
2373
+ {
2374
+ originalValue = __atomicAdd(buf [byteAddress/ 4 ], half2(half (0 . 0 ), value)).y ;
2375
+ }
2369
2376
return ;
2370
2377
}
2371
2378
}
@@ -7555,14 +7562,19 @@ __target_intrinsic(cuda, "_waveProductMultiple($0, $1)")
7555
7562
__target_intrinsic (hlsl, " WaveActiveProduct($1)" )
7556
7563
matrix< T,N,M> WaveMaskProduct(WaveMask mask, matrix< T,N,M> expr);
7557
7564
7565
+ __intrinsic_op($(kIROp_RequireGLSLExtension ))
7566
+ void __requireGLSLExtension(String extensionName);
7567
+
7558
7568
__generic < T : __BuiltinArithmeticType>
7559
7569
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
7560
7570
__spirv_version(1 . 3 )
7561
7571
T WaveMaskSum(WaveMask mask, T expr)
7562
7572
{
7563
7573
__target_switch
7564
7574
{
7565
- case glsl: __intrinsic_asm " subgroupAdd($1)" ;
7575
+ case glsl:
7576
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7577
+ __intrinsic_asm " subgroupAdd($1)" ;
7566
7578
case cuda: __intrinsic_asm " _waveSum($0, $1)" ;
7567
7579
case hlsl: __intrinsic_asm " WaveActiveSum($1)" ;
7568
7580
case spirv:
@@ -7591,7 +7603,9 @@ vector<T,N> WaveMaskSum(WaveMask mask, vector<T,N> expr)
7591
7603
{
7592
7604
__target_switch
7593
7605
{
7594
- case glsl: __intrinsic_asm " subgroupAdd($1)" ;
7606
+ case glsl:
7607
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7608
+ __intrinsic_asm " subgroupAdd($1)" ;
7595
7609
case cuda: __intrinsic_asm " _waveSumMultiple($0, $1)" ;
7596
7610
case hlsl: __intrinsic_asm " WaveActiveSum($1)" ;
7597
7611
case spirv:
@@ -7627,6 +7641,7 @@ bool WaveMaskAllEqual(WaveMask mask, T value)
7627
7641
__target_switch
7628
7642
{
7629
7643
case glsl:
7644
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7630
7645
__intrinsic_asm " subgroupAllEqual($1)" ;
7631
7646
case hlsl:
7632
7647
__intrinsic_asm " WaveActiveAllEqual($1)" ;
@@ -7651,6 +7666,7 @@ bool WaveMaskAllEqual(WaveMask mask, vector<T,N> value)
7651
7666
__target_switch
7652
7667
{
7653
7668
case glsl:
7669
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7654
7670
__intrinsic_asm " subgroupAllEqual($1)" ;
7655
7671
case hlsl:
7656
7672
__intrinsic_asm " WaveActiveAllEqual($1)" ;
@@ -7681,7 +7697,9 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr)
7681
7697
{
7682
7698
__target_switch
7683
7699
{
7684
- case glsl: __intrinsic_asm " subgroupExclusiveMul($1)" ;
7700
+ case glsl:
7701
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7702
+ __intrinsic_asm " subgroupExclusiveMul($1)" ;
7685
7703
case cuda: __intrinsic_asm " _wavePrefixProduct($0, $1)" ;
7686
7704
case hlsl: __intrinsic_asm " WavePrefixProduct($1)" ;
7687
7705
case spirv:
@@ -7710,7 +7728,9 @@ vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr)
7710
7728
{
7711
7729
__target_switch
7712
7730
{
7713
- case glsl: __intrinsic_asm " subgroupExclusiveMul($1)" ;
7731
+ case glsl:
7732
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7733
+ __intrinsic_asm " subgroupExclusiveMul($1)" ;
7714
7734
case cuda: __intrinsic_asm " _wavePrefixProductMultiple($0, $1)" ;
7715
7735
case hlsl: __intrinsic_asm " WavePrefixProduct($1)" ;
7716
7736
case spirv:
@@ -7744,7 +7764,9 @@ T WaveMaskPrefixSum(WaveMask mask, T expr)
7744
7764
{
7745
7765
__target_switch
7746
7766
{
7747
- case glsl: __intrinsic_asm " subgroupExclusiveAdd($1)" ;
7767
+ case glsl:
7768
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7769
+ __intrinsic_asm " subgroupExclusiveAdd($1)" ;
7748
7770
case cuda: __intrinsic_asm " _wavePrefixSum($0, $1)" ;
7749
7771
case hlsl: __intrinsic_asm " WavePrefixSum($1)" ;
7750
7772
case spirv:
@@ -7774,7 +7796,9 @@ vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr)
7774
7796
{
7775
7797
__target_switch
7776
7798
{
7777
- case glsl: __intrinsic_asm " subgroupExclusiveAdd($1)" ;
7799
+ case glsl:
7800
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
7801
+ __intrinsic_asm " subgroupExclusiveAdd($1)" ;
7778
7802
case cuda: __intrinsic_asm " _wavePrefixSumMultiple($0, $1)" ;
7779
7803
case hlsl: __intrinsic_asm " WavePrefixSum($1)" ;
7780
7804
case spirv:
@@ -8281,7 +8305,9 @@ T WaveActive$(opName.hlslName)(T expr)
8281
8305
{
8282
8306
__target_switch
8283
8307
{
8284
- case glsl: __intrinsic_asm " subgroup$(opName.glslName)($0)" ;
8308
+ case glsl:
8309
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8310
+ __intrinsic_asm " subgroup$(opName.glslName)($0)" ;
8285
8311
case hlsl: __intrinsic_asm " WaveActive$(opName.hlslName)" ;
8286
8312
case spirv:
8287
8313
if (__isFloat< T> ())
@@ -8320,7 +8346,9 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
8320
8346
{
8321
8347
__target_switch
8322
8348
{
8323
- case glsl: __intrinsic_asm " subgroup$(opName.glslName)($0)" ;
8349
+ case glsl:
8350
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8351
+ __intrinsic_asm " subgroup$(opName.glslName)($0)" ;
8324
8352
case hlsl: __intrinsic_asm " WaveActive$(opName.hlslName)" ;
8325
8353
case spirv:
8326
8354
if (__isFloat< T> ())
@@ -8574,7 +8602,9 @@ T WavePrefixProduct(T expr)
8574
8602
{
8575
8603
__target_switch
8576
8604
{
8577
- case glsl: __intrinsic_asm " subgroupExclusiveMul($0)" ;
8605
+ case glsl:
8606
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8607
+ __intrinsic_asm " subgroupExclusiveMul($0)" ;
8578
8608
case hlsl: __intrinsic_asm " WavePrefixProduct" ;
8579
8609
case spirv:
8580
8610
if (__isFloat< T> ())
@@ -8609,7 +8639,9 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr)
8609
8639
{
8610
8640
__target_switch
8611
8641
{
8612
- case glsl: __intrinsic_asm " subgroupExclusiveMul($0)" ;
8642
+ case glsl:
8643
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8644
+ __intrinsic_asm " subgroupExclusiveMul($0)" ;
8613
8645
case hlsl: __intrinsic_asm " WavePrefixProduct" ;
8614
8646
case spirv:
8615
8647
if (__isFloat< T> ())
@@ -8647,7 +8679,9 @@ T WavePrefixSum(T expr)
8647
8679
{
8648
8680
__target_switch
8649
8681
{
8650
- case glsl: __intrinsic_asm " subgroupExclusiveAdd($0)" ;
8682
+ case glsl:
8683
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8684
+ __intrinsic_asm " subgroupExclusiveAdd($0)" ;
8651
8685
case hlsl: __intrinsic_asm " WavePrefixSum" ;
8652
8686
case spirv:
8653
8687
if (__isFloat< T> ())
@@ -8678,7 +8712,9 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
8678
8712
{
8679
8713
__target_switch
8680
8714
{
8681
- case glsl: __intrinsic_asm " subgroupExclusiveAdd($0)" ;
8715
+ case glsl:
8716
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8717
+ __intrinsic_asm " subgroupExclusiveAdd($0)" ;
8682
8718
case hlsl: __intrinsic_asm " WavePrefixSum" ;
8683
8719
case spirv:
8684
8720
if (__isFloat< T> ())
@@ -8716,7 +8752,9 @@ T WaveReadLaneFirst(T expr)
8716
8752
{
8717
8753
__target_switch
8718
8754
{
8719
- case glsl: __intrinsic_asm " subgroupBroadcastFirst($0)" ;
8755
+ case glsl:
8756
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8757
+ __intrinsic_asm " subgroupBroadcastFirst($0)" ;
8720
8758
case hlsl: __intrinsic_asm " WaveReadLaneFirst" ;
8721
8759
case spirv:
8722
8760
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
@@ -8732,7 +8770,9 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
8732
8770
{
8733
8771
__target_switch
8734
8772
{
8735
- case glsl: __intrinsic_asm " subgroupBroadcastFirst($0)" ;
8773
+ case glsl:
8774
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8775
+ __intrinsic_asm " subgroupBroadcastFirst($0)" ;
8736
8776
case hlsl: __intrinsic_asm " WaveReadLaneFirst" ;
8737
8777
case spirv:
8738
8778
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector< T,N> result Subgroup $expr};
@@ -8761,7 +8801,9 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
8761
8801
{
8762
8802
__target_switch
8763
8803
{
8764
- case glsl: __intrinsic_asm " subgroupBroadcast($0, $1)" ;
8804
+ case glsl:
8805
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8806
+ __intrinsic_asm " subgroupBroadcast($0, $1)" ;
8765
8807
case hlsl: __intrinsic_asm " WaveReadLaneAt" ;
8766
8808
case spirv:
8767
8809
let ulane = uint (lane);
@@ -8778,7 +8820,9 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
8778
8820
{
8779
8821
__target_switch
8780
8822
{
8781
- case glsl: __intrinsic_asm " subgroupBroadcast($0, $1)" ;
8823
+ case glsl:
8824
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8825
+ __intrinsic_asm " subgroupBroadcast($0, $1)" ;
8782
8826
case hlsl: __intrinsic_asm " WaveReadLaneAt" ;
8783
8827
case spirv:
8784
8828
let ulane = uint (lane);
@@ -8805,7 +8849,9 @@ T WaveReadLaneAt(T value, int lane)
8805
8849
{
8806
8850
__target_switch
8807
8851
{
8808
- case glsl: __intrinsic_asm " subgroupShuffle($0, $1)" ;
8852
+ case glsl:
8853
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8854
+ __intrinsic_asm " subgroupShuffle($0, $1)" ;
8809
8855
case hlsl: __intrinsic_asm " WaveReadLaneAt" ;
8810
8856
case spirv:
8811
8857
let ulane = uint (lane);
@@ -8822,7 +8868,9 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
8822
8868
{
8823
8869
__target_switch
8824
8870
{
8825
- case glsl: __intrinsic_asm " subgroupShuffle($0, $1)" ;
8871
+ case glsl:
8872
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8873
+ __intrinsic_asm " subgroupShuffle($0, $1)" ;
8826
8874
case hlsl: __intrinsic_asm " WaveReadLaneAt" ;
8827
8875
case spirv:
8828
8876
let ulane = uint (lane);
@@ -8850,7 +8898,9 @@ T WaveShuffle(T value, int lane)
8850
8898
{
8851
8899
__target_switch
8852
8900
{
8853
- case glsl: __intrinsic_asm " subgroupShuffle($0, $1)" ;
8901
+ case glsl:
8902
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8903
+ __intrinsic_asm " subgroupShuffle($0, $1)" ;
8854
8904
case hlsl: __intrinsic_asm " WaveReadLaneAt" ;
8855
8905
case spirv:
8856
8906
let ulane = uint (lane);
@@ -8867,7 +8917,9 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane)
8867
8917
{
8868
8918
__target_switch
8869
8919
{
8870
- case glsl: __intrinsic_asm " subgroupShuffle($0, $1)" ;
8920
+ case glsl:
8921
+ if (__isHalf< T> ()) __requireGLSLExtension(" GL_EXT_shader_subgroup_extended_types_float16" );
8922
+ __intrinsic_asm " subgroupShuffle($0, $1)" ;
8871
8923
case hlsl: __intrinsic_asm " WaveReadLaneAt" ;
8872
8924
case spirv:
8873
8925
let ulane = uint (lane);
@@ -8890,7 +8942,8 @@ uint WavePrefixCountBits(bool value)
8890
8942
{
8891
8943
__target_switch
8892
8944
{
8893
- case glsl: __intrinsic_asm " subgroupBallotExclusiveBitCount(subgroupBallot($0))" ;
8945
+ case glsl:
8946
+ __intrinsic_asm " subgroupBallotExclusiveBitCount(subgroupBallot($0))" ;
8894
8947
case hlsl: __intrinsic_asm " WavePrefixCountBits($0)" ;
8895
8948
case spirv:
8896
8949
return spirv_asm
@@ -8910,7 +8963,8 @@ uint4 WaveGetConvergedMulti()
8910
8963
{
8911
8964
__target_switch
8912
8965
{
8913
- case glsl: __intrinsic_asm " subgroupBallot(true)" ;
8966
+ case glsl:
8967
+ __intrinsic_asm " subgroupBallot(true)" ;
8914
8968
case hlsl: __intrinsic_asm " WaveActiveBallot(true)" ;
8915
8969
case cuda: __intrinsic_asm " make_uint4(__activemask(), 0, 0, 0)" ;
8916
8970
case spirv:
0 commit comments