mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
support animatediff on sdxl
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from ..models import SDUNet, SDMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock
|
||||
from ..models.tiler import TileWorker
|
||||
from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock
|
||||
from ..controlnets import MultiControlNetManager
|
||||
|
||||
|
||||
@@ -107,3 +106,65 @@ def lets_dance(
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
def lets_dance_xl(
|
||||
unet: SDXLUNet,
|
||||
motion_modules: SDXLMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
add_time_id = None,
|
||||
add_text_embeds = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 2. time
|
||||
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = unet.time_embedding(t_emb)
|
||||
|
||||
time_embeds = unet.add_time_proj(add_time_id)
|
||||
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
||||
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(sample.dtype)
|
||||
add_embeds = unet.add_time_embedding(add_embeds)
|
||||
|
||||
time_emb = t_emb + add_embeds
|
||||
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
for block_id, block in enumerate(unet.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
# 4.2 AnimateDiff
|
||||
if motion_modules is not None:
|
||||
if block_id in motion_modules.call_block_id:
|
||||
motion_module_id = motion_modules.call_block_id[block_id]
|
||||
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
batch_size=1
|
||||
)
|
||||
|
||||
# 5. output
|
||||
hidden_states = unet.conv_norm_out(hidden_states)
|
||||
hidden_states = unet.conv_act(hidden_states)
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
Reference in New Issue
Block a user