diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 892c6a66..5de19f1f 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -306,20 +306,27 @@ def _fully_fused_projection( depths = means_c[..., 2] # [C, N] + # scale factor for 3 sigma in 2 dimension for a 2d multivariate gaussian + # distribution + scale_factor = 3.4086 + # calculating magnitude of major and and minor axis b = (covars2d[..., 0, 0] + covars2d[..., 1, 1]) / 2 # (...,) v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.01)) # (...,) - radius = torch.ceil(3.0 * torch.sqrt(v1)) # (...,) + radius = torch.ceil(scale_factor * torch.sqrt(v1)) # (...,) # v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.01)) # (...,) - # radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,) + # radius = torch.ceil(scale_factor * torch.sqrt(torch.max(v1, v2))) # (...,) + + x_proj = scale_factor * torch.sqrt(covars2d[..., 0, 0]) # (...,) + y_proj = scale_factor * torch.sqrt(covars2d[..., 1, 1]) # (...,) valid = (det > 0) & (depths > near_plane) & (depths < far_plane) radius[~valid] = 0.0 inside = ( - (means2d[..., 0] + radius > 0) - & (means2d[..., 0] - radius < width) - & (means2d[..., 1] + radius > 0) - & (means2d[..., 1] - radius < height) + (means2d[..., 0] + x_proj > 0) + & (means2d[..., 0] - x_proj < width) + & (means2d[..., 1] + y_proj > 0) + & (means2d[..., 1] - y_proj < height) ) radius[~inside] = 0.0 diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index c651e803..8e70cefc 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -162,11 +162,14 @@ __global__ void fully_fused_projection_fwd_kernel( inverse(covar2d, covar2d_inv); // take 3 sigma as the radius (non differentiable) + T scale_factor = 3.4086; T b = 0.5f * (covar2d[0][0] + covar2d[1][1]); T v1 = b + sqrt(max(0.01f, b * b - det)); - T radius = ceil(3.f * sqrt(v1)); + T radius = ceil(scale_factor * sqrt(v1)); // T v2 = b - sqrt(max(0.1f, b * b - det)); // T radius = ceil(3.f * sqrt(max(v1, v2))); + T x_proj = scale_factor * sqrt(covar2d[0][0]); + T y_proj = scale_factor * sqrt(covar2d[1][1]); if (radius <= radius_clip) { radii[idx] = 0; @@ -174,8 +177,8 @@ __global__ void fully_fused_projection_fwd_kernel( } // mask out gaussians outside the image region - if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || - mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { + if (mean2d.x + x_proj <= 0 || mean2d.x - x_proj >= image_width || + mean2d.y + y_proj <= 0 || mean2d.y - y_proj >= image_height) { radii[idx] = 0; return; } diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 4d8609f0..213df79b 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -178,18 +178,21 @@ __global__ void fully_fused_projection_packed_fwd_kernel( T radius; if (valid) { // take 3 sigma as the radius (non differentiable) + T scale_factor = 3.4086; T b = 0.5f * (covar2d[0][0] + covar2d[1][1]); T v1 = b + sqrt(max(0.1f, b * b - det)); T v2 = b - sqrt(max(0.1f, b * b - det)); - radius = ceil(3.f * sqrt(max(v1, v2))); + radius = ceil(scale_factor * sqrt(max(v1, v2))); + T x_proj = scale_factor * sqrt(covar2d[0][0]); + T y_proj = scale_factor * sqrt(covar2d[1][1]); if (radius <= radius_clip) { valid = false; } // mask out gaussians outside the image region - if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || - mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { + if (mean2d.x + x_proj <= 0 || mean2d.x - x_proj >= image_width || + mean2d.y + y_proj <= 0 || mean2d.y - y_proj >= image_height) { valid = false; } }