Skip to content

Commit

Permalink
Merge branch 'main' into save_ply
Browse files Browse the repository at this point in the history
  • Loading branch information
maturk committed Jan 10, 2025
2 parents 4c10e88 + 2df0a95 commit 1809a6d
Show file tree
Hide file tree
Showing 54 changed files with 1,693 additions and 977 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,6 @@ jobs:
env:
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
BUILD_NO_CUDA=1 python -m build
twine upload --username __token__ --password $PYPI_TOKEN dist/*
shell: bash
shell: bash
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
recursive-include gsplat/cuda/csrc *
recursive-include gsplat/cuda/include *
30 changes: 11 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,45 +15,37 @@ gsplat is an open-source library for CUDA accelerated rasterization of gaussians

**Dependence**: Please install [Pytorch](https://pytorch.org/get-started/locally/) first.




The easiest way is to install from PyPI. In this way it will build the CUDA code **on the first run** (JIT).

```bash
pip install gsplat
```

Alternatively, you can install gsplat from python wheels containing pre-compiled binaries for a specific pytorch and cuda version. These wheels are stored in the github releases and can be found using simple index pages under https://docs.gsplat.studio/whl.
You obtain the wheel from this simple index page for a specific pytorch an and cuda version by appending these the version number after a + sign (part referred a *local version*). For example, to install gsplat for pytorch 2.0 and cuda 11.8 you can use
```
pip install gsplat==1.2.0+pt20cu118 --index-url https://docs.gsplat.studio/whl
```
Alternatively, you can specify the pytorch and cuda version in the index url using for example
```
pip install gsplat --index-url https://docs.gsplat.studio/whl/pt20cu118
```
This has the advantage that you do not have to pin a specific version of the package and as a result get automatically the latest package version.

Alternatively you can install gsplat from source. In this way it will build the CUDA code during installation.

```bash
pip install git+https://github.com/nerfstudio-project/gsplat.git
```

We also provide [pre-compiled wheels](https://docs.gsplat.studio/whl) for both linux and windows on certain python-torch-CUDA combinations (please check first which versions are supported). Note this way you would have to manually install [gsplat's dependencies](https://github.com/nerfstudio-project/gsplat/blob/6022cf45a19ee307803aaf1f19d407befad2a033/setup.py#L115). For example, to install gsplat for pytorch 2.0 and cuda 11.8 you can run
```
pip install ninja numpy jaxtyping rich
pip install gsplat --index-url https://docs.gsplat.studio/whl/pt20cu118
```

To install gsplat on Windows, please check [this instruction](docs/INSTALL_WIN.md).
To build gsplat from source on Windows, please check [this instruction](docs/INSTALL_WIN.md).

## Evaluation

This repo comes with a standalone script that reproduces the official Gaussian Splatting with exactly the same performance on PSNR, SSIM, LPIPS, and converged number of Gaussians. Powered by gsplat’s efficient CUDA implementation, the training takes up to **4x less GPU memory** with up to **15% less time** to finish than the official implementation. Full report can be found [here](https://docs.gsplat.studio/main/tests/eval.html).

```bash
pip install -r examples/requirements.txt
cd examples
pip install -r requirements.txt
# download mipnerf_360 benchmark data
python examples/datasets/download_dataset.py
python datasets/download_dataset.py
# run batch evaluation
bash examples/benchmarks/basic.sh
bash benchmarks/basic.sh
```

## Examples
Expand Down Expand Up @@ -82,7 +74,7 @@ This project is developed by the following wonderful contributors (unordered):
- [Zhuoyang Pan](https://panzhy.com/) (ShanghaiTech University): Core developer.
- [Jianbo Ye](http://www.jianboye.org/) (Amazon): Core developer.

We also have made the mathematical supplement, with conventions and derivations, available [here](https://arxiv.org/abs/2409.06765). If you find this library useful in your projects or papers, please consider citing:
We also have a white paper with about the project with benchmarking and mathematical supplement with conventions and derivations, available [here](https://arxiv.org/abs/2409.06765). If you find this library useful in your projects or papers, please consider citing:

```
@article{ye2024gsplatopensourcelibrarygaussian,
Expand Down
4 changes: 2 additions & 2 deletions examples/datasets/colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def __init__(
+ params[2] * theta**6
+ params[3] * theta**8
)
mapx = fx * x1 * r + width // 2
mapy = fy * y1 * r + height // 2
mapx = (fx * x1 * r + width // 2).astype(np.float32)
mapy = (fy * y1 * r + height // 2).astype(np.float32)

# Use mask to define ROI
mask = np.logical_and(
Expand Down
72 changes: 49 additions & 23 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from gsplat.distributed import cli
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy
from gsplat.optimizers import SelectiveAdam


@dataclass
Expand Down Expand Up @@ -118,6 +119,8 @@ class Config:
packed: bool = False
# Use sparse gradients for optimization. (experimental)
sparse_grad: bool = False
# Use visible adam from Taming 3DGS. (experimental)
visible_adam: bool = False
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
antialiased: bool = False

Expand Down Expand Up @@ -286,6 +289,7 @@ def create_splats_with_optimizers(
scene_scale: float = 1.0,
sh_degree: int = 3,
sparse_grad: bool = False,
visible_adam: bool = False,
batch_size: int = 1,
feature_dim: Optional[int] = None,
device: str = "cuda",
Expand Down Expand Up @@ -342,8 +346,15 @@ def create_splats_with_optimizers(
# Note that this would not make the training exactly equivalent, see
# https://arxiv.org/pdf/2402.18824v1
BS = batch_size * world_size
optimizer_class = None
if sparse_grad:
optimizer_class = torch.optim.SparseAdam
elif visible_adam:
optimizer_class = SelectiveAdam
else:
optimizer_class = torch.optim.Adam
optimizers = {
name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)(
name: optimizer_class(
[{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}],
eps=1e-15 / math.sqrt(BS),
# TODO: check betas logic when BS is larger than 10 betas[0] will be zero.
Expand Down Expand Up @@ -413,6 +424,7 @@ def __init__(
scene_scale=self.scene_scale,
sh_degree=cfg.sh_degree,
sparse_grad=cfg.sparse_grad,
visible_adam=cfg.visible_adam,
batch_size=cfg.batch_size,
feature_dim=feature_dim,
device=self.device,
Expand Down Expand Up @@ -835,27 +847,6 @@ def train(self):

save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply", rgb)

if isinstance(self.cfg.strategy, DefaultStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
packed=cfg.packed,
)
elif isinstance(self.cfg.strategy, MCMCStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
lr=schedulers[0].get_last_lr()[0],
)
else:
assert_never(self.cfg.strategy)

# Turn Gradients into Sparse Tensor before running optimizer
if cfg.sparse_grad:
assert cfg.packed, "Sparse gradients only work with packed mode."
Expand All @@ -871,9 +862,22 @@ def train(self):
is_coalesced=len(Ks) == 1,
)

if cfg.visible_adam:
gaussian_cnt = self.splats.means.shape[0]
if cfg.packed:
visibility_mask = torch.zeros_like(
self.splats["opacities"], dtype=bool
)
visibility_mask.scatter_(0, info["gaussian_ids"], 1)
else:
visibility_mask = (info["radii"] > 0).any(0)

# optimize
for optimizer in self.optimizers.values():
optimizer.step()
if cfg.visible_adam:
optimizer.step(visibility_mask)
else:
optimizer.step()
optimizer.zero_grad(set_to_none=True)
for optimizer in self.pose_optimizers:
optimizer.step()
Expand All @@ -887,6 +891,28 @@ def train(self):
for scheduler in schedulers:
scheduler.step()

# Run post-backward steps after backward and optimizer
if isinstance(self.cfg.strategy, DefaultStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
packed=cfg.packed,
)
elif isinstance(self.cfg.strategy, MCMCStrategy):
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
lr=schedulers[0].get_last_lr()[0],
)
else:
assert_never(self.cfg.strategy)

# eval the full set
if step in [i - 1 for i in cfg.eval_steps]:
self.eval(step)
Expand Down
6 changes: 5 additions & 1 deletion examples/simple_trainer_2dgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def rasterize_splats(
**kwargs,
)
elif self.model_type == "2dgs-inria":
render_colors, render_alphas, info = rasterization_2dgs_inria_wrapper(
renders, info = rasterization_2dgs_inria_wrapper(
means=means,
quats=quats,
scales=scales,
Expand Down Expand Up @@ -577,6 +577,10 @@ def train(self):
step=step,
info=info,
)
masks = data["mask"].to(device) if "mask" in data else None
if masks is not None:
pixels = pixels * masks[..., None]
colors = colors * masks[..., None]

# loss
l1loss = F.l1_loss(colors, pixels)
Expand Down
4 changes: 2 additions & 2 deletions examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(local_rank: int, world_rank, world_size: int, args):
quats, # [N, 4]
scales, # [N, 3]
opacities, # [N]
colors, # [N, 3]
colors, # [N, S, 3]
viewmats, # [C, 4, 4]
Ks, # [C, 3, 3]
width,
Expand Down Expand Up @@ -181,7 +181,7 @@ def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
quats, # [N, 4]
scales, # [N, 3]
opacities, # [N]
colors, # [N, 3]
colors, # [N, S, 3]
viewmat[None], # [1, 4, 4]
K[None], # [1, 3, 3]
width,
Expand Down
4 changes: 3 additions & 1 deletion gsplat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

from .compression import PngCompression
from .optimizers import SelectiveAdam
from .cuda._torch_impl import accumulate
from .cuda._torch_impl_2dgs import accumulate_2dgs
from .cuda._wrapper import (
Expand Down Expand Up @@ -48,5 +49,6 @@
"rasterize_to_pixels_2dgs",
"rasterize_to_indices_in_range_2dgs",
"accumulate_2dgs",
"rasterization_2dgs_inria_wrapper" "__version__",
"rasterization_2dgs_inria_wrapper",
"__version__",
]
2 changes: 1 addition & 1 deletion gsplat/cuda/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def cuda_toolkit_version():
current_dir = os.path.dirname(os.path.abspath(__file__))
glm_path = os.path.join(current_dir, "csrc", "third_party", "glm")

extra_include_paths = [os.path.join(PATH, "csrc/"), glm_path]
extra_include_paths = [os.path.join(PATH, "include/"), glm_path]
extra_cflags = ["-O3"]
if NO_FAST_MATH:
extra_cuda_cflags = ["-O3"]
Expand Down
18 changes: 18 additions & 0 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ def call_cuda(*args, **kwargs):
return call_cuda


def selective_adam_update(
param: Tensor,
param_grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
tiles_touched: Tensor,
lr: float,
b1: float,
b2: float,
eps: float,
N: int,
M: int,
) -> None:
_make_lazy_cuda_func("selective_adam_update")(
param, param_grad, exp_avg, exp_avg_sq, tiles_touched, lr, b1, b2, eps, N, M
)


def _make_lazy_cuda_obj(name: str) -> Any:
# pylint: disable=import-outside-toplevel
from ._backend import _C
Expand Down
81 changes: 81 additions & 0 deletions gsplat/cuda/csrc/adam.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#include "bindings.h"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda_runtime.h>

namespace gsplat {

namespace cg = cooperative_groups;

template<typename T>
__global__ void selective_adam_update_kernel(
T* __restrict__ param,
const T* __restrict__ param_grad,
T* __restrict__ exp_avg,
T* __restrict__ exp_avg_sq,
const bool* tiles_touched,
const float lr,
const float b1,
const float b2,
const float eps,
const uint32_t N,
const uint32_t M
) {
auto p_idx = cg::this_grid().thread_rank();
const uint32_t g_idx = p_idx / M;
if (g_idx >= N) return;
if (tiles_touched[g_idx]) {
T Register_param_grad = param_grad[p_idx];
T Register_exp_avg = exp_avg[p_idx];
T Register_exp_avg_sq = exp_avg_sq[p_idx];
Register_exp_avg = b1 * Register_exp_avg + (1.0f - b1) * Register_param_grad;
Register_exp_avg_sq = b2 * Register_exp_avg_sq + (1.0f - b2) * Register_param_grad * Register_param_grad;
T step = -lr * Register_exp_avg / (sqrt(Register_exp_avg_sq) + eps);

param[p_idx] += step;
exp_avg[p_idx] = Register_exp_avg;
exp_avg_sq[p_idx] = Register_exp_avg_sq;
}
}

void selective_adam_update(
torch::Tensor &param,
torch::Tensor &param_grad,
torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq,
torch::Tensor &tiles_touched,
const float lr,
const float b1,
const float b2,
const float eps,
const uint32_t N,
const uint32_t M
) {
GSPLAT_DEVICE_GUARD(param);
GSPLAT_CHECK_INPUT(param);
GSPLAT_CHECK_INPUT(param_grad);
GSPLAT_CHECK_INPUT(exp_avg);
GSPLAT_CHECK_INPUT(exp_avg_sq);
GSPLAT_CHECK_INPUT(tiles_touched);

const uint32_t cnt = N * M;
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
selective_adam_update_kernel<float><<<(cnt + 255) / 256, 256, 0, stream>>>(
param.data_ptr<float>(),
param_grad.data_ptr<float>(),
exp_avg.data_ptr<float>(),
exp_avg_sq.data_ptr<float>(),
tiles_touched.data_ptr<bool>(),
lr,
b1,
b2,
eps,
N,
M
);
}

} // namespace gsplat
1 change: 1 addition & 0 deletions gsplat/cuda/csrc/compute_sh_bwd.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "bindings.h"
#include "helpers.cuh"
#include "spherical_harmonics.cuh"
#include "types.cuh"

Expand Down
Loading

0 comments on commit 1809a6d

Please sign in to comment.