Skip to content

Commit

Permalink
[Feature] Upgrade differentiable renderers (#107)
Browse files Browse the repository at this point in the history
- Refactor differentiable renderer
  
- Refactor visualize_smpl
  
- Body_models
  
- visualize_kp2d
  
- Readme
  
- Cameras

- Mesh_utils

- Meshes
  
- Tests
  • Loading branch information
WenjiaWang0312 authored Mar 22, 2022
1 parent 6fbbf98 commit 96517ae
Show file tree
Hide file tree
Showing 45 changed files with 3,338 additions and 1,385 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ logs/
*.png
*.gif
*.jpg
*.obj
*.ply
!demo/resources/*

# Resources as exception
Expand Down
158 changes: 68 additions & 90 deletions configs/render/smpl.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
base_directional_light = {
'type': 'directional',
'direction': [[10.0, 10.0, 10.0]],
'direction': [[1, 1, 1]],
'ambient_color': [[0.5, 0.5, 0.5]],
'diffuse_color': [[0.5, 0.5, 0.5]],
'specular_color': [[0.5, 0.5, 0.5]],
}

base_point_light = {
'type': 'point',
'ambient_color': [[0.5, 0.5, 0.5]],
'ambient_color': [[1, 1, 1]],
'diffuse_color': [[0.3, 0.3, 0.3]],
'specular_color': [[0.5, 0.5, 0.5]],
'location': [[2.0, 2.0, -2.0]],
Expand All @@ -22,7 +22,7 @@
base_material = {
'ambient_color': [[1, 1, 1]],
'diffuse_color': [[0.5, 0.5, 0.5]],
'specular_color': [[0.5, 0.5, 0.5]],
'specular_color': [[0.15, 0.15, 0.15]],
'shininess': 60.0,
}

Expand All @@ -33,139 +33,117 @@
'shininess': 1.0,
}

empty_light = None

empty_material = {}

white_blend_params = {'background_color': (1.0, 1.0, 1.0)}

black_blend_params = {'background_color': (0.0, 0.0, 0.0)}

RENDER_CONFIGS = {
# low quality
'lq': {
'renderer_type': 'base',
'shader_type': 'flat',
'texture_type': 'vertex',
'raster_type': 'mesh',
'light': base_directional_light,
'material': base_material,
'raster_setting': {
'type': 'mesh',
'shader': {
'type': 'hard_flat'
},
'lights': base_directional_light,
'materials': base_material,
'rasterizer': {
'bin_size': 0,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'cull_to_frustum': True,
'cull_backfaces': True,
'perspective_correct': False,
},
'blend': white_blend_params,
'blend_params': white_blend_params,
},
# medium quality
'mq': {
'renderer_type': 'base',
'shader_type': 'gouraud',
'texture_type': 'vertex',
'raster_type': 'mesh',
'light': base_directional_light,
'material': base_material,
'raster_setting': {
'type': 'mesh',
'shader': {
'type': 'soft_gouraud'
},
'lights': base_directional_light,
'materials': base_material,
'rasterizer': {
'bin_size': 0,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'cull_to_frustum': True,
'cull_backfaces': True,
'perspective_correct': False,
},
'blend': white_blend_params,
'blend_params': white_blend_params,
},
# high quality
'hq': {
'renderer_type': 'base',
'shader_type': 'phong',
'texture_type': 'vertex',
'raster_type': 'mesh',
'light': base_directional_light,
'material': base_material,
'raster_setting': {
'type': 'mesh',
'shader': {
'type': 'soft_phong'
},
'lights': base_directional_light,
'materials': base_material,
'rasterizer': {
'bin_size': 0,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'cull_to_frustum': False,
'cull_backfaces': False,
'bin_size': 0,
'perspective_correct': False,
},
'blend': white_blend_params,
'blend_params': white_blend_params,
},
'silhouette': {
'renderer_type': 'silhouette',
'shader_type': 'silhouette',
'texture_type': 'vertex',
'raster_type': 'mesh',
'material': silhouete_material,
'raster_setting': {
'type': 'silhouette',
'lights': None,
'materials': silhouete_material,
'rasterizer': {
'bin_size': 0,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'blur_radius': 2e-5,
'faces_per_pixel': 50,
'perspective_correct': False,
},
'blend': black_blend_params,
'blend_params': black_blend_params,
},
'part_silhouette': {
'renderer_type': 'segmentation',
'shader_type': 'nolight',
'texture_type': 'closest',
'raster_type': 'mesh',
'light': base_directional_light,
'type': 'segmentation',
'material': base_material,
'raster_setting': {
'rasterizer': {
'bin_size': 0,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'cull_to_frustum': False,
'cull_backfaces': True,
'perspective_correct': False,
},
'blend': black_blend_params,
'blend_params': black_blend_params,
},
'depth': {
'renderer_type': 'depth',
'shader_type': 'nolight',
'texture_type': 'vertex',
'raster_type': 'mesh',
'light': empty_light,
'material': empty_material,
'raster_setting': {
'type': 'depth',
'rasterizer': {
'bin_size': 0,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'cull_to_frustum': False,
'cull_backfaces': False,
'perspective_correct': False,
},
'blend': black_blend_params,
'blend_params': black_blend_params,
},
'normal': {
'renderer_type': 'normal',
'shader_type': 'nolight',
'texture_type': 'vertex',
'raster_type': 'mesh',
'light': empty_light,
'material': empty_material,
'raster_setting': {
'type': 'normal',
'rasterizer': {
'bin_size': 0,
'blur_radius': 0.0,
'faces_per_pixel': 1,
'cull_to_frustum': False,
'cull_backfaces': False,
'perspective_correct': False,
},
'blend': white_blend_params,
'blend_params': white_blend_params,
},
'pointcloud': {
'renderer_type': 'pointcloud',
'shader_type': 'nolight',
'raster_type': 'point',
'light': empty_light,
'material': empty_material,
'blend': white_blend_params,
'bg_color': [
1.0,
1.0,
1.0,
0.0,
],
'points_per_pixel': 10,
'radius': 0.003
'type': 'pointcloud',
'compositor': {
'background_color': [
1.0,
1.0,
1.0,
0.0,
],
},
'rasterizer': {
'points_per_pixel': 10,
'radius': 0.003,
'bin_size': None,
'max_points_per_bin': None,
}
}
}
150 changes: 150 additions & 0 deletions docs/render.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Render Meshes

## Renderer Initialization

We follow `Pytorch3D` renderer. We initialize the renderer with rasterizer, shader and other settings. Ours is compatible with `Pytorch3D` renderer initializations, but more flexible and functional. E.g., you can initialize a renderer just like `Pytorch3D` by passing the rasterizer and shader modules, or you can pass setting `dicts`, or use default settings.
In `mmhuman3d`, we provide `MeshRenderer`, `DepthRenderer`, `NormalRenderer`, `PointCloudRenderer`, `SegmentationRenderer`, `SilhouetteRenderer` and `UVRenderer`. In these renderers, `UVRenderer` is special and please refer to the last chapter [UVRenderer](#uvrenderer).

All of these renderers could be initialized by MMCV.Registry. It is convenient to store the renderers configs by dicts.

- **comparison between `pytorch3d` and `mmhuman3d`:**
```python
### initialized by Pytorch3D
import torch
from pytorch3d.renderer import MeshRenderer, RasterizationSettings
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.cameras import FoVPerspectiveCameras

device = torch.device('cuda')
R, T = look_at_view_transform(dist=2.7, elev=0, azim=0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

lights = PointLights(
device=device,
ambient_color=((0.5, 0.5, 0.5), ),
diffuse_color=((0.3, 0.3, 0.3), ),
specular_color=((0.2, 0.2, 0.2), ),
direction=((0, 1, 0), ),
)
raster_settings = RasterizationSettings(
image_size=128,
blur_radius=0.0,
faces_per_pixel=1,
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras, raster_settings=raster_settings),
shader=SoftPhongShader(device=device, cameras=cameras, lights=lights))

### initialized by mmhuman3d
from mmhuman3d.core.visualization.renderer import MeshRenderer
# rasterizer could be passed by nn.Module or dict
rasterizer = dict(
image_size=128,
blur_radius=0.0,
faces_per_pixel=1,
)
# lights could be passed by nn.Module or dict
lights = dict(type='point', ambient_color=((0.5, 0.5, 0.5), ),
diffuse_color=((0.3, 0.3, 0.3), ),
specular_color=((0.2, 0.2, 0.2), ),
direction=((0, 1, 0), ),)

# rasterizer could be passed by cameras or dict
cameras = dict(type='fovperspective', R=R, T=T, device=device)

# shader could be passed by nn.Module or dict
shader = dict(type='SoftPhongShader')
```

These two methods are equal.
```python
import torch.nn as nn
from mmhuman3d.core.visualization.renderer import MeshRenderer, build_renderer

renderer = MeshRenderer(shader=shader, device=device, rasterizer=rasterizer, resolution=resolution)
renderer = build_renderer(dict(type='mesh', device=device, shader=shader, rasterizer=rasterizer, resolution=resolution))

# Use default raster and shader settings
renderer = build_renderer(dict(type='mesh', device=device, resolution=resolution))
assert isinstance(renderer.rasterizer, nn.Module)
assert isinstance(renderer.shader, nn.Module)
```

We provide `tensor2rgba` function for visualization, the returned tensor will be a colorful image for visualization.
This function is different for different renderers. E.g., the rendered tensor of `DepthRenderer` is shape of (N, H, W, 1) of depth, and we will repeat it as a (N, H, W, 4) image tensor. And the rendered tensor of `SegmentationRenderer` is shape of (N, H, W, C) LongTensor, and we will convert it as a (N, H, W, 4) colorful image tensor according to a colormap. The rendered tensor of `NormalRenderer` is a (N, H, W, 4), its range is [-1, 1] and the `tensor2rgba` will normalize it to [0, 1].

The operation is simple:
```python
import torch
from mmhuman3d.core.visualization.renderer import build_renderer

renderer = build_renderer(dict(type='mesh', device=device, resolution=resolution))
rendered_tensor = renderer(meshes=meshes, cameras=cameras, lights=lights)
rendered_rgba = renderer.tensor2rgba(rendered_tensor)
```

Moreover, our renderer could set output settings and provide file I/O operations.
These writed images or videos are converted by the mentioned function `tensor2rgba`.

```python
# will write a video
renderer = build_renderer(dict(type='mesh', device=device, resolution=resolution, output_path='test.mp4'))
backgrounds = torch.Tensor(N, H, W, 3)
rendered_tensor = renderer(meshes=meshes, cameras=cameras, lights=lights, backgrounds=backgrounds)
renderer.export() # needed for a video

# will write a folder of images
renderer = build_renderer(dict(type='mesh', device=device, resolution=resolution, output_path='test_folder', out_img_format='%06d.png'))
backgrounds = torch.Tensor(N, H, W, 3)
rendered_tensor = renderer(meshes=meshes, cameras=cameras, lights=lights, backgrounds=backgrounds)
```


## Use render_runner

You could pass your data by `render_runner` to render a series batch of render. It will use a for loop to render the tensor by batch so you can render a long sequence of video without `CUDA out of memory`.

```python
import torch
from mmhuman3d.core.visualization.renderer import render_runner

render_data = dict(cameras=cameras, lights=lights, meshes=meshes, backgrounds=backgrounds)
# no_grad=True for non-differentiable render
rendered_tensor = render_runner.render(renderer=renderer, output_path=output_path, resolution=resolution, batch_size=batch_size, device=device, no_grad=True, return_tensor=True, **render_data)
```

## UVRenderer

Our `UVRenderer` is different from the above renderers. It is actually a smpl uv topology defined wrapper and sampler. It has two main utilities: wrapping vertex attributes to a map, sampling vertex attributes from a map.

### Initialize
The UV information is stored in the `smpl_uv.npz` file.
```python
uv_renderer = build_renderer(dict(type='uv', resolution=resolution, device=device, model_type='smpl', uv_param_path='data/body_models/smpl/smpl_uv.npz'))
```
### warping
Warp a gray texture image to smpl_mesh.

```python
import torch
from mmhuman3d.models import build_body_model
from pytorch3d.structures import meshes
from mmhuman3d.core.visualization.renderer import build_renderer
body_model = build_body_model(dict(type='smpl', model_path=model_path)).to(device)
pose_dict = body_model.tensor2dict(torch.zeros(1, 72))
verts = body_model(**pose_dict)['vertices']
faces = body_model.faces_tensor[None]
smpl_mesh = Meshes(verts=verts, faces=faces)
texture_image = torch.ones(1, 512, 512, 3) * 0.5
smpl_mesh.textures = uv_renderer.warp_texture(texture_image=texture_image)
```
### sampling
Sample vertex normal from a normal_map.

```python
normal_map = torch.ones(1, 512, 512, 3)
vertex_normals = uv_renderer.vertex_resample(normal_map)
assert vertex_normals.shape == (1 ,6890, 3)
assert (vertex_normals == 1).all()
```
Loading

0 comments on commit 96517ae

Please sign in to comment.