mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
train
This commit is contained in:
@@ -680,6 +680,7 @@ def lets_dance_flux(
|
||||
step1x_mask=None,
|
||||
step1x_reference_latents=None,
|
||||
tea_cache: TeaCache = None,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
@@ -774,20 +775,32 @@ def lets_dance_flux(
|
||||
tea_cache_update = tea_cache.check(dit, hidden_states, conditioning)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
if tea_cache_update:
|
||||
hidden_states = tea_cache.update(hidden_states)
|
||||
else:
|
||||
# Joint Blocks
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id, None),
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
@@ -796,14 +809,21 @@ def lets_dance_flux(
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
num_joint_blocks = len(dit.blocks)
|
||||
for block_id, block in enumerate(dit.single_blocks):
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
if use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, ipadapter_kwargs_list.get(block_id + num_joint_blocks, None),
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(
|
||||
hidden_states,
|
||||
prompt_emb,
|
||||
conditioning,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_frames is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
|
||||
Reference in New Issue
Block a user