3232# - sample colors from a texture map
3333# - apply per pixel lighting
3434# - blend colors across top K faces per pixel.
35-
36-
37- class HardPhongShader (nn .Module ):
38- """
39- Per pixel lighting - the lighting model is applied using the interpolated
40- coordinates and normals for each pixel. The blending function hard assigns
41- the color of the closest face for each pixel.
42-
43- To use the default values, simply initialize the shader with the desired
44- device e.g.
45-
46- .. code-block::
47-
48- shader = HardPhongShader(device=torch.device("cuda:0"))
49- """
50-
35+ class ShaderBase (nn .Module ):
5136 def __init__ (
5237 self ,
5338 device : Device = "cpu" ,
@@ -74,6 +59,21 @@ def to(self, device: Device):
7459 self .lights = self .lights .to (device )
7560 return self
7661
62+
63+ class HardPhongShader (ShaderBase ):
64+ """
65+ Per pixel lighting - the lighting model is applied using the interpolated
66+ coordinates and normals for each pixel. The blending function hard assigns
67+ the color of the closest face for each pixel.
68+
69+ To use the default values, simply initialize the shader with the desired
70+ device e.g.
71+
72+ .. code-block::
73+
74+ shader = HardPhongShader(device=torch.device("cuda:0"))
75+ """
76+
7777 def forward (self , fragments : Fragments , meshes : Meshes , ** kwargs ) -> torch .Tensor :
7878 cameras = kwargs .get ("cameras" , self .cameras )
7979 if cameras is None :
@@ -97,7 +97,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
9797 return images
9898
9999
100- class SoftPhongShader (nn . Module ):
100+ class SoftPhongShader (ShaderBase ):
101101 """
102102 Per pixel lighting - the lighting model is applied using the interpolated
103103 coordinates and normals for each pixel. The blending function returns the
@@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module):
111111 shader = SoftPhongShader(device=torch.device("cuda:0"))
112112 """
113113
114- def __init__ (
115- self ,
116- device : Device = "cpu" ,
117- cameras : Optional [TensorProperties ] = None ,
118- lights : Optional [TensorProperties ] = None ,
119- materials : Optional [Materials ] = None ,
120- blend_params : Optional [BlendParams ] = None ,
121- ) -> None :
122- super ().__init__ ()
123- self .lights = lights if lights is not None else PointLights (device = device )
124- self .materials = (
125- materials if materials is not None else Materials (device = device )
126- )
127- self .cameras = cameras
128- self .blend_params = blend_params if blend_params is not None else BlendParams ()
129-
130- # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
131- def to (self , device : Device ):
132- # Manually move to device modules which are not subclasses of nn.Module
133- cameras = self .cameras
134- if cameras is not None :
135- self .cameras = cameras .to (device )
136- self .materials = self .materials .to (device )
137- self .lights = self .lights .to (device )
138- return self
139-
140114 def forward (self , fragments : Fragments , meshes : Meshes , ** kwargs ) -> torch .Tensor :
141115 cameras = kwargs .get ("cameras" , self .cameras )
142116 if cameras is None :
@@ -164,7 +138,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
164138 return images
165139
166140
167- class HardGouraudShader (nn . Module ):
141+ class HardGouraudShader (ShaderBase ):
168142 """
169143 Per vertex lighting - the lighting model is applied to the vertex colors and
170144 the colors are then interpolated using the barycentric coordinates to
@@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module):
179153 shader = HardGouraudShader(device=torch.device("cuda:0"))
180154 """
181155
182- def __init__ (
183- self ,
184- device : Device = "cpu" ,
185- cameras : Optional [TensorProperties ] = None ,
186- lights : Optional [TensorProperties ] = None ,
187- materials : Optional [Materials ] = None ,
188- blend_params : Optional [BlendParams ] = None ,
189- ) -> None :
190- super ().__init__ ()
191- self .lights = lights if lights is not None else PointLights (device = device )
192- self .materials = (
193- materials if materials is not None else Materials (device = device )
194- )
195- self .cameras = cameras
196- self .blend_params = blend_params if blend_params is not None else BlendParams ()
197-
198- # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
199- def to (self , device : Device ):
200- # Manually move to device modules which are not subclasses of nn.Module
201- cameras = self .cameras
202- if cameras is not None :
203- self .cameras = cameras .to (device )
204- self .materials = self .materials .to (device )
205- self .lights = self .lights .to (device )
206- return self
207-
208156 def forward (self , fragments : Fragments , meshes : Meshes , ** kwargs ) -> torch .Tensor :
209157 cameras = kwargs .get ("cameras" , self .cameras )
210158 if cameras is None :
@@ -231,7 +179,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
231179 return images
232180
233181
234- class SoftGouraudShader (nn . Module ):
182+ class SoftGouraudShader (ShaderBase ):
235183 """
236184 Per vertex lighting - the lighting model is applied to the vertex colors and
237185 the colors are then interpolated using the barycentric coordinates to
@@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module):
246194 shader = SoftGouraudShader(device=torch.device("cuda:0"))
247195 """
248196
249- def __init__ (
250- self ,
251- device : Device = "cpu" ,
252- cameras : Optional [TensorProperties ] = None ,
253- lights : Optional [TensorProperties ] = None ,
254- materials : Optional [Materials ] = None ,
255- blend_params : Optional [BlendParams ] = None ,
256- ) -> None :
257- super ().__init__ ()
258- self .lights = lights if lights is not None else PointLights (device = device )
259- self .materials = (
260- materials if materials is not None else Materials (device = device )
261- )
262- self .cameras = cameras
263- self .blend_params = blend_params if blend_params is not None else BlendParams ()
264-
265- # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
266- def to (self , device : Device ):
267- # Manually move to device modules which are not subclasses of nn.Module
268- cameras = self .cameras
269- if cameras is not None :
270- self .cameras = cameras .to (device )
271- self .materials = self .materials .to (device )
272- self .lights = self .lights .to (device )
273- return self
274-
275197 def forward (self , fragments : Fragments , meshes : Meshes , ** kwargs ) -> torch .Tensor :
276198 cameras = kwargs .get ("cameras" , self .cameras )
277199 if cameras is None :
@@ -320,7 +242,7 @@ def TexturedSoftPhongShader(
320242 )
321243
322244
323- class HardFlatShader (nn . Module ):
245+ class HardFlatShader (ShaderBase ):
324246 """
325247 Per face lighting - the lighting model is applied using the average face
326248 position and the face normal. The blending function hard assigns
@@ -334,32 +256,6 @@ class HardFlatShader(nn.Module):
334256 shader = HardFlatShader(device=torch.device("cuda:0"))
335257 """
336258
337- def __init__ (
338- self ,
339- device : Device = "cpu" ,
340- cameras : Optional [TensorProperties ] = None ,
341- lights : Optional [TensorProperties ] = None ,
342- materials : Optional [Materials ] = None ,
343- blend_params : Optional [BlendParams ] = None ,
344- ) -> None :
345- super ().__init__ ()
346- self .lights = lights if lights is not None else PointLights (device = device )
347- self .materials = (
348- materials if materials is not None else Materials (device = device )
349- )
350- self .cameras = cameras
351- self .blend_params = blend_params if blend_params is not None else BlendParams ()
352-
353- # pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
354- def to (self , device : Device ):
355- # Manually move to device modules which are not subclasses of nn.Module
356- cameras = self .cameras
357- if cameras is not None :
358- self .cameras = cameras .to (device )
359- self .materials = self .materials .to (device )
360- self .lights = self .lights .to (device )
361- return self
362-
363259 def forward (self , fragments : Fragments , meshes : Meshes , ** kwargs ) -> torch .Tensor :
364260 cameras = kwargs .get ("cameras" , self .cameras )
365261 if cameras is None :
0 commit comments