Skip to content

Commit

Permalink
Add render lock mechanism (#64)
Browse files Browse the repository at this point in the history
* Update try-slang.js

* Re-structure renderer to fix render-lock issues

* Update try-slang.js

* Add stretch-free UV calculation to splatter
  • Loading branch information
saipraveenb25 authored Nov 20, 2024
1 parent da4c16b commit dfbb318
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 121 deletions.
75 changes: 55 additions & 20 deletions demos/gsplat2d-diff.slang
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ struct Gaussian2D : IDifferentiable

if (abs(b) < 1e-6 || abs(c) < 1e-6)
{
// The covariance matrix is diagonal (or close enough..) , so the eigenvectors are the x and y axes.
// The covariance matrix is diagonal (or close enough..), so the eigenvectors are the x and y axes.
float2x2 eigenvectors = float2x2(float2(1, 0), float2(0, 1));
float2 scale = float2(sqrt(a), sqrt(d));

Expand Down Expand Up @@ -442,24 +442,14 @@ float4 fineRasterize(SortedShortList, uint localIdx, no_diff float2 uv)
return pixelState.value;
}

[ForceInline]
uint workgroupUniformLoad(Ptr<uint> val)
{
__target_switch
{
case wgsl:
__intrinsic_asm "workgroupUniformLoad(&blobCount_0)";
}
}

void fineRasterize_bwd(SortedShortList, uint localIdx, float2 uv, float4 dOut)
{
GroupMemoryBarrierWithGroupSync();

PixelState pixelState = { finalVal[localIdx], maxCount[localIdx] };

PixelState.Differential dColor = { dOut };
uint count = workgroupUniformLoad(&blobCount);
uint count = workgroupUniformLoad(blobCount);
for (uint _i = count; _i > 0; _i--)
{
uint i = _i - 1;
Expand Down Expand Up @@ -490,17 +480,62 @@ InitializedShortList initShortList(uint2 dispatchThreadID)
return { 0 };
}

[Differentiable]
float4 splatBlobs(uint2 dispatchThreadID, int2 imageSize)
// Calculates a 'stretch-free' mapping from the requested render-target dimensions to the
// image in the texture
//
float2 calcUV(uint2 dispatchThreadID, int2 renderSize, int2 imageSize)
{
uint globalID = dispatchThreadID.x + dispatchThreadID.y * imageSize.x;
// Easy case.
if (all(renderSize == imageSize))
return ((float2)dispatchThreadID) / renderSize;

float aspectRatioRT = ((float) renderSize.x) / renderSize.y;
float aspectRatioTEX = ((float) imageSize.x) / imageSize.y;

float2 uv = float2(dispatchThreadID) / float2(imageSize);
if (aspectRatioRT > aspectRatioTEX)
{
// Render target is wider than the texture.
// Match the widths.
//
float xCoord = ((float) dispatchThreadID.x) / renderSize.x;
float yCoord = ((float) dispatchThreadID.y * aspectRatioTEX) / renderSize.x;

uint2 tileID = uint2(dispatchThreadID.x / WG_X, dispatchThreadID.y / WG_Y);
// We'll re-center the y-coord around 0.5.
float yCoordMax = aspectRatioTEX / aspectRatioRT;
yCoord = yCoord + (1.0 - yCoordMax) / 2.0f;
return float2(xCoord, yCoord);
}
else
{
// Render target is taller than the texture.
// Match the heights.
//
float yCoord = ((float) dispatchThreadID.y) / renderSize.y;
float xCoord = ((float) dispatchThreadID.x) / (renderSize.y * aspectRatioTEX);

// We'll recenter the x-coord around 0.5.
float xCoordMax = aspectRatioRT / aspectRatioTEX;
xCoord = xCoord + (1.0 - xCoordMax) / 2.0f;
return float2(xCoord, yCoord);
}
}


[Differentiable]
float4 splatBlobs(uint2 dispatchThreadID, int2 dispatchSize)
{
uint globalID = dispatchThreadID.x + dispatchThreadID.y * dispatchSize.x;

int texWidth;
int texHeight;
targetTexture.GetDimensions(texWidth, texHeight);
int2 texSize = int2(texWidth, texHeight);
float2 uv = calcUV(dispatchThreadID, dispatchSize, texSize);

uint2 tileCoords = uint2(dispatchThreadID.x / WG_X, dispatchThreadID.y / WG_Y);

float2 tileLow = float2(tileID) * float2(WG_X, WG_Y) / float2(imageSize);
float2 tileHigh = tileLow + (float2(WG_X, WG_Y) / float2(imageSize));
float2 tileLow = calcUV(tileCoords * uint2(WG_X, WG_Y), dispatchSize, texSize);
float2 tileHigh = calcUV((tileCoords + 1) * uint2(WG_X, WG_Y), dispatchSize, texSize);

float2 tileCenter = (tileLow + tileHigh) / 2;
float2x2 tileRotation = float2x2(1, 0, 0, 1);
Expand All @@ -513,7 +548,7 @@ float4 splatBlobs(uint2 dispatchThreadID, int2 imageSize)
uint2 localID = dispatchThreadID % uint2(WG_X, WG_Y);
uint localIdx = localID.x + localID.y * WG_X;

// Short-list blobs that overlap with the current pixel
// Short-list blobs that overlap with the local tile.
FilledShortList filledSList = coarseRasterize(sList, tileBounds, localIdx);

// Pad the unused space in the buffer
Expand Down
Loading

0 comments on commit dfbb318

Please sign in to comment.