Skip to content

Commit

Permalink
Improve MaskOutput dimension consistency (#7591)
Browse files Browse the repository at this point in the history
## Summary

This PR fixes an issue with mask dimension consistency. Prior to this
change, the following workflow would fail with `tuple out of range`
error:

<img width="1072" alt="image"
src="https://github.com/user-attachments/assets/d0a9e658-1d64-4db4-adee-973bbdaca745"
/>

### Before this PR

Dimension compatibility for invocations that take a mask input:
- `ApplyMaskTensorToImageInvocation`: 2 or 3
- `MaskTensorToImageInvocation`: 2 or 3
- `InvertTensorMaskInvocation`: 3

Mask dimension for invocations that produce a MaskOutput:
- `RectangleMaskInvocation`: 3
- `AlphaMaskToTensorInvocation`: 3
- `InvertTensorMaskInvocation`: 3
- `ImageMaskToTensorInvocation`: 3
- `SegmentAnythingInvocation`: 2

### After this PR (changes in bold)

Dimension compatibility for invocations that take a mask input:
- `ApplyMaskTensorToImageInvocation`: 2 or 3
- `MaskTensorToImageInvocation`: 2 or 3
- `InvertTensorMaskInvocation`: **2 or 3** <----------------

Mask dimension for invocations that produce a MaskOutput:
- `RectangleMaskInvocation`: 3
- `AlphaMaskToTensorInvocation`: 3
- `InvertTensorMaskInvocation`: 3
- `ImageMaskToTensorInvocation`: 3
- `SegmentAnythingInvocation`: **3** <-------------------


## QA Instructions

I tested the workflow in the PR description and this workflow:
<img width="872" alt="image"
src="https://github.com/user-attachments/assets/20496860-ce81-47c0-a46a-a611b73faa22"
/>


## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Jan 28, 2025
2 parents 6efd108 + 80c3d8b commit 9d2f8b4
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
11 changes: 10 additions & 1 deletion invokeai/app/invocations/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def invoke(self, context: InvocationContext) -> MaskOutput:
title="Invert Tensor Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
version="1.1.0",
classification=Classification.Beta,
)
class InvertTensorMaskInvocation(BaseInvocation):
Expand All @@ -96,6 +96,15 @@ class InvertTensorMaskInvocation(BaseInvocation):

def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.tensors.load(self.mask.tensor_name)

# Verify dtype and shape.
assert mask.dtype == torch.bool
assert mask.dim() in [2, 3]

# Unsqueeze the channel dimension if it is missing. The MaskOutput type expects a single channel.
if mask.dim() == 2:
mask = mask.unsqueeze(0)

inverted = ~mask

return MaskOutput(
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def invoke(self, context: InvocationContext) -> ColorOutput:
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor."""

# shape: [1, H, W], dtype: bool
mask: TensorField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
Expand Down
6 changes: 4 additions & 2 deletions invokeai/app/invocations/segment_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def to_list(self) -> list[list[int]]:
title="Segment Anything",
tags=["prompt", "segmentation"],
category="segmentation",
version="1.1.0",
version="1.2.0",
)
class SegmentAnythingInvocation(BaseInvocation):
"""Runs a Segment Anything Model."""
Expand Down Expand Up @@ -96,8 +96,10 @@ def invoke(self, context: InvocationContext) -> MaskOutput:
# masks contains bool values, so we merge them via max-reduce.
combined_mask, _ = torch.stack(masks).max(dim=0)

# Unsqueeze the channel dimension.
combined_mask = combined_mask.unsqueeze(0)
mask_tensor_name = context.tensors.save(combined_mask)
height, width = combined_mask.shape
_, height, width = combined_mask.shape
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)

@staticmethod
Expand Down
5 changes: 0 additions & 5 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7716,11 +7716,6 @@ export type components = {
* @description Gets the bounding box of the given mask image.
*/
GetMaskBoundingBoxInvocation: {
/**
* @description Optional metadata to be saved with the image
* @default null
*/
metadata?: components["schemas"]["MetadataField"] | null;
/**
* Id
* @description The id of this instance of an invocation. Must be unique among all instances of invocations.
Expand Down

0 comments on commit 9d2f8b4

Please sign in to comment.