Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PBR Maps Node #6720

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions invokeai/app/invocations/pbr_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pathlib
from typing import Literal

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import ImageField, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
from invokeai.backend.image_util.pbr_maps.pbr_maps import NORMAL_MAP_MODEL, OTHER_MAP_MODEL, PBRMapsGenerator
from invokeai.backend.util.devices import TorchDevice


@invocation_output("pbr_maps-output")
class PBRMapsOutput(BaseInvocationOutput):
normal_map: ImageField = OutputField(default=None, description="The generated normal map")
roughness_map: ImageField = OutputField(default=None, description="The generated roughness map")
displacement_map: ImageField = OutputField(default=None, description="The generated displacement map")


@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0")
class PBRMapsInvocation(BaseInvocation):
"""Generate Normal, Displacement and Roughness Map from a given image"""

image: ImageField = InputField(default=None, description="Input image")
tile_size: int = InputField(default=512, description="Tile size")
border_mode: Literal["none", "seamless", "mirror", "replicate"] = InputField(
default="none", description="Border mode to apply to eliminate any artifacts or seams"
)

def invoke(self, context: InvocationContext) -> PBRMapsOutput:
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")

def loader(model_path: pathlib.Path):
return PBRMapsGenerator.load_model(model_path, TorchDevice.choose_torch_device())

with (
context.models.load_remote_model(NORMAL_MAP_MODEL, loader) as normal_map_model,
context.models.load_remote_model(OTHER_MAP_MODEL, loader) as other_map_model,
):
assert isinstance(normal_map_model, PBR_RRDB_Net)
assert isinstance(other_map_model, PBR_RRDB_Net)
pbr_pipeline = PBRMapsGenerator(normal_map_model, other_map_model, TorchDevice.choose_torch_device())
normal_map, roughness_map, displacement_map = pbr_pipeline.generate_maps(
image_pil, self.tile_size, self.border_mode
)

normal_map = context.images.save(normal_map)
normal_map_field = ImageField(image_name=normal_map.image_name)

roughness_map = context.images.save(roughness_map)
roughness_map_field = ImageField(image_name=roughness_map.image_name)

displacement_map = context.images.save(displacement_map)
displacement_map_map_field = ImageField(image_name=displacement_map.image_name)

return PBRMapsOutput(
normal_map=normal_map_field, roughness_map=roughness_map_field, displacement_map=displacement_map_map_field
)
367 changes: 367 additions & 0 deletions invokeai/backend/image_util/pbr_maps/architecture/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,367 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI

from collections import OrderedDict
from typing import Any, List, Literal, Optional

import torch
import torch.nn as nn

ACTIVATION_LAYER_TYPE = Literal["relu", "leakyrelu", "prelu"]
NORMALIZATION_LAYER_TYPE = Literal["batch", "instance"]
PADDING_LAYER_TYPE = Literal["zero", "reflect", "replicate"]
BLOCK_MODE = Literal["CNA", "NAC", "CNAC"]
UPCONV_BLOCK_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear"]


def act(act_type: ACTIVATION_LAYER_TYPE, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1):
"""Helper to select Activation Layer"""
if act_type == "relu":
layer = nn.ReLU(inplace)
elif act_type == "leakyrelu":
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == "prelu":
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
return layer


def norm(norm_type: NORMALIZATION_LAYER_TYPE, nc: int):
"""Helper to select Normalization Layer"""
if norm_type == "batch":
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == "instance":
layer = nn.InstanceNorm2d(nc, affine=False)
return layer


def pad(pad_type: PADDING_LAYER_TYPE, padding: int):
"""Helper to select Padding Layer"""
if padding == 0 or pad_type == "zero":
return None
if pad_type == "reflect":
layer = nn.ReflectionPad2d(padding)
elif pad_type == "replicate":
layer = nn.ReplicationPad2d(padding)
return layer


def get_valid_padding(kernel_size: int, dilation: int):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding


