Skip to content

Commit

Permalink
Allow Optional, Tuple and bool to be used in varying input/outp…
Browse files Browse the repository at this point in the history
…ut. (#5889)

* Allow `Optional` and `Tuple` to be used in varying input/output.

* Fix.

* format code

* Fix.

* Fix test.

* Fix.

* enhance test.

* Fix.

* format code

---------

Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
  • Loading branch information
csyonghe and slangbot authored Dec 18, 2024
1 parent 41c627f commit ae04e60
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 51 deletions.
13 changes: 10 additions & 3 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case AddressSpace::Uniform:
return SpvStorageClassUniform;
case AddressSpace::Input:
case AddressSpace::BuiltinInput:
return SpvStorageClassInput;
case AddressSpace::Output:
case AddressSpace::BuiltinOutput:
return SpvStorageClassOutput;
case AddressSpace::TaskPayloadWorkgroup:
return SpvStorageClassTaskPayloadWorkgroupEXT;
Expand Down Expand Up @@ -2688,7 +2690,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
IRBuilder builder(spvAsmBuiltinVar);
builder.setInsertBefore(spvAsmBuiltinVar);
auto varInst = getBuiltinGlobalVar(
builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), AddressSpace::Input),
builder.getPtrType(
kIROp_PtrType,
spvAsmBuiltinVar->getDataType(),
AddressSpace::BuiltinInput),
kind,
spvAsmBuiltinVar);
registerInst(spvAsmBuiltinVar, varInst);
Expand Down Expand Up @@ -4214,7 +4219,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto addrSpace = ptrType->getAddressSpace();
if (addrSpace != AddressSpace::Input &&
addrSpace != AddressSpace::Output)
addrSpace != AddressSpace::Output &&
addrSpace != AddressSpace::BuiltinInput &&
addrSpace != AddressSpace::BuiltinOutput)
continue;
}
}
Expand Down Expand Up @@ -4995,7 +5002,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
if (!ptrType)
return;
auto addrSpace = ptrType->getAddressSpace();
if (addrSpace == AddressSpace::Input)
if (addrSpace == AddressSpace::Input || addrSpace == AddressSpace::BuiltinInput)
{
if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
{
Expand Down
150 changes: 106 additions & 44 deletions source/slang/slang-ir-lower-buffer-element-type.cpp

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions source/slang/slang-ir-spirv-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
// Skip load's for referenced `Input` variables since a ref implies
// passing as is, which needs to be a pointer (pass as is).
if (user->getDataType() && user->getDataType()->getOp() == kIROp_RefType &&
addressSpace == AddressSpace::Input)
(addressSpace == AddressSpace::Input ||
addressSpace == AddressSpace::BuiltinInput))
{
builder.replaceOperand(use, addr);
continue;
Expand Down Expand Up @@ -431,7 +432,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
String semanticName = systemValueAttr->getName();
semanticName = semanticName.toLower();
if (semanticName == "sv_pointsize")
addressSpace = AddressSpace::Input;
addressSpace = AddressSpace::BuiltinInput;
}
}

Expand Down Expand Up @@ -661,6 +662,18 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
"resolve a storage class address space.");
}
}

switch (result)
{
case AddressSpace::Input:
if (varLayout->findSystemValueSemanticAttr())
result = AddressSpace::BuiltinInput;
break;
case AddressSpace::Output:
if (varLayout->findSystemValueSemanticAttr())
result = AddressSpace::BuiltinOutput;
break;
}
return result;
}

