Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support IDifferentiablePtrType #5031

Merged
merged 19 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ interface IDifferentiable
static Differential dmul(T, Differential);
};

__magic_type(DifferentiablePtrType)
interface IDifferentiablePtrType
{
__builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) )
associatedtype Differential : IDifferentiablePtrType;
};


/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
Expand Down Expand Up @@ -357,6 +364,36 @@ struct DifferentialPair : IDifferentiable
}
};

__generic<T : IDifferentiablePtrType>
__magic_type(DifferentialPtrPairType)
__intrinsic_type($(kIROp_DifferentialPtrPairType))
struct DifferentialPtrPair : IDifferentiablePtrType
{
typedef DifferentialPtrPair<T.Differential> Differential;
typedef T.Differential DifferentialElementType;

__intrinsic_op($(kIROp_MakeDifferentialPtrPair))
__init(T _primal, T.Differential _differential);

property p : T
{
__intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal))
get;
}

property v : T
{
__intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal))
get;
}

property d : T.Differential
{
__intrinsic_op($(kIROp_DifferentialPtrPairGetDifferential))
get;
}
};


/// A type that uses a floating-point representation
[sealed]
Expand Down
23 changes: 21 additions & 2 deletions source/slang/slang-ast-builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo

DifferentialPairType* ASTBuilder::getDifferentialPairType(
Type* valueType,
Witness* primalIsDifferentialWitness)
Witness* diffTypeWitness)
{
Val* args[] = { valueType, primalIsDifferentialWitness };
Val* args[] = { valueType, diffTypeWitness };
return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType"));
}

DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType(
Type* valueType,
Witness* diffRefTypeWitness)
{
Val* args[] = { valueType, diffRefTypeWitness };
return as<DifferentialPtrPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType"));
}

DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl()
{
DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr));
return declRef;
}

DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableRefInterfaceDecl()
{
DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr));
return declRef;
}

bool ASTBuilder::isDifferentiableInterfaceAvailable()
{
return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr);
Expand Down Expand Up @@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType()
return DeclRefType::create(this, getDifferentiableInterfaceDecl());
}

Type* ASTBuilder::getDifferentiableRefInterfaceType()
{
return DeclRefType::create(this, getDifferentiableRefInterfaceDecl());
}

DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg)
{
auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName);
Expand Down
9 changes: 8 additions & 1 deletion source/slang/slang-ast-builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,17 @@ class ASTBuilder : public RefObject

DifferentialPairType* getDifferentialPairType(
Type* valueType,
Witness* primalIsDifferentialWitness);
Witness* diffTypeWitness);

DifferentialPtrPairType* getDifferentialPtrPairType(
Type* valueType,
Witness* diffRefTypeWitness);

DeclRef<InterfaceDecl> getDifferentiableInterfaceDecl();
DeclRef<InterfaceDecl> getDifferentiableRefInterfaceDecl();

Type* getDifferentiableInterfaceType();
Type* getDifferentiableRefInterfaceType();

bool isDifferentiableInterfaceAvailable();

Expand Down
3 changes: 2 additions & 1 deletion source/slang/slang-ast-support-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#include "slang-profile.h"
#include "slang-type-system-shared.h"
#include "slang.h"
#include "../../include/slang.h"

#include "../core/slang-semantic-version.h"

Expand Down Expand Up @@ -1606,6 +1606,7 @@ namespace Slang
DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method

DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement
DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove DMulFunc down below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm doing dmul removal in a separate patch

DZeroFunc, ///< The `IDifferentiable.dzero` function requirement
DAddFunc, ///< The `IDifferentiable.dadd` function requirement
DMulFunc, ///< The `IDifferentiable.dmul` function requirement
Expand Down
11 changes: 11 additions & 0 deletions source/slang/slang-ast-type.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,22 @@ class DifferentialPairType : public ArithmeticExpressionType
Type* getPrimalType();
};

class DifferentialPtrPairType : public ArithmeticExpressionType
{
SLANG_AST_CLASS(DifferentialPtrPairType)
Type* getPrimalRefType();
};

class DifferentiableType : public BuiltinType
{
SLANG_AST_CLASS(DifferentiableType)
};

class DifferentiablePtrType : public BuiltinType
{
SLANG_AST_CLASS(DifferentiablePtrType)
};

class DefaultInitializableType : public BuiltinType
{
SLANG_AST_CLASS(DefaultInitializableType);
Expand Down
9 changes: 7 additions & 2 deletions source/slang/slang-check-conformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,14 @@ namespace Slang
return isInterfaceType(type);
}

bool SemanticsVisitor::isTypeDifferentiable(Type* type)
SubtypeWitness* SemanticsVisitor::isTypeDifferentiable(Type* type)
{
return isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None);
if (auto valueWitness = isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None))
return valueWitness;
else if (auto ptrWitness = isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None))
return ptrWitness;

return nullptr;
}

bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag)
Expand Down
8 changes: 7 additions & 1 deletion source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10204,7 +10204,8 @@ namespace Slang
bool isDiffParam = (!param->findModifier<NoDiffModifier>());
if (isDiffParam)
{
if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
auto diffPair = visitor->getDifferentialPairType(param->getType());
if (auto pairType = as<DifferentialPairType>(diffPair))
{
arg->type.type = pairType;
arg->type.isLeftValue = true;
Expand All @@ -10225,6 +10226,11 @@ namespace Slang
direction = ParameterDirection::kParameterDirection_InOut;
}
}
else if (auto refPairType = as<DifferentialPtrPairType>(diffPair))
{
// no need to change direction of ref-pairs.
arg->type.type = refPairType;
}
else
{
isDiffParam = false;
Expand Down
32 changes: 25 additions & 7 deletions source/slang/slang-check-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1131,15 +1131,19 @@ namespace Slang
{
if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>())
{
if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType)
if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType
|| builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType)
{
// We are trying to get differential type from a differential type.
// The result is itself.
return type;
}
}
type = resolveType(type);
if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())))
auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()));
if (!witness)
witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType()));
if (witness)
{
auto diffTypeLookupResult = lookUpMember(
getASTBuilder(),
Expand Down Expand Up @@ -1367,6 +1371,13 @@ namespace Slang
{
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
}

if (auto subtypeWitness = as<SubtypeWitness>(
tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType())))
{
addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
}

if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
{
foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
Expand Down Expand Up @@ -2899,18 +2910,25 @@ namespace Slang
return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType));
}
}

// Get a reference to the builtin 'IDifferentiable' interface
auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType();
auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType();

auto conformanceWitness = as<Witness>(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
if (conformanceWitness)
if (auto conformanceWitness = isTypeDifferentiable(primalType))
{
return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
if (conformanceWitness->getSup() == differentiableInterface)
{
return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
}
else if (conformanceWitness->getSup() == differentiableRefInterface)
{
return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness);
}
}
else
return primalType;
return primalType;
}

Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType)
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2208,7 +2208,7 @@ namespace Slang

bool isValidGenericConstraintType(Type* type);

bool isTypeDifferentiable(Type* type);
SubtypeWitness* isTypeDifferentiable(Type* type);

bool doesTypeHaveTag(Type* type, TypeTag tag);

Expand Down
Loading
Loading