Skip to content

Allow LHS of where to be any type. #6333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,13 @@ interface ITexelElement
__init(Element x);
}

extension<T:__BuiltinArithmeticType> T : ITexelElement
{
typealias Element = T;
static const int elementCount = 1;
__intrinsic_op(0) __init(Element x);
}

${{{
// Scalar types that can be used as texel element.
const char* texeElementScalarTypes[] = {
Expand All @@ -539,12 +546,6 @@ const char* texeElementScalarTypes[] = {
for (auto elementType : texeElementScalarTypes)
{
}}}
extension $(elementType) : ITexelElement
{
typealias Element = $(elementType);
static const int elementCount = 1;
__intrinsic_op(0) __init(Element x);
}
extension<int N> vector<$(elementType), N> : ITexelElement
{
typealias Element = $(elementType);
Expand Down
133 changes: 90 additions & 43 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9789,6 +9789,46 @@ void SemanticsVisitor::validateArraySizeForVariable(VarDeclBase* varDecl)
}
}

bool getExtensionTargetDeclList(
ASTBuilder* astBuilder,
DeclRefType* targetDeclRefType,
ExtensionDecl* extDecl,
ShortList<AggTypeDecl*>& targetDecls)
{
if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>())
{
auto aggTypeDecl = aggTypeDeclRef.getDecl();

targetDecls.add(aggTypeDecl);
return true;
}

auto genericParamDeclRef = targetDeclRefType->getDeclRef().as<GenericTypeParamDeclBase>();
if (!genericParamDeclRef)
return false;

auto genericParent = as<GenericDecl>(genericParamDeclRef.getParent().getDecl());
if (!genericParent)
return false;

if (genericParent != extDecl->parentDecl)
return false;

for (auto member : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericParent))
{
if (getSub(astBuilder, member) == targetDeclRefType)
{
auto baseType = getSup(astBuilder, member);
if (auto baseTypeDecl = isDeclRefTypeOf<AggTypeDecl>(baseType))
{
targetDecls.add(baseTypeDecl.getDecl());
}
}
}
return targetDecls.getCount() != 0;
}


