Skip to content

Commit c336f71

Browse files
authored
Merge pull request Fannovel16#76 from Visionatrix/fix/mps-border-unsupported
MPS: Unsupported Border padding mode
2 parents 483dfe6 + 31deca2 commit c336f71

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

vfi_models/rife/rife_arch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,16 @@ def warp(tenInput, tenFlow):
5656
if tenInput.type() == "torch.cuda.HalfTensor":
5757
g = g.half()
5858

59+
padding_mode = "border"
60+
if device.type == "mps":
61+
# https://github.com/pytorch/pytorch/issues/125098
62+
padding_mode = "zeros"
63+
g = g.clamp(-1, 1)
5964
return torch.nn.functional.grid_sample(
6065
input=tenInput,
6166
grid=g,
6267
mode="bilinear",
63-
padding_mode="border",
68+
padding_mode=padding_mode,
6469
align_corners=True,
6570
)
6671

0 commit comments

Comments
 (0)