Skip to content

Commit 8da47c4

Browse files
Added basic auto-diff capabilities for local load/store and simple arithmetic. Also added type-checking during the semantic stage. (shader-slang#2303)
* Added JVPTranscriber to handle differentiation of load, store, var, param and return instructions, as well as conversion of data and function types * Changed class names to be more in line with convention. Added correct type checking for __jvp() and verified that simple calls with only loads and stores are processed correctly * Added logic to differentiate basic arithmetic and literals inside IRConstruct and fixed the way parameters are differentiated Co-authored-by: Yong He <yonghe@outlook.com>
1 parent 0229784 commit 8da47c4

12 files changed

+453
-63
lines changed

source/slang/slang-ast-expr.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,10 @@ class ExtractExistentialValueExpr: public Expr
372372
/// An expression of the form `__jvp(fn)` to access the
373373
/// forward-mode derivative version of the function `fn`
374374
///
375-
class JVPDerivativeOfExpr: public Expr
375+
class JVPDifferentiateExpr: public Expr
376376
{
377-
SLANG_AST_CLASS(JVPDerivativeOfExpr)
378-
Expr* baseFn;
377+
SLANG_AST_CLASS(JVPDifferentiateExpr)
378+
Expr* baseFunction;
379379
};
380380

381381
/// A type expression of the form `__TaggedUnion(A, ...)`.

source/slang/slang-check-expr.cpp

+35-5
Original file line numberDiff line numberDiff line change
@@ -1509,16 +1509,46 @@ namespace Slang
15091509
return expr;
15101510
}
15111511

1512-
Expr* SemanticsExprVisitor::visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr)
1512+
Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
15131513
{
15141514
// Check/Resolve inner function declaration.
1515-
expr->baseFn = CheckTerm(expr->baseFn);
1515+
expr->baseFunction = CheckTerm(expr->baseFunction);
15161516

1517-
if(auto funcType = as<FuncType>(expr->baseFn->type))
1517+
if(auto primalType = as<FuncType>(expr->baseFunction->type))
15181518
{
15191519
// Resolve JVP type here.
1520-
// Temporarily resolving to the same type as the original function.
1521-
expr->type = expr->baseFn->type;
1520+
// Note that this type checking needs to be in sync with
1521+
// the auto-generation logic in slang-ir-jvp-diff.cpp
1522+
1523+
auto astBuilder = this->getASTBuilder();
1524+
FuncType* jvpType = astBuilder->create<FuncType>();
1525+
1526+
// Only float types can be differentiated for now.
1527+
1528+
// The JVP return type is float if primal return type is float
1529+
// void otherwise.
1530+
//
1531+
if (primalType->resultType->equals(astBuilder->getFloatType()))
1532+
jvpType->resultType = astBuilder->getFloatType();
1533+
else
1534+
jvpType->resultType = astBuilder->getVoidType();
1535+
1536+
// No support for differentiating function that throw errors, for now.
1537+
SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
1538+
jvpType->errorType = primalType->errorType;
1539+
1540+
for (UInt i = 0; i < primalType->getParamCount(); i++)
1541+
{
1542+
jvpType->paramTypes.add(primalType->getParamType(i));
1543+
}
1544+
1545+
for (UInt i = 0; i < primalType->getParamCount(); i++)
1546+
{
1547+
if(primalType->getParamType(i)->equals(astBuilder->getFloatType()))
1548+
jvpType->paramTypes.add(astBuilder->getFloatType());
1549+
}
1550+
1551+
expr->type = jvpType;
15221552
}
15231553
else
15241554
{

source/slang/slang-check-impl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1732,7 +1732,7 @@ namespace Slang
17321732
Expr* visitAndTypeExpr(AndTypeExpr* expr);
17331733
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
17341734

1735-
Expr* visitJVPDerivativeOfExpr(JVPDerivativeOfExpr* expr);
1735+
Expr* visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr);
17361736

17371737
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
17381738
Val* checkTypeModifier(Modifier* modifier, Type* type);

source/slang/slang-ir-diff-call.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ struct DerivativeCallProcessContext
3434
do
3535
{
3636
auto nextChild = child->getNextInst();
37-
// Look for IRJVPDerivativeOf
38-
if (auto derivOf = as<IRJVPDerivativeOf>(child))
37+
// Look for IRJVPDifferentiate
38+
if (auto derivOf = as<IRJVPDifferentiate>(child))
3939
{
40-
processDerivativeOf(derivOf);
40+
processDifferentiate(derivOf);
4141
}
4242
child = nextChild;
4343
}
@@ -50,7 +50,7 @@ struct DerivativeCallProcessContext
5050

5151
// Perform forward-mode automatic differentiation on
5252
// the intstructions.
53-
void processDerivativeOf(IRJVPDerivativeOf* derivOfInst)
53+
void processDifferentiate(IRJVPDifferentiate* derivOfInst)
5454
{
5555
IRFunc* jvpFunc = nullptr;
5656

0 commit comments

Comments
 (0)