diff --git a/diffsynth/models/flux2_dit.py b/diffsynth/models/flux2_dit.py index 316cf08..a1bd02a 100644 --- a/diffsynth/models/flux2_dit.py +++ b/diffsynth/models/flux2_dit.py @@ -407,6 +407,7 @@ class Flux2AttnProcessor: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) hidden_states = attention_forward( query, key, @@ -536,6 +537,7 @@ class Flux2ParallelSelfAttnProcessor: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + query, key, value = query.to(hidden_states.dtype), key.to(hidden_states.dtype), value.to(hidden_states.dtype) hidden_states = attention_forward( query, key,