Skip to content

Commit 1388f85

Browse files
blessedcoolanthipsterusername
authored andcommitted
initial: Implement PBR Maps node
A node to generate normal, roughness and displacement maps from a single image. Based on: https://github.com/joeyballentine/Material-Map-Generator
1 parent 8ecf728 commit 1388f85

File tree

5 files changed

+691
-0
lines changed

5 files changed

+691
-0
lines changed

invokeai/app/invocations/pbr_maps.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pathlib
2+
from typing import Literal
3+
4+
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
5+
from invokeai.app.invocations.fields import ImageField, InputField, OutputField
6+
from invokeai.app.services.shared.invocation_context import InvocationContext
7+
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
8+
from invokeai.backend.image_util.pbr_maps.pbr_maps import NORMAL_MAP_MODEL, OTHER_MAP_MODEL, PBRMapsGenerator
9+
from invokeai.backend.util.devices import TorchDevice
10+
11+
12+
@invocation_output("pbr_maps-output")
13+
class PBRMapsOutput(BaseInvocationOutput):
14+
normal_map: ImageField = OutputField(default=None, description="The generated normal map")
15+
roughness_map: ImageField = OutputField(default=None, description="The generated roughness map")
16+
displacement_map: ImageField = OutputField(default=None, description="The generated displacement map")
17+
18+
19+
@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0")
20+
class PBRMapsInvocation(BaseInvocation):
21+
"""Generate Normal, Displacement and Roughness Map from a given image"""
22+
23+
image: ImageField = InputField(default=None, description="Input image")
24+
tile_size: int = InputField(default=512, description="Tile size")
25+
border_mode: Literal["none", "seamless", "mirror", "replicate"] = InputField(
26+
default="none", description="Border mode to apply to eliminate any artifacts or seams"
27+
)
28+
29+
def invoke(self, context: InvocationContext) -> PBRMapsOutput:
30+
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
31+
32+
def loader(model_path: pathlib.Path):
33+
return PBRMapsGenerator.load_model(model_path, TorchDevice.choose_torch_device())
34+
35+
with (
36+
context.models.load_remote_model(NORMAL_MAP_MODEL, loader) as normal_map_model,
37+
context.models.load_remote_model(OTHER_MAP_MODEL, loader) as other_map_model,
38+
):
39+
assert isinstance(normal_map_model, PBR_RRDB_Net)
40+
assert isinstance(other_map_model, PBR_RRDB_Net)
41+
pbr_pipeline = PBRMapsGenerator(normal_map_model, other_map_model, TorchDevice.choose_torch_device())
42+
normal_map, roughness_map, displacement_map = pbr_pipeline.generate_maps(
43+
image_pil, self.tile_size, self.border_mode
44+
)
45+
46+
normal_map = context.images.save(normal_map)
47+
normal_map_field = ImageField(image_name=normal_map.image_name)
48+
49+
roughness_map = context.images.save(roughness_map)
50+
roughness_map_field = ImageField(image_name=roughness_map.image_name)
51+
52+
displacement_map = context.images.save(displacement_map)
53+
displacement_map_map_field = ImageField(image_name=displacement_map.image_name)
54+
55+
return PBRMapsOutput(
56+
normal_map=normal_map_field, roughness_map=roughness_map_field, displacement_map=displacement_map_map_field
57+
)

0 commit comments

Comments
 (0)