mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support animatediff on sdxl
This commit is contained in:
@@ -15,6 +15,7 @@ from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
from .sd_controlnet import SDControlNet
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
@@ -61,6 +62,10 @@ class ModelManager:
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff_xl(self, state_dict):
|
||||
param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_sd_lora(self, state_dict):
|
||||
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
||||
return param_name in state_dict
|
||||
@@ -153,6 +158,14 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_animatediff_xl(self, state_dict, file_path=""):
|
||||
component = "motion_modules_xl"
|
||||
model = SDXLMotionModel()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_beautiful_prompt(self, state_dict, file_path=""):
|
||||
component = "beautiful_prompt"
|
||||
from transformers import AutoModelForCausalLM
|
||||
@@ -218,6 +231,8 @@ class ModelManager:
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff(state_dict):
|
||||
self.load_animatediff(state_dict, file_path=file_path)
|
||||
elif self.is_animatediff_xl(state_dict):
|
||||
self.load_animatediff_xl(state_dict, file_path=file_path)
|
||||
elif self.is_controlnet(state_dict):
|
||||
self.load_controlnet(state_dict, file_path=file_path)
|
||||
elif self.is_stabe_diffusion_xl(state_dict):
|
||||
|
||||
103
diffsynth/models/sdxl_motion.py
Normal file
103
diffsynth/models/sdxl_motion.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from .sd_motion import TemporalBlock
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class SDXLMotionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
])
|
||||
self.call_block_id = {
|
||||
0: 0,
|
||||
2: 1,
|
||||
7: 2,
|
||||
10: 3,
|
||||
15: 4,
|
||||
18: 5,
|
||||
25: 6,
|
||||
28: 7,
|
||||
31: 8,
|
||||
35: 9,
|
||||
38: 10,
|
||||
41: 11,
|
||||
44: 12,
|
||||
46: 13,
|
||||
48: 14,
|
||||
}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def state_dict_converter(self):
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
class SDMotionModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"norm": "norm",
|
||||
"proj_in": "proj_in",
|
||||
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
||||
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
||||
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
||||
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
||||
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
||||
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
||||
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
||||
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
||||
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
||||
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
||||
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
||||
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
||||
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
||||
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
||||
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
||||
state_dict_ = {}
|
||||
last_prefix, module_id = "", -1
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
prefix_index = names.index("temporal_transformer") + 1
|
||||
prefix = ".".join(names[:prefix_index])
|
||||
if prefix != last_prefix:
|
||||
last_prefix = prefix
|
||||
module_id += 1
|
||||
middle_name = ".".join(names[prefix_index:-1])
|
||||
suffix = names[-1]
|
||||
if "pos_encoder" in names:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
||||
else:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||
state_dict_[rename] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,4 +1,5 @@
|
||||
from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from ..models import SDUNet, SDMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock
|
||||
from ..models.tiler import TileWorker
|
||||
from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel
|
||||
from ..models.sd_unet import PushBlock, PopBlock
|
||||
from ..controlnets import MultiControlNetManager
|
||||
|
||||
|
||||
@@ -107,3 +106,65 @@ def lets_dance(
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
|
||||
def lets_dance_xl(
|
||||
unet: SDXLUNet,
|
||||
motion_modules: SDXLMotionModel = None,
|
||||
controlnet: MultiControlNetManager = None,
|
||||
sample = None,
|
||||
add_time_id = None,
|
||||
add_text_embeds = None,
|
||||
timestep = None,
|
||||
encoder_hidden_states = None,
|
||||
controlnet_frames = None,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
device = "cuda",
|
||||
vram_limit_level = 0,
|
||||
):
|
||||
# 2. time
|
||||
t_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||
t_emb = unet.time_embedding(t_emb)
|
||||
|
||||
time_embeds = unet.add_time_proj(add_time_id)
|
||||
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
||||
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(sample.dtype)
|
||||
add_embeds = unet.add_time_embedding(add_embeds)
|
||||
|
||||
time_emb = t_emb + add_embeds
|
||||
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
for block_id, block in enumerate(unet.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
# 4.2 AnimateDiff
|
||||
if motion_modules is not None:
|
||||
if block_id in motion_modules.call_block_id:
|
||||
motion_module_id = motion_modules.call_block_id[block_id]
|
||||
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
batch_size=1
|
||||
)
|
||||
|
||||
# 5. output
|
||||
hidden_states = unet.conv_norm_out(hidden_states)
|
||||
hidden_states = unet.conv_act(hidden_states)
|
||||
hidden_states = unet.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -30,8 +30,6 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
# load textual inversion
|
||||
self.prompter.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
@@ -117,10 +115,7 @@ class SDXLImagePipeline(torch.nn.Module):
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
|
||||
190
diffsynth/pipelines/stable_diffusion_xl_video.py
Normal file
190
diffsynth/pipelines/stable_diffusion_xl_video.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel
|
||||
from .dancer import lets_dance_xl
|
||||
# TODO: SDXL ControlNet
|
||||
from ..prompts import SDXLPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SDXLVideoPipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear")
|
||||
self.prompter = SDXLPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: SDXLTextEncoder = None
|
||||
self.text_encoder_2: SDXLTextEncoder2 = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
# TODO: SDXL ControlNet
|
||||
self.motion_modules: SDXLMotionModel = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.text_encoder
|
||||
self.text_encoder_2 = model_manager.text_encoder_2
|
||||
self.unet = model_manager.unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs):
|
||||
# TODO: SDXL ControlNet
|
||||
pass
|
||||
|
||||
|
||||
def fetch_motion_modules(self, model_manager: ModelManager):
|
||||
if "motion_modules_xl" in model_manager.model:
|
||||
self.motion_modules = model_manager.motion_modules_xl
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs):
|
||||
pipe = SDXLVideoPipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
use_animatediff="motion_modules_xl" in model_manager.model
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_motion_modules(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32):
|
||||
images = [
|
||||
self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
for frame_id in range(latents.shape[0])
|
||||
]
|
||||
return images
|
||||
|
||||
|
||||
def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32):
|
||||
latents = []
|
||||
for image in processed_images:
|
||||
image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu()
|
||||
latents.append(latent)
|
||||
latents = torch.concat(latents, dim=0)
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
num_frames=None,
|
||||
input_frames=None,
|
||||
controlnet_frames=None,
|
||||
denoising_strength=1.0,
|
||||
height=512,
|
||||
width=512,
|
||||
num_inference_steps=20,
|
||||
animatediff_batch_size = 16,
|
||||
animatediff_stride = 8,
|
||||
unet_batch_size = 1,
|
||||
controlnet_batch_size = 1,
|
||||
cross_frame_attention = False,
|
||||
smoother=None,
|
||||
smoother_progress_ids=[],
|
||||
vram_limit_level=0,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if self.motion_modules is None:
|
||||
noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1)
|
||||
else:
|
||||
noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype)
|
||||
if input_frames is None or denoising_strength == 1.0:
|
||||
latents = noise
|
||||
else:
|
||||
latents = self.encode_images(input_frames)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip, clip_skip_2=clip_skip_2,
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet, motion_modules=self.motion_modules, controlnet=None,
|
||||
sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||
cross_frame_attention=cross_frame_attention,
|
||||
device=self.device, vram_limit_level=vram_limit_level
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_images(latents.to(torch.float32))
|
||||
|
||||
return image
|
||||
28
examples/sdxl_text_to_video.py
Normal file
28
examples/sdxl_text_to_video.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from diffsynth import ModelManager, SDXLVideoPipeline, save_video
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors)
|
||||
# `models/AnimateDiff/mm_sdxl_v10_beta.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sdxl_v10_beta.ckpt)
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/stable_diffusion_xl/sd_xl_base_1.0.safetensors",
|
||||
"models/AnimateDiff/mm_sdxl_v10_beta.ckpt"
|
||||
])
|
||||
pipe = SDXLVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "A panda standing on a surfboard in the ocean in sunset, 4k, high resolution.Realistic, Cinematic, high resolution"
|
||||
negative_prompt = ""
|
||||
|
||||
torch.manual_seed(0)
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=8.5,
|
||||
height=1024, width=1024, num_frames=16,
|
||||
num_inference_steps=100,
|
||||
)
|
||||
save_video(video, "video.mp4", fps=16)
|
||||
Reference in New Issue
Block a user