Expand Down
53 changes: 52 additions & 1 deletion source/slang/slang-parameter-binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2262,12 +2262,63 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(

return ptrTypeLayout;
}
else if (auto optionalType = as<OptionalType>(type))
{
Array<Type*, 2> types =
makeArray(optionalType->getValueType(), context->getASTBuilder()->getBoolType());
auto tupleType = context->getASTBuilder()->getTupleType(types.getView());
return processEntryPointVaryingParameter(context, tupleType, state, varLayout);
}
else if (auto tupleType = as<TupleType>(type))
{
RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
structLayout->type = type;
for (Index i = 0; i < tupleType->getMemberCount(); i++)
{
auto fieldType = tupleType->getMember(i);
RefPtr<VarLayout> fieldVarLayout = new VarLayout();

// We don't really have a "field" decl, so just use the tuple-typed decl
// itself as the varDecl of the elements.
auto fieldDecl = (VarDeclBase*)varLayout->varDecl.getDecl();
fieldVarLayout->varDecl = fieldDecl;

structLayout->fields.add(fieldVarLayout);

auto fieldTypeLayout = processEntryPointVaryingParameterDecl(
context,
fieldDecl,
fieldType,
state,
fieldVarLayout);

if (!fieldTypeLayout)
{
getSink(context)->diagnose(
varLayout->varDecl,
Diagnostics::notValidVaryingParameter,
fieldType);
continue;
}
fieldVarLayout->typeLayout = fieldTypeLayout;

// Assign offsets in var layout for each resource kind of the type.
for (auto fieldTypeResInfo : fieldTypeLayout->resourceInfos)
{
auto kind = fieldTypeResInfo.kind;
auto structTypeResInfo = structLayout->findOrAddResourceInfo(kind);
auto fieldResInfo = fieldVarLayout->findOrAddResourceInfo(kind);
fieldResInfo->index = structTypeResInfo->count.getFiniteValue();
structTypeResInfo->count += fieldTypeResInfo.count;
}
}
return structLayout;
}
// Catch declaration-reference types late in the sequence, since
// otherwise they will include all of the above cases...
else if (auto declRefType = as<DeclRefType>(type))
{
auto declRef = declRefType->getDeclRef();

if (auto structDeclRef = declRef.as<StructDecl>())
{
RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
Expand Down
10 changes: 9 additions & 1 deletion source/slang/slang-type-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4785,10 +4785,18 @@ static TypeLayoutResult _createTypeLayout(TypeLayoutContext& context, Type* type
type,
rules);
}
else if (auto optionalType = as<OptionalType>(type))
{
// OptionalType should be laid out the same way as Tuple<T, bool>.
Array<Type*, 2> types =
makeArray(optionalType->getValueType(), context.astBuilder->getBoolType());
auto tupleType = context.astBuilder->getTupleType(types.getView());
return _createTypeLayout(context, tupleType);
}
else if (auto tupleType = as<TupleType>(type))
{
// A `Tuple` type is laid out exactly the same way as a `struct` type,
// except that we want have a declref to the field.
// except that we won't have a declref to the field.

StructTypeLayoutBuilder typeLayoutBuilder;
StructTypeLayoutBuilder pendingDataTypeLayoutBuilder;
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-type-system-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,12 @@ enum class AddressSpace : uint64_t
MetalObjectData,
// Corresponds to SPIR-V's SpvStorageClassInput
Input,
// Same as `Input`, but used for builtin input variables.
BuiltinInput,
// Corresponds to SPIR-V's SpvStorageClassOutput
Output,
// Same as `Output`, but used for builtin output variables.
BuiltinOutput,
// Corresponds to SPIR-V's SpvStorageClassTaskPayloadWorkgroupEXT
TaskPayloadWorkgroup,
// Corresponds to SPIR-V's SpvStorageClassFunction
Expand Down
23 changes: 23 additions & 0 deletions tests/spirv/matrix-vertex-input.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv
// CHECK: OpVectorTimesMatrix

struct Vertex
{
float4x4 m;
float4 pos;
}

struct VertexOut
{
float4 pos : SV_Position;
float4 color;
}

[shader("vertex")]
VertexOut vertMain(Vertex v)
{
VertexOut o;
o.pos = mul(v.m, v.pos);
o.color = v.pos;
return o;
}
26 changes: 26 additions & 0 deletions tests/spirv/optional-vertex-output.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// Test that we can use Optional<T> or bool types in varying input or outputs.

// CHECK: OpDecorate %i_inA_value Location 0
// CHECK: OpDecorate %i_inA_hasValue Location 1
// CHECK: OpDecorate %entryPointParam_vertMain_a_value Location 0
// CHECK: OpDecorate %entryPointParam_vertMain_a_hasValue Location 1

struct VIn {
Optional<float> inA;
}

struct VSOut {
Optional<float> a;
bool outputValues[3];
};

[shader("vertex")]
VSOut vertMain(VIn i)
{
VSOut o;
o.a = i.inA;
o.outputValues = { true, false, true };
return o;
}

0 comments on commit ae04e60

Please sign in to comment.