def sequential(*args: Any):
# Flatten Sequential. It unwraps nn.Sequential.
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError("sequential does not support OrderedDict input.")
return args[0] # No sequential is needed.
modules: List[nn.Module] = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)


def conv_block(
in_nc: int,
out_nc: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
pad_type: Optional[PADDING_LAYER_TYPE] = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
mode: BLOCK_MODE = "CNA",
):
"""
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
"""
assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type else None
padding = padding if pad_type == "zero" else 0

c = nn.Conv2d(
in_nc,
out_nc,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
groups=groups,
)
a = act(act_type) if act_type else None
match mode:
case "CNA":
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
case "NAC":
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
case "CNAC":
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)


class ConcatBlock(nn.Module):
# Concat the output of a submodule to its input
def __init__(self, submodule: nn.Module):
super(ConcatBlock, self).__init__()
self.sub = submodule

def forward(self, x: torch.Tensor):
output = torch.cat((x, self.sub(x)), dim=1)
return output

def __repr__(self):
tmpstr = "Identity .. \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr


class ShortcutBlock(nn.Module):
# Elementwise sum the output of a submodule to its input
def __init__(self, submodule: nn.Module):
super(ShortcutBlock, self).__init__()
self.sub = submodule

def forward(self, x: torch.Tensor):
output = x + self.sub(x)
return output

def __repr__(self):
tmpstr = "Identity + \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr


class ShortcutBlockSPSR(nn.Module):
# Elementwise sum the output of a submodule to its input
def __init__(self, submodule: nn.Module):
super(ShortcutBlockSPSR, self).__init__()
self.sub = submodule

def forward(self, x: torch.Tensor):
return x, self.sub

def __repr__(self):
tmpstr = "Identity + \n|"
modstr = self.sub.__repr__().replace("\n", "\n|")
tmpstr = tmpstr + modstr
return tmpstr


class ResNetBlock(nn.Module):
"""
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
"""

def __init__(
self,
in_nc: int,
mid_nc: int,
out_nc: int,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: Optional[ACTIVATION_LAYER_TYPE] = "relu",
mode: BLOCK_MODE = "CNA",
res_scale: int = 1,
):
super(ResNetBlock, self).__init__()
conv0 = conv_block(
in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
)
if mode == "CNA":
act_type = None
if mode == "CNAC": # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = conv_block(
mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, norm_type, act_type, mode
)

self.res = sequential(conv0, conv1)
self.res_scale = res_scale

def forward(self, x: torch.Tensor):
res = self.res(x).mul(self.res_scale)
return x + res


