From 3d013a87bf462f1ec630d4bb438f1e1388aa7583 Mon Sep 17 00:00:00 2001 From: Nerogar Date: Mon, 23 Dec 2024 17:35:29 +0100 Subject: [PATCH] fix a dtype issue when evaluating the sana transformer with a float16 autocast context --- src/diffusers/models/attention_processor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30e160dd2408..6c28b48f06d6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -6006,9 +6006,10 @@ def __call__( query, key, value = query.float(), key.float(), value.float() - value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0) - scores = torch.matmul(value, key) - hidden_states = torch.matmul(scores, query) + with torch.autocast(device_type=hidden_states.device.type, enabled=False): + value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0) + scores = torch.matmul(value, key) + hidden_states = torch.matmul(scores, query) hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15) hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)