From 3b5bbb577386b5eb13d5850dedcb13c183ffce9f Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Sat, 11 May 2024 22:48:02 +0800 Subject: [PATCH] support animatediff on sdxl --- diffsynth/models/__init__.py | 15 ++ diffsynth/models/sdxl_motion.py | 103 ++++++++++ diffsynth/pipelines/__init__.py | 3 +- diffsynth/pipelines/dancer.py | 67 +++++- diffsynth/pipelines/stable_diffusion_xl.py | 7 +- .../pipelines/stable_diffusion_xl_video.py | 190 ++++++++++++++++++ examples/sdxl_text_to_video.py | 28 +++ 7 files changed, 403 insertions(+), 10 deletions(-) create mode 100644 diffsynth/models/sdxl_motion.py create mode 100644 diffsynth/pipelines/stable_diffusion_xl_video.py create mode 100644 examples/sdxl_text_to_video.py diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 889eec7..9f90505 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -15,6 +15,7 @@ from .sdxl_vae_encoder import SDXLVAEEncoder from .sd_controlnet import SDControlNet from .sd_motion import SDMotionModel +from .sdxl_motion import SDXLMotionModel from .svd_image_encoder import SVDImageEncoder from .svd_unet import SVDUNet @@ -61,6 +62,10 @@ class ModelManager: param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight" return param_name in state_dict + def is_animatediff_xl(self, state_dict): + param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight" + return param_name in state_dict + def is_sd_lora(self, state_dict): param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight" return param_name in state_dict @@ -153,6 +158,14 @@ class ModelManager: self.model[component] = model self.model_path[component] = file_path + def load_animatediff_xl(self, state_dict, file_path=""): + component = "motion_modules_xl" + model = SDXLMotionModel() + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + def load_beautiful_prompt(self, state_dict, file_path=""): component = "beautiful_prompt" from transformers import AutoModelForCausalLM @@ -218,6 +231,8 @@ class ModelManager: self.load_stable_video_diffusion(state_dict, file_path=file_path) elif self.is_animatediff(state_dict): self.load_animatediff(state_dict, file_path=file_path) + elif self.is_animatediff_xl(state_dict): + self.load_animatediff_xl(state_dict, file_path=file_path) elif self.is_controlnet(state_dict): self.load_controlnet(state_dict, file_path=file_path) elif self.is_stabe_diffusion_xl(state_dict): diff --git a/diffsynth/models/sdxl_motion.py b/diffsynth/models/sdxl_motion.py new file mode 100644 index 0000000..329b1c6 --- /dev/null +++ b/diffsynth/models/sdxl_motion.py @@ -0,0 +1,103 @@ +from .sd_motion import TemporalBlock +import torch + + + +class SDXLMotionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.motion_modules = torch.nn.ModuleList([ + TemporalBlock(8, 320//8, 320, eps=1e-6), + TemporalBlock(8, 320//8, 320, eps=1e-6), + + TemporalBlock(8, 640//8, 640, eps=1e-6), + TemporalBlock(8, 640//8, 640, eps=1e-6), + + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + TemporalBlock(8, 1280//8, 1280, eps=1e-6), + + TemporalBlock(8, 640//8, 640, eps=1e-6), + TemporalBlock(8, 640//8, 640, eps=1e-6), + TemporalBlock(8, 640//8, 640, eps=1e-6), + + TemporalBlock(8, 320//8, 320, eps=1e-6), + TemporalBlock(8, 320//8, 320, eps=1e-6), + TemporalBlock(8, 320//8, 320, eps=1e-6), + ]) + self.call_block_id = { + 0: 0, + 2: 1, + 7: 2, + 10: 3, + 15: 4, + 18: 5, + 25: 6, + 28: 7, + 31: 8, + 35: 9, + 38: 10, + 41: 11, + 44: 12, + 46: 13, + 48: 14, + } + + def forward(self): + pass + + def state_dict_converter(self): + return SDMotionModelStateDictConverter() + + +class SDMotionModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "norm": "norm", + "proj_in": "proj_in", + "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q", + "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k", + "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v", + "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out", + "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1", + "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q", + "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k", + "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v", + "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out", + "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2", + "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1", + "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2", + "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj", + "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff", + "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3", + "proj_out": "proj_out", + } + name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")]) + name_list += sorted([i for i in state_dict if i.startswith("mid_block.")]) + name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")]) + state_dict_ = {} + last_prefix, module_id = "", -1 + for name in name_list: + names = name.split(".") + prefix_index = names.index("temporal_transformer") + 1 + prefix = ".".join(names[:prefix_index]) + if prefix != last_prefix: + last_prefix = prefix + module_id += 1 + middle_name = ".".join(names[prefix_index:-1]) + suffix = names[-1] + if "pos_encoder" in names: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]]) + else: + rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) + state_dict_[rename] = state_dict[name] + return state_dict_ + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index 80dda9f..8e7515d 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -1,4 +1,5 @@ from .stable_diffusion import SDImagePipeline from .stable_diffusion_xl import SDXLImagePipeline from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner -from .stable_video_diffusion import SVDVideoPipeline +from .stable_diffusion_xl_video import SDXLVideoPipeline +from .stable_video_diffusion import SVDVideoPipeline \ No newline at end of file diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py index 91c2fa7..d19e746 100644 --- a/diffsynth/pipelines/dancer.py +++ b/diffsynth/pipelines/dancer.py @@ -1,7 +1,6 @@ import torch -from ..models import SDUNet, SDMotionModel -from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock -from ..models.tiler import TileWorker +from ..models import SDUNet, SDMotionModel, SDXLUNet, SDXLMotionModel +from ..models.sd_unet import PushBlock, PopBlock from ..controlnets import MultiControlNetManager @@ -107,3 +106,65 @@ def lets_dance( hidden_states = unet.conv_out(hidden_states) return hidden_states + + + + +def lets_dance_xl( + unet: SDXLUNet, + motion_modules: SDXLMotionModel = None, + controlnet: MultiControlNetManager = None, + sample = None, + add_time_id = None, + add_text_embeds = None, + timestep = None, + encoder_hidden_states = None, + controlnet_frames = None, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + tiled=False, + tile_size=64, + tile_stride=32, + device = "cuda", + vram_limit_level = 0, +): + # 2. time + t_emb = unet.time_proj(timestep[None]).to(sample.dtype) + t_emb = unet.time_embedding(t_emb) + + time_embeds = unet.add_time_proj(add_time_id) + time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1)) + add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(sample.dtype) + add_embeds = unet.add_time_embedding(add_embeds) + + time_emb = t_emb + add_embeds + + # 3. pre-process + height, width = sample.shape[2], sample.shape[3] + hidden_states = unet.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 4. blocks + for block_id, block in enumerate(unet.blocks): + hidden_states, time_emb, text_emb, res_stack = block( + hidden_states, time_emb, text_emb, res_stack, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) + # 4.2 AnimateDiff + if motion_modules is not None: + if block_id in motion_modules.call_block_id: + motion_module_id = motion_modules.call_block_id[block_id] + hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( + hidden_states, time_emb, text_emb, res_stack, + batch_size=1 + ) + + # 5. output + hidden_states = unet.conv_norm_out(hidden_states) + hidden_states = unet.conv_act(hidden_states) + hidden_states = unet.conv_out(hidden_states) + + return hidden_states \ No newline at end of file diff --git a/diffsynth/pipelines/stable_diffusion_xl.py b/diffsynth/pipelines/stable_diffusion_xl.py index 0fec886..246a361 100644 --- a/diffsynth/pipelines/stable_diffusion_xl.py +++ b/diffsynth/pipelines/stable_diffusion_xl.py @@ -30,8 +30,6 @@ class SDXLImagePipeline(torch.nn.Module): self.unet = model_manager.unet self.vae_decoder = model_manager.vae_decoder self.vae_encoder = model_manager.vae_encoder - # load textual inversion - self.prompter.load_textual_inversion(model_manager.textual_inversion_dict) def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): @@ -117,10 +115,7 @@ class SDXLImagePipeline(torch.nn.Module): device=self.device, positive=False, ) - - # Prepare scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength) - + # Prepare positional id add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) diff --git a/diffsynth/pipelines/stable_diffusion_xl_video.py b/diffsynth/pipelines/stable_diffusion_xl_video.py new file mode 100644 index 0000000..ceb4c67 --- /dev/null +++ b/diffsynth/pipelines/stable_diffusion_xl_video.py @@ -0,0 +1,190 @@ +from ..models import ModelManager, SDXLTextEncoder, SDXLTextEncoder2, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLMotionModel +from .dancer import lets_dance_xl +# TODO: SDXL ControlNet +from ..prompts import SDXLPrompter +from ..schedulers import EnhancedDDIMScheduler +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np + + +class SDXLVideoPipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16, use_animatediff=True): + super().__init__() + self.scheduler = EnhancedDDIMScheduler(beta_schedule="linear" if use_animatediff else "scaled_linear") + self.prompter = SDXLPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: SDXLTextEncoder = None + self.text_encoder_2: SDXLTextEncoder2 = None + self.unet: SDXLUNet = None + self.vae_decoder: SDXLVAEDecoder = None + self.vae_encoder: SDXLVAEEncoder = None + # TODO: SDXL ControlNet + self.motion_modules: SDXLMotionModel = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.text_encoder + self.text_encoder_2 = model_manager.text_encoder_2 + self.unet = model_manager.unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + + + def fetch_controlnet_models(self, model_manager: ModelManager, **kwargs): + # TODO: SDXL ControlNet + pass + + + def fetch_motion_modules(self, model_manager: ModelManager): + if "motion_modules_xl" in model_manager.model: + self.motion_modules = model_manager.motion_modules_xl + + + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, controlnet_config_units = [], **kwargs): + pipe = SDXLVideoPipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + use_animatediff="motion_modules_xl" in model_manager.model + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_motion_modules(model_manager) + pipe.fetch_prompter(model_manager) + pipe.fetch_controlnet_models(model_manager, controlnet_config_units=controlnet_config_units) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + def decode_images(self, latents, tiled=False, tile_size=64, tile_stride=32): + images = [ + self.decode_image(latents[frame_id: frame_id+1], tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + for frame_id in range(latents.shape[0]) + ] + return images + + + def encode_images(self, processed_images, tiled=False, tile_size=64, tile_stride=32): + latents = [] + for image in processed_images: + image = self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) + latent = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).cpu() + latents.append(latent) + latents = torch.concat(latents, dim=0) + return latents + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=1, + clip_skip_2=2, + num_frames=None, + input_frames=None, + controlnet_frames=None, + denoising_strength=1.0, + height=512, + width=512, + num_inference_steps=20, + animatediff_batch_size = 16, + animatediff_stride = 8, + unet_batch_size = 1, + controlnet_batch_size = 1, + cross_frame_attention = False, + smoother=None, + smoother_progress_ids=[], + vram_limit_level=0, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + if self.motion_modules is None: + noise = torch.randn((1, 4, height//8, width//8), device="cpu", dtype=self.torch_dtype).repeat(num_frames, 1, 1, 1) + else: + noise = torch.randn((num_frames, 4, height//8, width//8), device="cuda", dtype=self.torch_dtype) + if input_frames is None or denoising_strength == 1.0: + latents = noise + else: + latents = self.encode_images(input_frames) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + + # Encode prompts + add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device, + positive=True, + ) + if cfg_scale != 1.0: + add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( + self.text_encoder, + self.text_encoder_2, + negative_prompt, + clip_skip=clip_skip, clip_skip_2=clip_skip_2, + device=self.device, + positive=False, + ) + + # Prepare positional id + add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # Classifier-free guidance + noise_pred_posi = lets_dance_xl( + self.unet, motion_modules=self.motion_modules, controlnet=None, + sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, + timestep=timestep, encoder_hidden_states=prompt_emb_posi, controlnet_frames=controlnet_frames, + cross_frame_attention=cross_frame_attention, + device=self.device, vram_limit_level=vram_limit_level + ) + if cfg_scale != 1.0: + noise_pred_nega = lets_dance_xl( + self.unet, motion_modules=self.motion_modules, controlnet=None, + sample=latents, add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, + timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, + cross_frame_attention=cross_frame_attention, + device=self.device, vram_limit_level=vram_limit_level + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + latents = self.scheduler.step(noise_pred, timestep, latents) + + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + image = self.decode_images(latents.to(torch.float32)) + + return image diff --git a/examples/sdxl_text_to_video.py b/examples/sdxl_text_to_video.py new file mode 100644 index 0000000..96fc7a5 --- /dev/null +++ b/examples/sdxl_text_to_video.py @@ -0,0 +1,28 @@ +from diffsynth import ModelManager, SDXLVideoPipeline, save_video +import torch + + +# Download models +# `models/stable_diffusion_xl/sd_xl_base_1.0.safetensors`: [link](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors) +# `models/AnimateDiff/mm_sdxl_v10_beta.ckpt`: [link](https://huggingface.co/guoyww/animatediff/resolve/main/mm_sdxl_v10_beta.ckpt) + + +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda") +model_manager.load_models([ + "models/stable_diffusion_xl/sd_xl_base_1.0.safetensors", + "models/AnimateDiff/mm_sdxl_v10_beta.ckpt" +]) +pipe = SDXLVideoPipeline.from_model_manager(model_manager) + +prompt = "A panda standing on a surfboard in the ocean in sunset, 4k, high resolution.Realistic, Cinematic, high resolution" +negative_prompt = "" + +torch.manual_seed(0) +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + cfg_scale=8.5, + height=1024, width=1024, num_frames=16, + num_inference_steps=100, +) +save_video(video, "video.mp4", fps=16)