void SemanticsDeclBasesVisitor::_validateExtensionDeclTargetType(ExtensionDecl* decl)
{
if (auto targetDeclRefType = as<DeclRefType>(decl->targetType))
Expand Down Expand Up @@ -11582,8 +11622,8 @@ void checkDerivativeAttributeImpl(
auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef);

// If the function is a member function, we need to check that the
// `this` type matches the expected type. This will ensure that after lowering to
// IR, the two functions are compatible.
// `this` type matches the expected type. This will ensure that after lowering
// to IR, the two functions are compatible.
//
if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType))
{
Expand Down Expand Up @@ -11971,8 +12011,9 @@ void checkDerivativeOfAttributeImpl(

if (as<ErrorType>(resolved->type.type))
{
// If we can't resolve a type, something went wrong. If we're working with a generic
// decl, the most likely cause is a failure of generic argument inference.
// If we can't resolve a type, something went wrong. If we're working with a
// generic decl, the most likely cause is a failure of generic argument
// inference.
//
visitor->getSink()->diagnose(
derivativeOfAttr,
Expand Down Expand Up @@ -12284,8 +12325,8 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers(
// Find the base type's members first
for (auto inheritanceMember : structDecl->getMembersOfType<InheritanceDecl>())
{
// For base types, we need to pick their parameters of the constructor to the derived type's
// constructor
// For base types, we need to pick their parameters of the constructor to the derived
// type's constructor
if (auto baseTypeDeclRef = isDeclRefTypeOf<StructDecl>(inheritanceMember->base.type))
{
// We should only find the member initialization constructor because it is the
Expand All @@ -12294,15 +12335,15 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers(
baseTypeDeclRef.getDecl(),
ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit);

// The constructor has to have higher or equal visibility level than the struct itself,
// otherwise, it's not accessible so we will not pick up.
// The constructor has to have higher or equal visibility level than the struct
// itself, otherwise, it's not accessible so we will not pick up.
if (ctor && getDeclVisibility(ctor) >= ctorVisibility)
{
for (ParamDecl* param : ctor->getParameters())
{
// Because the parameters in the ctor must have the higher or equal visibility
// than the ctor itself, we don't need to check the visibility level of the
// parameter.
// Because the parameters in the ctor must have the higher or equal
// visibility than the ctor itself, we don't need to check the visibility
// level of the parameter.
resultMembers.add(param);
}
}
Expand Down Expand Up @@ -12342,10 +12383,9 @@ static Expr* _getParamDefaultValue(SemanticsVisitor* visitor, VarDeclBase* varDe

bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* structDecl)
{
// If a type or its base type already defines any explicit constructors, do not synthesize any
// constructors.
// See
// https://github.com/shader-slang/spec/blob/main/proposals/004-initialization.md#inheritance-initialization
// If a type or its base type already defines any explicit constructors, do not synthesize
// any constructors. see:
// https://github.com/shader-slang/slang/blob/master/docs/proposals/004-initialization.md#inheritance-initialization
if (_hasExplicitConstructor(structDecl, true))
return false;

Expand Down Expand Up @@ -12397,9 +12437,9 @@ bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* struct
ctorParam->loc = ctor->loc;
ctor->members.add(ctorParam);

// We need to ensure member is `no_diff` if it cannot be differentiated, `ctor` modifiers do
// not matter in this case since member-wise ctor is always differentiable or "treat as
// differentiable".
// We need to ensure member is `no_diff` if it cannot be differentiated, `ctor`
// modifiers do not matter in this case since member-wise ctor is always differentiable
// or "treat as differentiable".
if (!isTypeDifferentiable(member->getType()) || member->hasModifier<NoDiffModifier>())
{
auto noDiffMod = m_astBuilder->create<NoDiffModifier>();
Expand Down Expand Up @@ -12559,7 +12599,8 @@ void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl)
totalWidth += int(thisFieldWidth);
groupInfo.add({memberIndex, int(thisFieldWidth), t, bfm});
}
// If the struct ended with a bitpacked member, then make sure we don't forget the last group
// If the struct ended with a bitpacked member, then make sure we don't forget the last
// group
dispatchSomeBitPackedMembers();
}

Expand Down Expand Up @@ -12630,8 +12671,8 @@ static void _propagateRequirement(
if (!isAnyInvalid && resultCaps.isInvalid())
{
// If joining the referenced decl's requirements results an invalid capability set,
// then the decl is using things that require conflicting set of capabilities, and we should
// diagnose an error.
// then the decl is using things that require conflicting set of capabilities, and we
// should diagnose an error.
if (referencedDecl && decl)
{
maybeDiagnose(
Expand Down Expand Up @@ -12736,17 +12777,17 @@ struct CapabilityDeclReferenceVisitor
// `calling_functions_targets`:
// ``` default_target = calling_functions_targets-{other_case_targets} ```
//
// * `calling_functions_capability` = `requirement attribute` of the calling function;
// if missing
// * `calling_functions_capability` = `requirement attribute` of the calling
// function; if missing
// we can assume it is `any_target`
//
// * `{other_case_targets}` = set of all capabilities all `case` statments target inside
// the `__target_switch`
// * `{other_case_targets}` = set of all capabilities all `case` statments target
// inside the `__target_switch`

// If we do not handle `default:`, the codegen will fail when trying to find a specific
// codegen target not handled explicitly by a `case` statment.
// We must also ensure the `default` case is last so we have priority to hit `case`
// statments and can preprocess `case` statments before the `default` case.
// If we do not handle `default:`, the codegen will fail when trying to find a
// specific codegen target not handled explicitly by a `case` statment. We must also
// ensure the `default` case is last so we have priority to hit `case` statments and
// can preprocess `case` statments before the `default` case.
CapabilitySet targetCap;
if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) ==
CapabilityName::Invalid)
Expand Down Expand Up @@ -12901,7 +12942,8 @@ CapabilitySet SemanticsDeclCapabilityVisitor::getDeclaredCapabilitySet(Decl* dec
// For every existing target, we want to join their requirements together.
// If the the parent defines additional targets, we want to add them to the disjunction set.
// For example:
// [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void foo(); }
// [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void
// foo(); }
// The requirement for `foo` should be glsl+glsl_ext_1 | spirv.
//
CapabilitySet declaredCaps;
Expand Down Expand Up @@ -12990,8 +13032,8 @@ static inline void _dispatchCapabilitiesVisitorOfFunctionDecl(

void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl)
{
// If the function is an entrypoint and specifies a target stage, add the capabilities to our
// function capabilities.
// If the function is an entrypoint and specifies a target stage, add the capabilities to
// our function capabilities.
_dispatchCapabilitiesVisitorOfFunctionDecl(
this,
funcDecl,
Expand All @@ -13012,8 +13054,8 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun

auto vis = getDeclVisibility(funcDecl);

// If 0 capabilities were annotated on a function, capabilities are inferred from the function
// body
// If 0 capabilities were annotated on a function, capabilities are inferred from the
// function body
if (declaredCaps.isEmpty())
{
declaredCaps = funcDecl->inferredCapabilityRequirements;
Expand Down Expand Up @@ -13133,7 +13175,8 @@ DeclVisibility getDeclVisibility(Decl* decl)
: parentModule->defaultVisibility;
}

// Members of other agg type decls will have their default visibility capped to the parents'.
// Members of other agg type decls will have their default visibility capped to the
// parents'.
if (as<NamespaceDecl>(decl))
{
return DeclVisibility::Public;
Expand Down Expand Up @@ -13345,15 +13388,16 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(

// There are two causes for why type checking failed on failedAvailableSet.
// The first scenario is that failedAvailableSet defines a set of capabilities on a
// compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have
// a function:
// compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we
// have a function:
// [require(hlsl)] // <-- failedAvailableSet
// [require(cpp)]
// void caller()
// {
// printf(); // assume this is defined for (cpp | cuda).
// }
// In this case we should diagnose error reporting printf isn't defined on a required target.
// In this case we should diagnose error reporting printf isn't defined on a required
// target.
//
// Now, we detect if we are case 1.

Expand All @@ -13370,8 +13414,8 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(
decl,
outFailedAtom);

// Anything defined on a non-failed target atom may be the culprit to why we fail having
// a target capability. Print out all possible culprits.
// Anything defined on a non-failed target atom may be the culprit to why we fail
// having a target capability. Print out all possible culprits.
CapabilityAtomSet failedAtomSet;
failedAtomSet.add((UInt)outFailedAtom);
CapabilityAtomSet targetsNotUsedSet;
Expand All @@ -13395,9 +13439,12 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability(
}
}

//// The second scenario is when the callee is using a capability that is not provided by the
/// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { / useD();
///// require capability (hlsl,d) / } / In this case we should report that useD() is using a
//// The second scenario is when the callee is using a capability that is not provided by
/// the
/// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { /
/// useD();
///// require capability (hlsl,d) / } / In this case we should report that useD() is
/// using a
/// capability that is not declared by caller.
////

Expand Down
13 changes: 13 additions & 0 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3066,4 +3066,17 @@ bool resolveStageOfProfileWithEntryPoint(
const List<RefPtr<TargetRequest>>& targets,
FuncDecl* entryPointFuncDecl,
DiagnosticSink* sink);

// For an extensions decl, collect a list of decls on which the extension might be applying to.
// For example, if we see a `extension Foo`, return a `Decl*` that represents `struct Foo`.
// In the case of free-form generic extensions i.e. `extension<T:IFoo> T : IBar`, return `IFoo`.
// These are the decls that we need to register the extension with in
// `mapTypeToCandidateExtensions`.
// Returns true when any base decls are found.
bool getExtensionTargetDeclList(
ASTBuilder* astBuilder,
DeclRefType* targetDeclRefType,
ExtensionDecl* extDeclRef,
ShortList<AggTypeDecl*>& targetDecls);

} // namespace Slang
2 changes: 1 addition & 1 deletion source/slang/slang-ir-hlsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in
continue;
auto forceStructArg = arg->getOperand(0);
auto forceStructBaseType =
as<IRType>(forceStructArg->getDataType()->getOperand(0));
(IRType*)(forceStructArg->getDataType()->getOperand(0));
IRBuilder builder(call);
if (forceStructBaseType->getOp() == kIROp_StructType)
{
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration
IRUse pairType;
IR_LEAF_ISA(MixedDifferentialInstDecoration)

IRType* getPairType() { return as<IRType>(getOperand(0)); }
IRType* getPairType() { return (IRType*)(getOperand(0)); }
};

struct IRRecomputeBlockDecoration : IRAutodiffInstDecoration
Expand Down
4 changes: 2 additions & 2 deletions source/slang/slang-ir-link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,8 @@ IRWitnessTable* cloneWitnessTableImpl(
IRWitnessTable* clonedTable = dstTable;
if (!clonedTable)
{
auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, as<IRType>(originalTable->getConcreteType()));
auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType()));
auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType()));
clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType);
}
cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue);
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-use-uninitialized-values.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ static bool canIgnoreType(IRType* type, IRType* upper)
if (auto spec = as<IRSpecialize>(type))
{
IRInst* inner = getResolvedInstForDecorations(spec);
IRType* innerType = as<IRType>(inner);
IRType* innerType = (IRType*)(inner);
return canIgnoreType(innerType, upper);
}

Expand Down
6 changes: 3 additions & 3 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9453,10 +9453,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// should handle propgation of value-size information from constraints
// back to generic parameters?
//
if (auto declRefType = as<DeclRefType>(constraintDecl->sub.type))
if (auto genParamDeclRef =
isDeclRefTypeOf<GenericTypeParamDeclBase>(constraintDecl->sub.type))
{
auto typeParamDeclVal =
subContext->findLoweredDecl(declRefType->getDeclRef().getDecl());
auto typeParamDeclVal = subContext->findLoweredDecl(genParamDeclRef.getDecl());
SLANG_ASSERT(typeParamDeclVal && typeParamDeclVal->val);
subBuilder->addTypeConstraintDecoration(typeParamDeclVal->val, supType);
}
Expand Down
16 changes: 9 additions & 7 deletions source/slang/slang-serialize-container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "../core/slang-math.h"
#include "../core/slang-stream.h"
#include "../core/slang-text-io.h"
#include "slang-check-impl.h"
#include "slang-compiler.h"
#include "slang-mangled-lexer.h"
#include "slang-parser.h"
Expand Down Expand Up @@ -813,15 +814,16 @@ static List<ExtensionDecl*>& _getCandidateExtensionList(
if (auto targetDeclRefType =
as<DeclRefType>(extensionDecl->targetType))
{
// Attach our extension to that type as a candidate...
if (auto aggTypeDeclRef =
targetDeclRefType->getDeclRef()
.as<AggTypeDecl>())
ShortList<AggTypeDecl*> baseDecls;
getExtensionTargetDeclList(
astBuilder,
targetDeclRefType,
extensionDecl,
baseDecls);
for (auto baseDecl : baseDecls)
{
auto aggTypeDecl = aggTypeDeclRef.getDecl();

_getCandidateExtensionList(
aggTypeDecl,
baseDecl,
moduleDecl->mapTypeToCandidateExtensions)
.add(extensionDecl);
}
Expand Down
Loading
Loading