From a076adf59217495d1c3d13407112e8200660a5b9 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 26 Jul 2024 14:35:18 +0800 Subject: [PATCH] ExVideo for AnimateDiff --- ExVideo_animatediff_train.py | 267 ++++++++++++++++++ diffsynth/models/__init__.py | 6 +- diffsynth/models/sd_motion.py | 139 ++++++--- diffsynth/models/sd_motion_ex.py | 115 ++++++++ diffsynth/pipelines/stable_diffusion_video.py | 33 ++- diffsynth/prompts/sd_prompter.py | 4 +- diffsynth/prompts/utils.py | 4 +- 7 files changed, 520 insertions(+), 48 deletions(-) create mode 100644 ExVideo_animatediff_train.py create mode 100644 diffsynth/models/sd_motion_ex.py diff --git a/ExVideo_animatediff_train.py b/ExVideo_animatediff_train.py new file mode 100644 index 0000000..4e50c7f --- /dev/null +++ b/ExVideo_animatediff_train.py @@ -0,0 +1,267 @@ +import torch, json, os, imageio +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +from diffsynth import ModelManager, EnhancedDDIMScheduler, SDVideoPipeline, SDUNet, load_state_dict, SDMotionModel + + + +def lets_dance( + unet: SDUNet, + motion_modules: SDMotionModel, + sample, + timestep, + encoder_hidden_states, + use_gradient_checkpointing=False, +): + # 1. ControlNet (skip) + # 2. time + time_emb = unet.time_proj(timestep[None]).to(sample.dtype) + time_emb = unet.time_embedding(time_emb) + + # 3. pre-process + hidden_states = unet.conv_in(sample) + text_emb = encoder_hidden_states + res_stack = [hidden_states] + + # 4. blocks + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + for block_id, block in enumerate(unet.blocks): + # 4.1 UNet + if use_gradient_checkpointing: + hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, time_emb, text_emb, res_stack, + use_reentrant=False, + ) + else: + hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) + # 4.2 AnimateDiff + if block_id in motion_modules.call_block_id: + motion_module_id = motion_modules.call_block_id[block_id] + if use_gradient_checkpointing: + hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_modules.motion_modules[motion_module_id]), + hidden_states, time_emb, text_emb, res_stack, + use_reentrant=False, + ) + else: + hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](hidden_states, time_emb, text_emb, res_stack) + + # 5. output + hidden_states = unet.conv_norm_out(hidden_states) + hidden_states = unet.conv_act(hidden_states) + hidden_states = unet.conv_out(hidden_states) + + return hidden_states + + + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.path = [os.path.join(base_path, i["path"]) for i in metadata] + self.text = [i["text"] for i in metadata] + self.steps_per_epoch = steps_per_epoch + self.training_shapes = training_shapes + + self.frame_process = [] + for max_num_frames, interval, num_frames, height, width in training_shapes: + self.frame_process.append(v2.Compose([ + v2.Resize(size=max(height, width), antialias=True), + v2.CenterCrop(size=(height, width)), + v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]), + ])) + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = torch.tensor(frame, dtype=torch.float32) + frame = rearrange(frame, "H W C -> 1 C H W") + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.concat(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + return frames + + + def load_video(self, file_path, training_shape_id): + data = {} + max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id] + frame_process = self.frame_process[training_shape_id] + start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0] + frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process) + if frames is None: + return None + else: + data[f"frames_{training_shape_id}"] = frames + data[f"start_frame_id_{training_shape_id}"] = start_frame_id + return data + + + def __getitem__(self, index): + video_data = {} + for training_shape_id in range(len(self.training_shapes)): + while True: + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + text = self.text[data_id] + if isinstance(text, list): + text = text[torch.randint(0, len(text), (1,))[0]] + video_file = self.path[data_id] + try: + data = self.load_video(video_file, training_shape_id) + except: + data = None + if data is not None: + data[f"text_{training_shape_id}"] = text + break + video_data.update(data) + return video_data + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModel(pl.LightningModule): + def __init__(self, learning_rate=1e-5, sd_ckpt_path=None): + super().__init__() + # Load models + model_manager = ModelManager(torch_dtype=torch.float16, device="cpu") + model_manager.load_stable_diffusion(load_state_dict(sd_ckpt_path)) + + # Initialize motion modules + model_manager.model["motion_modules"] = SDMotionModel().to(dtype=self.dtype, device=self.device) + + # Build pipeline + self.pipe = SDVideoPipeline.from_model_manager(model_manager) + self.pipe.vae_encoder.eval() + self.pipe.vae_encoder.requires_grad_(False) + + self.pipe.vae_decoder.eval() + self.pipe.vae_decoder.requires_grad_(False) + + self.pipe.text_encoder.eval() + self.pipe.text_encoder.requires_grad_(False) + + self.pipe.unet.eval() + self.pipe.unet.requires_grad_(False) + + self.pipe.motion_modules.train() + self.pipe.motion_modules.requires_grad_(True) + + # Reset the scheduler + self.pipe.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear") + self.pipe.scheduler.set_timesteps(1000) + + # Other parameters + self.learning_rate = learning_rate + + + def encode_video_with_vae(self, video): + video = video.to(device=self.device, dtype=self.dtype) + video = video.unsqueeze(0) + latents = self.pipe.vae_encoder.encode_video(video, batch_size=16) + latents = rearrange(latents[0], "C T H W -> T C H W") + return latents + + + def calculate_loss(self, prompt, frames): + with torch.no_grad(): + # Call video encoder + latents = self.encode_video_with_vae(frames) + + # Call text encoder + prompt_embs = self.pipe.prompter.encode_prompt(self.pipe.text_encoder, prompt, device=self.device, max_length=77) + prompt_embs = prompt_embs.repeat(latents.shape[0], 1, 1) + + # Call scheduler + timestep = torch.randint(0, len(self.pipe.scheduler.timesteps), (1,), device=self.device)[0] + noise = torch.randn_like(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # Calculate loss + model_pred = lets_dance( + self.pipe.unet, self.pipe.motion_modules, + sample=noisy_latents, encoder_hidden_states=prompt_embs, timestep=timestep + ) + loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean") + return loss + + + def training_step(self, batch, batch_idx): + # Loss + frames = batch["frames_0"][0] + prompt = batch["text_0"][0] + loss = self.calculate_loss(prompt, frames) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.pipe.motion_modules.parameters(), lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_modules.named_parameters())) + trainable_param_names = [named_param[0] for named_param in trainable_param_names] + checkpoint["trainable_param_names"] = trainable_param_names + + + +if __name__ == '__main__': + # dataset and data loader + dataset = TextVideoDataset( + "/data/zhongjie/datasets/opensoraplan/data/processed", + "/data/zhongjie/datasets/opensoraplan/data/processed/metadata.json", + training_shapes=[(16, 1, 16, 512, 512)], + steps_per_epoch=7*10000, + ) + train_loader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=4 + ) + + # model + model = LightningModel( + learning_rate=1e-5, + sd_ckpt_path="models/stable_diffusion/v1-5-pruned-emaonly.safetensors", + ) + + # train + trainer = pl.Trainer( + max_epochs=100000, + accelerator="gpu", + devices="auto", + strategy="deepspeed_stage_1", + precision="16-mixed", + default_root_dir="/data/zhongjie/models/train_extended_animatediff", + accumulate_grad_batches=1, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + ) + trainer.fit( + model=model, + train_dataloaders=train_loader, + ckpt_path=None + ) diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 21757e9..29e0450 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -194,10 +194,10 @@ class ModelManager: self.model[component].append(model) self.model_path[component].append(file_path) - def load_animatediff(self, state_dict, file_path=""): + def load_animatediff(self, state_dict, file_path="", add_positional_conv=None): component = "motion_modules" - model = SDMotionModel() - model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) + model = SDMotionModel(add_positional_conv=add_positional_conv) + model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv)) model.to(self.torch_dtype).to(self.device) self.model[component] = model self.model_path[component] = file_path diff --git a/diffsynth/models/sd_motion.py b/diffsynth/models/sd_motion.py index b313e62..0692736 100644 --- a/diffsynth/models/sd_motion.py +++ b/diffsynth/models/sd_motion.py @@ -1,20 +1,28 @@ from .sd_unet import SDUNet, Attention, GEGLU +from .svd_unet import get_timestep_embedding import torch from einops import rearrange, repeat class TemporalTransformerBlock(torch.nn.Module): - def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32): + def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None): super().__init__() + self.add_positional_conv = add_positional_conv # 1. Self-Attn - self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) + emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim) + self.pe1 = torch.nn.Parameter(emb) + if add_positional_conv: + self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect") self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) # 2. Cross-Attn - self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) + emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim) + self.pe2 = torch.nn.Parameter(emb) + if add_positional_conv: + self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect") self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) @@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module): self.ff = torch.nn.Linear(dim * 4, dim) + def frame_id_to_position_id(self, frame_id, max_id, repeat_length): + if frame_id < max_id: + position_id = frame_id + else: + position_id = (frame_id - max_id) % (repeat_length * 2) + if position_id < repeat_length: + position_id = max_id - 2 - position_id + else: + position_id = max_id - 2 * repeat_length + position_id + return position_id + + + def positional_ids(self, num_frames): + max_id = self.pe1.shape[1] + positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)]) + return positional_ids + + def forward(self, hidden_states, batch_size=1): # 1. Self-Attention norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) - attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]]) + norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])] + if self.add_positional_conv: + norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size) + norm_hidden_states = self.positional_conv_1(norm_hidden_states) + norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size) + attn_output = self.attn1(norm_hidden_states) attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) hidden_states = attn_output + hidden_states # 2. Cross-Attention norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) - attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]]) + norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])] + if self.add_positional_conv: + norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size) + norm_hidden_states = self.positional_conv_2(norm_hidden_states) + norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size) + attn_output = self.attn2(norm_hidden_states) attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) hidden_states = attn_output + hidden_states @@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module): class TemporalBlock(torch.nn.Module): - def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): + def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None): super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module): TemporalTransformerBlock( inner_dim, num_attention_heads, - attention_head_dim + attention_head_dim, + max_position_embeddings=32 if add_positional_conv is None else add_positional_conv, + add_positional_conv=add_positional_conv ) for d in range(num_layers) ]) @@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module): class SDMotionModel(torch.nn.Module): - def __init__(self): + def __init__(self, add_positional_conv=None): super().__init__() self.motion_modules = torch.nn.ModuleList([ - TemporalBlock(8, 40, 320, eps=1e-6), - TemporalBlock(8, 40, 320, eps=1e-6), - TemporalBlock(8, 80, 640, eps=1e-6), - TemporalBlock(8, 80, 640, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 160, 1280, eps=1e-6), - TemporalBlock(8, 80, 640, eps=1e-6), - TemporalBlock(8, 80, 640, eps=1e-6), - TemporalBlock(8, 80, 640, eps=1e-6), - TemporalBlock(8, 40, 320, eps=1e-6), - TemporalBlock(8, 40, 320, eps=1e-6), - TemporalBlock(8, 40, 320, eps=1e-6), + TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), + TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv), ]) self.call_block_id = { 1: 0, @@ -152,7 +190,42 @@ class SDMotionModelStateDictConverter: def __init__(self): pass - def from_diffusers(self, state_dict): + def frame_id_to_position_id(self, frame_id, max_id, repeat_length): + if frame_id < max_id: + position_id = frame_id + else: + position_id = (frame_id - max_id) % (repeat_length * 2) + if position_id < repeat_length: + position_id = max_id - 2 - position_id + else: + position_id = max_id - 2 * repeat_length + position_id + return position_id + + def process_positional_conv_parameters(self, state_dict, add_positional_conv): + ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)] + for i in range(21): + # Extend positional embedding + name = f"motion_modules.{i}.transformer_blocks.0.pe1" + state_dict[name] = state_dict[name][:, ids] + name = f"motion_modules.{i}.transformer_blocks.0.pe2" + state_dict[name] = state_dict[name][:, ids] + # add post convolution + dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1] + name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias" + state_dict[name] = torch.zeros((dim,)) + name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias" + state_dict[name] = torch.zeros((dim,)) + name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight" + param = torch.zeros((dim, dim, 3)) + param[:, :, 1] = torch.eye(dim, dim) + state_dict[name] = param + name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight" + param = torch.zeros((dim, dim, 3)) + param[:, :, 1] = torch.eye(dim, dim) + state_dict[name] = param + return state_dict + + def from_diffusers(self, state_dict, add_positional_conv=None): rename_dict = { "norm": "norm", "proj_in": "proj_in", @@ -192,7 +265,9 @@ class SDMotionModelStateDictConverter: else: rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) state_dict_[rename] = state_dict[name] + if add_positional_conv is not None: + state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv) return state_dict_ - def from_civitai(self, state_dict): - return self.from_diffusers(state_dict) + def from_civitai(self, state_dict, add_positional_conv=None): + return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv) diff --git a/diffsynth/models/sd_motion_ex.py b/diffsynth/models/sd_motion_ex.py new file mode 100644 index 0000000..b1d0de0 --- /dev/null +++ b/diffsynth/models/sd_motion_ex.py @@ -0,0 +1,115 @@ +from .attention import Attention +from .svd_unet import get_timestep_embedding +import torch +from einops import rearrange, repeat + + + +class ExVideoMotionBlock(torch.nn.Module): + + def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None): + super().__init__() + + emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1) + self.positional_embedding = torch.nn.Parameter(emb) + self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None + self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)]) + self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)]) + + def frame_id_to_position_id(self, frame_id, max_id, repeat_length): + if frame_id < max_id: + position_id = frame_id + else: + position_id = (frame_id - max_id) % (repeat_length * 2) + if position_id < repeat_length: + position_id = max_id - 2 - position_id + else: + position_id = max_id - 2 * repeat_length + position_id + return position_id + + def positional_ids(self, num_frames): + max_id = self.positional_embedding.shape[0] + positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)]) + return positional_ids + + def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs): + batch, inner_dim, height, width = hidden_states.shape + residual = hidden_states + + pos_emb = self.positional_ids(batch // batch_size) + pos_emb = self.positional_embedding[pos_emb] + pos_emb = pos_emb.repeat(batch_size) + hidden_states = hidden_states + pos_emb + if self.positional_conv is not None: + hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size) + hidden_states = self.positional_conv(hidden_states) + hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C") + else: + hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size) + + for norm, attn in zip(self.norms, self.attns): + norm_hidden_states = norm(hidden_states) + attn_output = attn(norm_hidden_states) + hidden_states = hidden_states + attn_output + + hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width) + hidden_states = hidden_states + residual + return hidden_states, time_emb, text_emb, res_stack + + + +class ExVideoMotionModel(torch.nn.Module): + def __init__(self, num_layers=2): + super().__init__() + self.motion_modules = torch.nn.ModuleList([ + ExVideoMotionBlock(8, 40, 320, num_layers=num_layers), + ExVideoMotionBlock(8, 40, 320, num_layers=num_layers), + ExVideoMotionBlock(8, 80, 640, num_layers=num_layers), + ExVideoMotionBlock(8, 80, 640, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers), + ExVideoMotionBlock(8, 80, 640, num_layers=num_layers), + ExVideoMotionBlock(8, 80, 640, num_layers=num_layers), + ExVideoMotionBlock(8, 80, 640, num_layers=num_layers), + ExVideoMotionBlock(8, 40, 320, num_layers=num_layers), + ExVideoMotionBlock(8, 40, 320, num_layers=num_layers), + ExVideoMotionBlock(8, 40, 320, num_layers=num_layers), + ]) + self.call_block_id = { + 1: 0, + 4: 1, + 9: 2, + 12: 3, + 17: 4, + 20: 5, + 24: 6, + 26: 7, + 29: 8, + 32: 9, + 34: 10, + 36: 11, + 40: 12, + 43: 13, + 46: 14, + 50: 15, + 53: 16, + 56: 17, + 60: 18, + 63: 19, + 66: 20 + } + + def forward(self): + pass + + def state_dict_converter(self): + pass diff --git a/diffsynth/pipelines/stable_diffusion_video.py b/diffsynth/pipelines/stable_diffusion_video.py index b204cad..f01edf7 100644 --- a/diffsynth/pipelines/stable_diffusion_video.py +++ b/diffsynth/pipelines/stable_diffusion_video.py @@ -10,6 +10,7 @@ import torch, os, json from tqdm import tqdm from PIL import Image import numpy as np +from einops import rearrange def lets_dance_with_long_video( @@ -150,6 +151,14 @@ class SDVideoPipeline(torch.nn.Module): return latents + def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0): + if post_normalize: + mean, std = latents.mean(), latents.std() + latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean + latents = latents * contrast_enhance_scale + return latents + + @torch.no_grad() def __call__( self, @@ -172,6 +181,8 @@ class SDVideoPipeline(torch.nn.Module): smoother=None, smoother_progress_ids=[], vram_limit_level=0, + post_normalize=False, + contrast_enhance_scale=1.0, progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -226,15 +237,18 @@ class SDVideoPipeline(torch.nn.Module): cross_frame_attention=cross_frame_attention, device=self.device, vram_limit_level=vram_limit_level ) - noise_pred_nega = lets_dance_with_long_video( - self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, - sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, - animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride, - unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, - cross_frame_attention=cross_frame_attention, - device=self.device, vram_limit_level=vram_limit_level - ) - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + if cfg_scale != 1.0: + noise_pred_nega = lets_dance_with_long_video( + self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, + animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride, + unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, + cross_frame_attention=cross_frame_attention, + device=self.device, vram_limit_level=vram_limit_level + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi # DDIM and smoother if smoother is not None and progress_id in smoother_progress_ids: @@ -250,6 +264,7 @@ class SDVideoPipeline(torch.nn.Module): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image + latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale) output_frames = self.decode_images(latents) # Post-process diff --git a/diffsynth/prompts/sd_prompter.py b/diffsynth/prompts/sd_prompter.py index 6d4407c..ae3e02a 100644 --- a/diffsynth/prompts/sd_prompter.py +++ b/diffsynth/prompts/sd_prompter.py @@ -8,9 +8,9 @@ class SDPrompter(Prompter): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) - def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): + def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True, max_length=99999999): prompt = self.process_prompt(prompt, positive=positive) - input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) + input_ids = tokenize_long_prompt(self.tokenizer, prompt, max_length=max_length).to(device) prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) diff --git a/diffsynth/prompts/utils.py b/diffsynth/prompts/utils.py index f041228..32282f3 100644 --- a/diffsynth/prompts/utils.py +++ b/diffsynth/prompts/utils.py @@ -3,12 +3,12 @@ from ..models import ModelManager import os -def tokenize_long_prompt(tokenizer, prompt): +def tokenize_long_prompt(tokenizer, prompt, max_length=99999999): # Get model_max_length from self.tokenizer length = tokenizer.model_max_length # To avoid the warning. set self.tokenizer.model_max_length to +oo. - tokenizer.model_max_length = 99999999 + tokenizer.model_max_length = max_length # Tokenize it! input_ids = tokenizer(prompt, return_tensors="pt").input_ids