Skip to content

Commit f0dc651

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Shader API more consistent naming
Summary: Renamed shaders to be prefixed with Hard/Soft depending on if they use a probabalistic blending (Soft) or use the closest face (Hard). There is some code duplication but I thought it would be cleaner to have separate shaders for each task rather than: - inheritance (which we discussed previously that we want to avoid) - boolean (hard/soft) or a string (hard/soft) - new blending functions other than the ones provided would need if statements in the current shaders which might get messy. Also added a `flat_shading` function and a `FlatShader` - I could make this into a tutorial as it was really easy to add a new shader and it might be a nice showcase. NOTE: There are a few more places where the naming will need to change (e.g the tutorials) but I wanted to reach a consensus on this before changing it everywhere. Reviewed By: jcjohnson Differential Revision: D19761036 fbshipit-source-id: f972f6530c7f66dc5550b0284c191abc4a7f6fc4
1 parent 60f3c4e commit f0dc651

9 files changed

+293
-82
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ build/
22
dist/
33
*.egg-info/
44
**/__pycache__/
5+
*-checkpoint.ipynb
6+
**/.ipynb_checkpoints
7+
**/.ipynb_checkpoints/**
8+
59

610
# Docusaurus site
711
website/yarn.lock

docs/notes/renderer_getting_started.md

+22
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,25 @@ renderer = MeshRenderer(
8484
shader=PhongShader(device=device, cameras=cameras)
8585
)
8686
```
87+
88+
### A custom shader
89+
90+
Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set.
91+
92+
A shader can incorporate several steps:
93+
- **texturing** (e.g interpolation of vertex RGB colors or interpolation of vertex UV coordinates followed by sampling from a texture map (interpolation uses barycentric coordinates output from rasterization))
94+
- **lighting/shading** (e.g. ambient, diffuse, specular lighting, Phong, Gourad, Flat)
95+
- **blending** (e.g. hard blending using only the closest face for each pixel, or soft blending using a weighted sum of the top K faces per pixel)
96+
97+
We have examples of several combinations of these functions based on the texturing/shading/blending support we have currently. These are summarised in this table below. Many other combinations are possible and we plan to expand the options available for texturing, shading and blending.
98+
99+
100+
|Example Shaders | Vertex Textures| Texture Map| Flat Shading| Gourad Shading| Phong Shading | Hard blending | Soft Blending |
101+
| ------------- |:-------------: | :--------------:| :--------------:| :--------------:| :--------------:|:--------------:|:--------------:|
102+
| HardPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | :heavy_check_mark:||
103+
| SoftPhongShader | :heavy_check_mark: |||| :heavy_check_mark: | | :heavy_check_mark:|
104+
| HardGouradShader | :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:||
105+
| SoftGouradShader | :heavy_check_mark: ||| :heavy_check_mark: ||| :heavy_check_mark:|
106+
| TexturedSoftPhongShader || :heavy_check_mark: ||| :heavy_check_mark: || :heavy_check_mark:|
107+
| HardFlatShader | :heavy_check_mark: || :heavy_check_mark: ||| :heavy_check_mark:||
108+
| SoftSilhouetteShader ||||||| :heavy_check_mark:|

docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb

+20-35
Large diffs are not rendered by default.

docs/tutorials/render_textured_meshes.ipynb

+17-19
Large diffs are not rendered by default.

pytorch3d/renderer/__init__.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
from .lighting import DirectionalLights, PointLights, diffuse, specular
1818
from .materials import Materials
1919
from .mesh import (
20-
GouradShader,
20+
HardFlatShader,
21+
HardGouradShader,
22+
HardPhongShader,
2123
MeshRasterizer,
2224
MeshRenderer,
23-
PhongShader,
2425
RasterizationSettings,
25-
SilhouetteShader,
26-
TexturedPhongShader,
26+
SoftGouradShader,
27+
SoftPhongShader,
28+
SoftSilhouetteShader,
29+
TexturedSoftPhongShader,
2730
gourad_shading,
2831
interpolate_face_attributes,
2932
interpolate_texture_map,

pytorch3d/renderer/mesh/__init__.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from .rasterizer import MeshRasterizer, RasterizationSettings
55
from .renderer import MeshRenderer
66
from .shader import (
7-
GouradShader,
8-
PhongShader,
9-
SilhouetteShader,
10-
TexturedPhongShader,
7+
HardFlatShader,
8+
HardGouradShader,
9+
HardPhongShader,
10+
SoftGouradShader,
11+
SoftPhongShader,
12+
SoftSilhouetteShader,
13+
TexturedSoftPhongShader,
1114
)
1215
from .shading import gourad_shading, phong_shading
1316
from .texturing import ( # isort: skip

pytorch3d/renderer/mesh/shader.py

+171-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ..cameras import OpenGLPerspectiveCameras
1515
from ..lighting import PointLights
1616
from ..materials import Materials
17-
from .shading import gourad_shading, phong_shading
17+
from .shading import flat_shading, gourad_shading, phong_shading
1818
from .texturing import interpolate_texture_map, interpolate_vertex_colors
1919

2020
# A Shader should take as input fragments from the output of rasterization
@@ -26,17 +26,18 @@
2626
# - blend colors across top K faces per pixel.
2727

2828

29-
class PhongShader(nn.Module):
29+
class HardPhongShader(nn.Module):
3030
"""
31-
Per pixel lighting. Apply the lighting model using the interpolated coords
32-
and normals for each pixel.
31+
Per pixel lighting - the lighting model is applied using the interpolated
32+
coordinates and normals for each pixel. The blending function hard assigns
33+
the color of the closest face for each pixel.
3334
3435
To use the default values, simply initialize the shader with the desired
3536
device e.g.
3637
3738
.. code-block::
3839
39-
shader = PhongShader(device=torch.device("cuda:0"))
40+
shader = HardPhongShader(device=torch.device("cuda:0"))
4041
"""
4142

4243
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
@@ -70,17 +71,74 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
7071
return images
7172

7273

73-
class GouradShader(nn.Module):
74+
class SoftPhongShader(nn.Module):
7475
"""
75-
Per vertex lighting. Apply the lighting model to the vertex colors and then
76-
interpolate using the barycentric coordinates to get colors for each pixel.
76+
Per pixel lighting - the lighting model is applied using the interpolated
77+
coordinates and normals for each pixel. The blending function returns the
78+
soft aggregated color using all the faces per pixel.
7779
7880
To use the default values, simply initialize the shader with the desired
7981
device e.g.
8082
8183
.. code-block::
8284
83-
shader = GouradShader(device=torch.device("cuda:0"))
85+
shader = SoftPhongShader(device=torch.device("cuda:0"))
86+
"""
87+
88+
def __init__(
89+
self,
90+
device="cpu",
91+
cameras=None,
92+
lights=None,
93+
materials=None,
94+
blend_params=None,
95+
):
96+
super().__init__()
97+
self.lights = (
98+
lights if lights is not None else PointLights(device=device)
99+
)
100+
self.materials = (
101+
materials if materials is not None else Materials(device=device)
102+
)
103+
self.cameras = (
104+
cameras
105+
if cameras is not None
106+
else OpenGLPerspectiveCameras(device=device)
107+
)
108+
self.blend_params = (
109+
blend_params if blend_params is not None else BlendParams()
110+
)
111+
112+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
113+
texels = interpolate_vertex_colors(fragments, meshes)
114+
cameras = kwargs.get("cameras", self.cameras)
115+
lights = kwargs.get("lights", self.lights)
116+
materials = kwargs.get("materials", self.materials)
117+
colors = phong_shading(
118+
meshes=meshes,
119+
fragments=fragments,
120+
texels=texels,
121+
lights=lights,
122+
cameras=cameras,
123+
materials=materials,
124+
)
125+
images = softmax_rgb_blend(colors, fragments, self.blend_params)
126+
return images
127+
128+
129+
class HardGouradShader(nn.Module):
130+
"""
131+
Per vertex lighting - the lighting model is applied to the vertex colors and
132+
the colors are then interpolated using the barycentric coordinates to
133+
obtain the colors for each pixel. The blending function hard assigns
134+
the color of the closest face for each pixel.
135+
136+
To use the default values, simply initialize the shader with the desired
137+
device e.g.
138+
139+
.. code-block::
140+
141+
shader = HardGouradShader(device=torch.device("cuda:0"))
84142
"""
85143

86144
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
@@ -112,12 +170,69 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
112170
return images
113171

114172

115-
class TexturedPhongShader(nn.Module):
173+
class SoftGouradShader(nn.Module):
174+
"""
175+
Per vertex lighting - the lighting model is applied to the vertex colors and
176+
the colors are then interpolated using the barycentric coordinates to
177+
obtain the colors for each pixel. The blending function returns the
178+
soft aggregated color using all the faces per pixel.
179+
180+
To use the default values, simply initialize the shader with the desired
181+
device e.g.
182+
183+
.. code-block::
184+
185+
shader = SoftGouradShader(device=torch.device("cuda:0"))
186+
"""
187+
188+
def __init__(
189+
self,
190+
device="cpu",
191+
cameras=None,
192+
lights=None,
193+
materials=None,
194+
blend_params=None,
195+
):
196+
super().__init__()
197+
self.lights = (
198+
lights if lights is not None else PointLights(device=device)
199+
)
200+
self.materials = (
201+
materials if materials is not None else Materials(device=device)
202+
)
203+
self.cameras = (
204+
cameras
205+
if cameras is not None
206+
else OpenGLPerspectiveCameras(device=device)
207+
)
208+
self.blend_params = (
209+
blend_params if blend_params is not None else BlendParams()
210+
)
211+
212+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
213+
cameras = kwargs.get("cameras", self.cameras)
214+
lights = kwargs.get("lights", self.lights)
215+
materials = kwargs.get("materials", self.materials)
216+
pixel_colors = gourad_shading(
217+
meshes=meshes,
218+
fragments=fragments,
219+
lights=lights,
220+
cameras=cameras,
221+
materials=materials,
222+
)
223+
images = softmax_rgb_blend(pixel_colors, fragments, self.blend_params)
224+
return images
225+
226+
227+
class TexturedSoftPhongShader(nn.Module):
116228
"""
117229
Per pixel lighting applied to a texture map. First interpolate the vertex
118230
uv coordinates and sample from a texture map. Then apply the lighting model
119231
using the interpolated coords and normals for each pixel.
120232
233+
The blending function returns the soft aggregated color using all
234+
the faces per pixel.
235+
121236
To use the default values, simply initialize the shader with the desired
122237
device e.g.
123238
@@ -167,7 +282,52 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
167282
return images
168283

169284

170-
class SilhouetteShader(nn.Module):
285+
class HardFlatShader(nn.Module):
286+
"""
287+
Per face lighting - the lighting model is applied using the average face
288+
position and the face normal. The blending function hard assigns
289+
the color of the closest face for each pixel.
290+
291+
To use the default values, simply initialize the shader with the desired
292+
device e.g.
293+
294+
.. code-block::
295+
296+
shader = HardFlatShader(device=torch.device("cuda:0"))
297+
"""
298+
299+
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
300+
super().__init__()
301+
self.lights = (
302+
lights if lights is not None else PointLights(device=device)
303+
)
304+
self.materials = (
305+
materials if materials is not None else Materials(device=device)
306+
)
307+
self.cameras = (
308+
cameras
309+
if cameras is not None
310+
else OpenGLPerspectiveCameras(device=device)
311+
)
312+
313+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
314+
texels = interpolate_vertex_colors(fragments, meshes)
315+
cameras = kwargs.get("cameras", self.cameras)
316+
lights = kwargs.get("lights", self.lights)
317+
materials = kwargs.get("materials", self.materials)
318+
colors = flat_shading(
319+
meshes=meshes,
320+
fragments=fragments,
321+
texels=texels,
322+
lights=lights,
323+
cameras=cameras,
324+
materials=materials,
325+
)
326+
images = hard_rgb_blend(colors, fragments)
327+
return images
328+
329+
330+
class SoftSilhouetteShader(nn.Module):
171331
"""
172332
Calculate the silhouette by blending the top K faces for each pixel based
173333
on the 2d euclidean distance of the centre of the pixel to the mesh face.

pytorch3d/renderer/mesh/shading.py

+36
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,39 @@ def gourad_shading(
124124
face_colors = verts_colors_shaded[faces]
125125
colors = interpolate_face_attributes(fragments, face_colors)
126126
return colors
127+
128+
129+
def flat_shading(
130+
meshes, fragments, lights, cameras, materials, texels
131+
) -> torch.Tensor:
132+
"""
133+
Apply per face shading. Use the average face position and the face normals
134+
to compute the ambient, diffuse and specular lighting. Apply the ambient
135+
and diffuse color to the pixel color and add the specular component to
136+
determine the final pixel color.
137+
138+
Args:
139+
meshes: Batch of meshes
140+
fragments: Fragments named tuple with the outputs of rasterization
141+
lights: Lights class containing a batch of lights parameters
142+
cameras: Cameras class containing a batch of cameras parameters
143+
materials: Materials class containing a batch of material properties
144+
texels: texture per pixel of shape (N, H, W, K, 3)
145+
146+
Returns:
147+
colors: (N, H, W, K, 3)
148+
"""
149+
verts = meshes.verts_packed() # (V, 3)
150+
faces = meshes.faces_packed() # (F, 3)
151+
face_normals = meshes.faces_normals_packed() # (V, 3)
152+
faces_verts = verts[faces]
153+
face_coords = faces_verts.mean(dim=-2) # (F, 3, XYZ) mean xyz across verts
154+
pixel_coords = face_coords[fragments.pix_to_face]
155+
pixel_normals = face_normals[fragments.pix_to_face]
156+
157+
# Calculate the illumination at each face
158+
ambient, diffuse, specular = _apply_lighting(
159+
pixel_coords, pixel_normals, lights, cameras, materials
160+
)
161+
colors = (ambient + diffuse) * texels + specular
162+
return colors

0 commit comments

Comments
 (0)