-
-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dithering augmentation #1545
Add dithering augmentation #1545
Conversation
def dither(img: np.ndarray, nc: int) -> np.ndarray: | ||
img = img.copy() | ||
height = np.shape(img)[0] | ||
is_rgb = True if len(np.shape(img)) == 3 else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use is_rgb_image function from albumentations/augmentations/utils
always_apply: bool = False, | ||
p: float = 0.5, | ||
): | ||
super().__init__(always_apply, p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nv is positive, we sjould have a check here
|
||
def __init__( | ||
self, | ||
nc: int = 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nc is not really intuituve name.
num_colors
would be better
@preserve_shape | ||
def dither(img: np.ndarray, nc: int) -> np.ndarray: | ||
img = img.copy() | ||
height = np.shape(img)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think better to use class method img.shape[0]
. I think it's more readable
# | ||
# Use `tolist()` since operating on individual elements of an ndarray | ||
# is very slow compared to a normal list. | ||
channels = np.transpose(oldrow).tolist() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need to convert to list? Why do not use np.ndarray channels = oldrow.transpose()
? Result would be the same, but faster.
img[y], quant_errors = _apply_dithering_to_channel(img[y].tolist(), nc) | ||
|
||
if y < height - 1: | ||
zero_or_zeros = 0 if np.shape(quant_errors[-1]) == () else np.zeros_like(quant_errors[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just use np.zeros_like(quant_errors[-1])
this if else is useles
is_rgb = True if len(np.shape(img)) == 3 else False | ||
|
||
for y in range(height): | ||
oldrow = img[y].copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need clone there
@clipped | ||
@preserve_shape | ||
def dither(img: np.ndarray, nc: int) -> np.ndarray: | ||
img = img.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like beetter to use empty:
result = np.empty_like(img)
result[0] = img[0]
It is much faster
return img | ||
|
||
|
||
def _apply_dithering_to_channel(ch, nc): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing
for x in range(width - 1): | ||
oldval = ch[x] | ||
newval = round(oldval * (nc - 1)) / (nc - 1) | ||
ch[x] = newval | ||
quant_error[x] = oldval - newval | ||
ch[x + 1] += quant_error[x] * (7 / 16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's vectorize this
new_val = np.round(quant_error[:-1] * (nc - 1)) * (1 / (nc - 1))
quant_error = ch[:-1] - new_val
ch[:-1] = new_val
ch[1:] += quant_error * (7 / 16)
No description provided.