Skip to content

Commit 4b35064

Browse files
authored
Fix autodiff issue for vector<T, N> (shader-slang#6275)
* Fix autodiff issue for vector<T, N> Close shader-slang#6154 We didn't implement correctly for vector<T, N> regarding the differentiablity. As we check differentiable before specialization, however according to the definition of vector, it has to be specialized to IFloat to know it's conformed to IDifferential type. Therefore for parameter type vector<T, N> will become no_diff. Therefore, we change the implementation a to make it explicit conform to IDifferential type. * fix typo
1 parent d3e5f39 commit 4b35064

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

source/slang/core.meta.slang

+20-6
Original file line numberDiff line numberDiff line change
@@ -1988,8 +1988,18 @@ extension vector<T,N> : IFloat
19881988
[OverloadRank(-1)]
19891989
[__unsafeForceInlineEarly] __init(float v) { this = vector<T,N>(T(v)); }
19901990

1991-
// IDifferentiable
1991+
}
1992+
1993+
__intrinsic_op($(kIROp_Add))
1994+
T __internal_add<T>(T a, T b);
19921995

1996+
__intrinsic_op($(kIROp_Mul))
1997+
T __internal_mul<T, U>(U a, T b);
1998+
1999+
__generic<T:IDifferentiable, let N : int>
2000+
extension vector<T,N> : IDifferentiable
2001+
{
2002+
// IDifferentiable
19932003
typedef vector<T, N> Differential;
19942004

19952005
[__unsafeForceInlineEarly]
@@ -2003,15 +2013,15 @@ extension vector<T,N> : IFloat
20032013
[BackwardDifferentiable]
20042014
static Differential dadd(Differential a, Differential b)
20052015
{
2006-
return a + b;
2016+
return __internal_add(a, b);
20072017
}
20082018

20092019
__generic<U : __BuiltinRealType>
20102020
[__unsafeForceInlineEarly]
20112021
[BackwardDifferentiable]
20122022
static Differential dmul(U a, Differential b)
20132023
{
2014-
return __realCast<T, U>(a) * b;
2024+
return __internal_mul(__realCast<float>(a), b);
20152025
}
20162026
}
20172027

@@ -2042,7 +2052,11 @@ extension matrix<T,N,M,L> : IFloat
20422052
[__unsafeForceInlineEarly]
20432053
__implicit_conversion($(kConversionCost_ScalarToMatrix))
20442054
__init(float v) { this = matrix<T,N,M>(T(v)); }
2055+
}
20452056

2057+
__generic<T:IDifferentiable, let N : int, let M : int, let L : int>
2058+
extension matrix<T,N,M,L> : IDifferentiable
2059+
{
20462060
// IDifferentiable.
20472061
typedef matrix<T, N,M,L> Differential;
20482062

@@ -2057,15 +2071,15 @@ extension matrix<T,N,M,L> : IFloat
20572071
[BackwardDifferentiable]
20582072
static Differential dadd(Differential a, Differential b)
20592073
{
2060-
return a + b;
2074+
return __internal_add(a, b);
20612075
}
2062-
2076+
20632077
__generic<U : __BuiltinRealType>
20642078
[__unsafeForceInlineEarly]
20652079
[BackwardDifferentiable]
20662080
static Differential dmul(U a, Differential b)
20672081
{
2068-
return __realCast<T, U>(a) * b;
2082+
return __internal_mul(__realCast<float>(a), b);
20692083
}
20702084
}
20712085

0 commit comments

Comments
 (0)