fix base model

This commit is contained in:
Artiprocher
2025-04-14 13:11:45 +08:00
parent 7a06a58f49
commit 2a5355b7cb
4 changed files with 47 additions and 9 deletions

View File

@@ -709,6 +709,7 @@ def lets_dance_flux(
hidden_states_ref = dit.patchify(hidden_states_ref)
hidden_states_ref = dit.x_embedder(hidden_states_ref)
hidden_states_ref = rearrange(hidden_states_ref, "B L C -> 1 (B L) C")
hidden_states_ref = reference_embedder.proj(hidden_states_ref)
hidden_states = torch.cat((hidden_states, hidden_states_ref), dim=1)
# TeaCache