class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
"""

def __init__(
self,
nc: int,
kernel_size: int = 3,
gc: int = 32,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
mode: BLOCK_MODE = "CNA",
):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = conv_block(
nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type, mode=mode
)
self.conv2 = conv_block(
nc + gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
)
self.conv3 = conv_block(
nc + 2 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
)
self.conv4 = conv_block(
nc + 3 * gc,
gc,
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
mode=mode,
)
if mode == "CNA":
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(
nc + 4 * gc, nc, 3, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=last_act, mode=mode
)

def forward(self, x: torch.Tensor):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5.mul(0.2) + x


class RRDB(nn.Module):
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""

def __init__(
self,
nc: int,
kernel_size: int = 3,
gc: int = 32,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "leakyrelu",
mode: BLOCK_MODE = "CNA",
):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, norm_type, act_type, mode)

def forward(self, x: torch.Tensor):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out.mul(0.2) + x


# Upsampler
def pixelshuffle_block(
in_nc: int,
out_nc: int,
upscale_factor: int = 2,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "relu",
):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv = conv_block(
in_nc,
out_nc * (upscale_factor**2),
kernel_size,
stride,
bias=bias,
pad_type=pad_type,
norm_type=None,
act_type=None,
)
pixel_shuffle = nn.PixelShuffle(upscale_factor)

n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)


def upconv_blcok(
in_nc: int,
out_nc: int,
upscale_factor: int = 2,
kernel_size: int = 3,
stride: int = 1,
bias: bool = True,
pad_type: PADDING_LAYER_TYPE = "zero",
norm_type: Optional[NORMALIZATION_LAYER_TYPE] = None,
act_type: ACTIVATION_LAYER_TYPE = "relu",
mode: UPCONV_BLOCK_MODE = "nearest",
):
# Adopted from https://distill.pub/2016/deconv-checkerboard/
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(
in_nc, out_nc, kernel_size, stride, bias=bias, pad_type=pad_type, norm_type=norm_type, act_type=act_type
)
return sequential(upsample, conv)
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI

import math
from typing import Literal, Optional

import torch
import torch.nn as nn

import invokeai.backend.image_util.pbr_maps.architecture.block as B

UPSCALE_MODE = Literal["upconv", "pixelshuffle"]


class PBR_RRDB_Net(nn.Module):
def __init__(
self,
in_nc: int,
out_nc: int,
nf: int,
nb: int,
gc: int = 32,
upscale: int = 4,
norm_type: Optional[B.NORMALIZATION_LAYER_TYPE] = None,
act_type: B.ACTIVATION_LAYER_TYPE = "leakyrelu",
mode: B.BLOCK_MODE = "CNA",
res_scale: int = 1,
upsample_mode: UPSCALE_MODE = "upconv",
):
super(PBR_RRDB_Net, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1

fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [
B.RRDB(
nf,
kernel_size=3,
gc=32,
stride=1,
bias=True,
pad_type="zero",
norm_type=norm_type,
act_type=act_type,
mode="CNA",
)
for _ in range(nb)
]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

if upsample_mode == "upconv":
upsample_block = B.upconv_blcok
elif upsample_mode == "pixelshuffle":
upsample_block = B.pixelshuffle_block

if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]

HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

self.model = B.sequential(
fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), *upsampler, HR_conv0, HR_conv1
)

def forward(self, x: torch.Tensor):
return self.model(x)
104 changes: 104 additions & 0 deletions invokeai/backend/image_util/pbr_maps/pbr_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI

import pathlib
from typing import Any, Literal

import cv2
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image

from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
from invokeai.backend.image_util.pbr_maps.utils.image_ops import crop_seamless, esrgan_launcher_split_merge

NORMAL_MAP_MODEL = "https://github.com/joeyballentine/Material-Map-Generator/blob/master/utils/models/1x_NormalMapGenerator-CX-Lite_200000_G.pth"
OTHER_MAP_MODEL = "https://github.com/joeyballentine/Material-Map-Generator/blob/master/utils/models/1x_FrankenMapGenerator-CX-Lite_215000_G.pth"


class PBRMapsGenerator:
def __init__(self, normal_map_model: PBR_RRDB_Net, other_map_model: PBR_RRDB_Net, device: torch.device) -> None:
self.normal_map_model = normal_map_model
self.other_map_model = other_map_model
self.device = device

@staticmethod
def load_model(model_path: pathlib.Path, device: torch.device) -> PBR_RRDB_Net:
state_dict = torch.load(model_path.as_posix(), map_location="cpu")

model = PBR_RRDB_Net(
3,
3,
32,
12,
gc=32,
upscale=1,
norm_type=None,
act_type="leakyrelu",
mode="CNA",
res_scale=1,
upsample_mode="upconv",
)

model.load_state_dict(state_dict, strict=False)
del state_dict
model.eval()

for _, v in model.named_parameters():
v.requires_grad = False

return model.to(device)

def process(self, img: npt.NDArray[Any], model: PBR_RRDB_Net):
img = img.astype(np.float32) / np.iinfo(img.dtype).max
img = img[..., ::-1].copy()
tensor_img = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(self.device)

with torch.no_grad():
output = model(tensor_img).data.squeeze(0).float().cpu().clamp_(0, 1).numpy()
output = output[[2, 1, 0], :, :]
output = np.transpose(output, (1, 2, 0))
output = (output * 255.0).round()
return output

def _cv2_to_pil(self, image: npt.NDArray[Any]):
return Image.fromarray(cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2BGR))

def generate_maps(
self,
image: Image.Image,
tile_size: int = 512,
border_mode: Literal["none", "seamless", "mirror", "replicate"] = "none",
):
models = [self.normal_map_model, self.other_map_model]
np_image = np.array(image).astype(np.uint8)

match border_mode:
case "seamless":
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_WRAP)
case "mirror":
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REFLECT_101)
case "replicate":
np_image = cv2.copyMakeBorder(np_image, 16, 16, 16, 16, cv2.BORDER_REPLICATE)
case "none":
pass

img_height, img_width = np_image.shape[:2]

# Checking whether to perform tiled inference
do_split = img_height > tile_size or img_width > tile_size

if do_split:
rlts = esrgan_launcher_split_merge(np_image, self.process, models, scale_factor=1, tile_size=tile_size)
else:
rlts = [self.process(np_image, model) for model in models]

if border_mode != "none":
rlts = [crop_seamless(rlt) for rlt in rlts]

normal_map = self._cv2_to_pil(rlts[0])
roughness = self._cv2_to_pil(rlts[1][:, :, 1])
displacement = self._cv2_to_pil(rlts[1][:, :, 0])

return normal_map, roughness, displacement
93 changes: 93 additions & 0 deletions invokeai/backend/image_util/pbr_maps/utils/image_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Original: https://github.com/joeyballentine/Material-Map-Generator
# Adopted and optimized for Invoke AI

import math
from typing import Any, Callable, List

import numpy as np
import numpy.typing as npt

from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net


def crop_seamless(img: npt.NDArray[Any]):
img_height, img_width = img.shape[:2]
y, x = 16, 16
h, w = img_height - 32, img_width - 32
img = img[y : y + h, x : x + w]
return img


# from https://github.com/ata4/esrgan-launcher/blob/master/upscale.py
def esrgan_launcher_split_merge(
input_image: npt.NDArray[Any],
upscale_function: Callable[[npt.NDArray[Any], PBR_RRDB_Net], npt.NDArray[Any]],
models: List[PBR_RRDB_Net],
scale_factor: int = 4,
tile_size: int = 512,
tile_padding: float = 0.125,
):
width, height, depth = input_image.shape
output_width = width * scale_factor
output_height = height * scale_factor
output_shape = (output_width, output_height, depth)

# start with black image
output_images = [np.zeros(output_shape, np.uint8) for _ in range(len(models))]

tile_padding = math.ceil(tile_size * tile_padding)
tile_size = math.ceil(tile_size / scale_factor)

tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)

for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * tile_size
ofs_y = y * tile_size

# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + tile_size, width)

input_start_y = ofs_y
input_end_y = min(ofs_y + tile_size, height)

# input tile area on total image with padding
input_start_x_pad = max(input_start_x - tile_padding, 0)
input_end_x_pad = min(input_end_x + tile_padding, width)

input_start_y_pad = max(input_start_y - tile_padding, 0)
input_end_y_pad = min(input_end_y + tile_padding, height)

# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y

input_tile = input_image[input_start_x_pad:input_end_x_pad, input_start_y_pad:input_end_y_pad]

for idx, model in enumerate(models):
# upscale tile
output_tile = upscale_function(input_tile, model)

# output tile area on total image
output_start_x = input_start_x * scale_factor
output_end_x = input_end_x * scale_factor

output_start_y = input_start_y * scale_factor
output_end_y = input_end_y * scale_factor

# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * scale_factor
output_end_x_tile = output_start_x_tile + input_tile_width * scale_factor

output_start_y_tile = (input_start_y - input_start_y_pad) * scale_factor
output_end_y_tile = output_start_y_tile + input_tile_height * scale_factor

# put tile into output image
output_images[idx][output_start_x:output_end_x, output_start_y:output_end_y] = output_tile[
output_start_x_tile:output_end_x_tile, output_start_y_tile:output_end_y_tile
]

return output_images