Compare commits

...

1 Commits
dpo ... ExVideo

Author SHA1 Message Date
Artiprocher
a076adf592 ExVideo for AnimateDiff 2024-07-26 14:35:18 +08:00
7 changed files with 520 additions and 48 deletions

View File

@@ -0,0 +1,267 @@
import torch, json, os, imageio
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
from diffsynth import ModelManager, EnhancedDDIMScheduler, SDVideoPipeline, SDUNet, load_state_dict, SDMotionModel
def lets_dance(
unet: SDUNet,
motion_modules: SDMotionModel,
sample,
timestep,
encoder_hidden_states,
use_gradient_checkpointing=False,
):
# 1. ControlNet (skip)
# 2. time
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
time_emb = unet.time_embedding(time_emb)
# 3. pre-process
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
res_stack = [hidden_states]
# 4. blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block_id, block in enumerate(unet.blocks):
# 4.1 UNet
if use_gradient_checkpointing:
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
# 4.2 AnimateDiff
if block_id in motion_modules.call_block_id:
motion_module_id = motion_modules.call_block_id[block_id]
if use_gradient_checkpointing:
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_modules.motion_modules[motion_module_id]),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](hidden_states, time_emb, text_emb, res_stack)
# 5. output
hidden_states = unet.conv_norm_out(hidden_states)
hidden_states = unet.conv_act(hidden_states)
hidden_states = unet.conv_out(hidden_states)
return hidden_states
class TextVideoDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]):
with open(metadata_path, "r") as f:
metadata = json.load(f)
self.path = [os.path.join(base_path, i["path"]) for i in metadata]
self.text = [i["text"] for i in metadata]
self.steps_per_epoch = steps_per_epoch
self.training_shapes = training_shapes
self.frame_process = []
for max_num_frames, interval, num_frames, height, width in training_shapes:
self.frame_process.append(v2.Compose([
v2.Resize(size=max(height, width), antialias=True),
v2.CenterCrop(size=(height, width)),
v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
]))
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = torch.tensor(frame, dtype=torch.float32)
frame = rearrange(frame, "H W C -> 1 C H W")
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.concat(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
return frames
def load_video(self, file_path, training_shape_id):
data = {}
max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id]
frame_process = self.frame_process[training_shape_id]
start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0]
frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process)
if frames is None:
return None
else:
data[f"frames_{training_shape_id}"] = frames
data[f"start_frame_id_{training_shape_id}"] = start_frame_id
return data
def __getitem__(self, index):
video_data = {}
for training_shape_id in range(len(self.training_shapes)):
while True:
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id]
if isinstance(text, list):
text = text[torch.randint(0, len(text), (1,))[0]]
video_file = self.path[data_id]
try:
data = self.load_video(video_file, training_shape_id)
except:
data = None
if data is not None:
data[f"text_{training_shape_id}"] = text
break
video_data.update(data)
return video_data
def __len__(self):
return self.steps_per_epoch
class LightningModel(pl.LightningModule):
def __init__(self, learning_rate=1e-5, sd_ckpt_path=None):
super().__init__()
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cpu")
model_manager.load_stable_diffusion(load_state_dict(sd_ckpt_path))
# Initialize motion modules
model_manager.model["motion_modules"] = SDMotionModel().to(dtype=self.dtype, device=self.device)
# Build pipeline
self.pipe = SDVideoPipeline.from_model_manager(model_manager)
self.pipe.vae_encoder.eval()
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.vae_decoder.eval()
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.text_encoder.eval()
self.pipe.text_encoder.requires_grad_(False)
self.pipe.unet.eval()
self.pipe.unet.requires_grad_(False)
self.pipe.motion_modules.train()
self.pipe.motion_modules.requires_grad_(True)
# Reset the scheduler
self.pipe.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
self.pipe.scheduler.set_timesteps(1000)
# Other parameters
self.learning_rate = learning_rate
def encode_video_with_vae(self, video):
video = video.to(device=self.device, dtype=self.dtype)
video = video.unsqueeze(0)
latents = self.pipe.vae_encoder.encode_video(video, batch_size=16)
latents = rearrange(latents[0], "C T H W -> T C H W")
return latents
def calculate_loss(self, prompt, frames):
with torch.no_grad():
# Call video encoder
latents = self.encode_video_with_vae(frames)
# Call text encoder
prompt_embs = self.pipe.prompter.encode_prompt(self.pipe.text_encoder, prompt, device=self.device, max_length=77)
prompt_embs = prompt_embs.repeat(latents.shape[0], 1, 1)
# Call scheduler
timestep = torch.randint(0, len(self.pipe.scheduler.timesteps), (1,), device=self.device)[0]
noise = torch.randn_like(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
# Calculate loss
model_pred = lets_dance(
self.pipe.unet, self.pipe.motion_modules,
sample=noisy_latents, encoder_hidden_states=prompt_embs, timestep=timestep
)
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
return loss
def training_step(self, batch, batch_idx):
# Loss
frames = batch["frames_0"][0]
prompt = batch["text_0"][0]
loss = self.calculate_loss(prompt, frames)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.pipe.motion_modules.parameters(), lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_modules.named_parameters()))
trainable_param_names = [named_param[0] for named_param in trainable_param_names]
checkpoint["trainable_param_names"] = trainable_param_names
if __name__ == '__main__':
# dataset and data loader
dataset = TextVideoDataset(
"/data/zhongjie/datasets/opensoraplan/data/processed",
"/data/zhongjie/datasets/opensoraplan/data/processed/metadata.json",
training_shapes=[(16, 1, 16, 512, 512)],
steps_per_epoch=7*10000,
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=4
)
# model
model = LightningModel(
learning_rate=1e-5,
sd_ckpt_path="models/stable_diffusion/v1-5-pruned-emaonly.safetensors",
)
# train
trainer = pl.Trainer(
max_epochs=100000,
accelerator="gpu",
devices="auto",
strategy="deepspeed_stage_1",
precision="16-mixed",
default_root_dir="/data/zhongjie/models/train_extended_animatediff",
accumulate_grad_batches=1,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
)
trainer.fit(
model=model,
train_dataloaders=train_loader,
ckpt_path=None
)

View File

@@ -194,10 +194,10 @@ class ModelManager:
self.model[component].append(model) self.model[component].append(model)
self.model_path[component].append(file_path) self.model_path[component].append(file_path)
def load_animatediff(self, state_dict, file_path=""): def load_animatediff(self, state_dict, file_path="", add_positional_conv=None):
component = "motion_modules" component = "motion_modules"
model = SDMotionModel() model = SDMotionModel(add_positional_conv=add_positional_conv)
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict)) model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv))
model.to(self.torch_dtype).to(self.device) model.to(self.torch_dtype).to(self.device)
self.model[component] = model self.model[component] = model
self.model_path[component] = file_path self.model_path[component] = file_path

View File

@@ -1,20 +1,28 @@
from .sd_unet import SDUNet, Attention, GEGLU from .sd_unet import SDUNet, Attention, GEGLU
from .svd_unet import get_timestep_embedding
import torch import torch
from einops import rearrange, repeat from einops import rearrange, repeat
class TemporalTransformerBlock(torch.nn.Module): class TemporalTransformerBlock(torch.nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32): def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
super().__init__() super().__init__()
self.add_positional_conv = add_positional_conv
# 1. Self-Attn # 1. Self-Attn
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe1 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True) self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
# 2. Cross-Attn # 2. Cross-Attn
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim)) emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
self.pe2 = torch.nn.Parameter(emb)
if add_positional_conv:
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True) self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
@@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module):
self.ff = torch.nn.Linear(dim * 4, dim) self.ff = torch.nn.Linear(dim * 4, dim)
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def positional_ids(self, num_frames):
max_id = self.pe1.shape[1]
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
return positional_ids
def forward(self, hidden_states, batch_size=1): def forward(self, hidden_states, batch_size=1):
# 1. Self-Attention # 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]]) norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn1(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states
# 2. Cross-Attention # 2. Cross-Attention
norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size) norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]]) norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
if self.add_positional_conv:
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
attn_output = self.attn2(norm_hidden_states)
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size) attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
hidden_states = attn_output + hidden_states hidden_states = attn_output + hidden_states
@@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module):
class TemporalBlock(torch.nn.Module): class TemporalBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
super().__init__() super().__init__()
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
@@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module):
TemporalTransformerBlock( TemporalTransformerBlock(
inner_dim, inner_dim,
num_attention_heads, num_attention_heads,
attention_head_dim attention_head_dim,
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
add_positional_conv=add_positional_conv
) )
for d in range(num_layers) for d in range(num_layers)
]) ])
@@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module):
class SDMotionModel(torch.nn.Module): class SDMotionModel(torch.nn.Module):
def __init__(self): def __init__(self, add_positional_conv=None):
super().__init__() super().__init__()
self.motion_modules = torch.nn.ModuleList([ self.motion_modules = torch.nn.ModuleList([
TemporalBlock(8, 40, 320, eps=1e-6), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 160, 1280, eps=1e-6), TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 80, 640, eps=1e-6), TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
TemporalBlock(8, 40, 320, eps=1e-6), TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
]) ])
self.call_block_id = { self.call_block_id = {
1: 0, 1: 0,
@@ -152,7 +190,42 @@ class SDMotionModelStateDictConverter:
def __init__(self): def __init__(self):
pass pass
def from_diffusers(self, state_dict): def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
for i in range(21):
# Extend positional embedding
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
state_dict[name] = state_dict[name][:, ids]
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
state_dict[name] = state_dict[name][:, ids]
# add post convolution
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
state_dict[name] = torch.zeros((dim,))
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
param = torch.zeros((dim, dim, 3))
param[:, :, 1] = torch.eye(dim, dim)
state_dict[name] = param
return state_dict
def from_diffusers(self, state_dict, add_positional_conv=None):
rename_dict = { rename_dict = {
"norm": "norm", "norm": "norm",
"proj_in": "proj_in", "proj_in": "proj_in",
@@ -192,7 +265,9 @@ class SDMotionModelStateDictConverter:
else: else:
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix]) rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
state_dict_[rename] = state_dict[name] state_dict_[rename] = state_dict[name]
if add_positional_conv is not None:
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
return state_dict_ return state_dict_
def from_civitai(self, state_dict): def from_civitai(self, state_dict, add_positional_conv=None):
return self.from_diffusers(state_dict) return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)

View File

@@ -0,0 +1,115 @@
from .attention import Attention
from .svd_unet import get_timestep_embedding
import torch
from einops import rearrange, repeat
class ExVideoMotionBlock(torch.nn.Module):
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
super().__init__()
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
self.positional_embedding = torch.nn.Parameter(emb)
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
if frame_id < max_id:
position_id = frame_id
else:
position_id = (frame_id - max_id) % (repeat_length * 2)
if position_id < repeat_length:
position_id = max_id - 2 - position_id
else:
position_id = max_id - 2 * repeat_length + position_id
return position_id
def positional_ids(self, num_frames):
max_id = self.positional_embedding.shape[0]
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
return positional_ids
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
batch, inner_dim, height, width = hidden_states.shape
residual = hidden_states
pos_emb = self.positional_ids(batch // batch_size)
pos_emb = self.positional_embedding[pos_emb]
pos_emb = pos_emb.repeat(batch_size)
hidden_states = hidden_states + pos_emb
if self.positional_conv is not None:
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
hidden_states = self.positional_conv(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
else:
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
for norm, attn in zip(self.norms, self.attns):
norm_hidden_states = norm(hidden_states)
attn_output = attn(norm_hidden_states)
hidden_states = hidden_states + attn_output
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
hidden_states = hidden_states + residual
return hidden_states, time_emb, text_emb, res_stack
class ExVideoMotionModel(torch.nn.Module):
def __init__(self, num_layers=2):
super().__init__()
self.motion_modules = torch.nn.ModuleList([
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
])
self.call_block_id = {
1: 0,
4: 1,
9: 2,
12: 3,
17: 4,
20: 5,
24: 6,
26: 7,
29: 8,
32: 9,
34: 10,
36: 11,
40: 12,
43: 13,
46: 14,
50: 15,
53: 16,
56: 17,
60: 18,
63: 19,
66: 20
}
def forward(self):
pass
def state_dict_converter(self):
pass

View File

@@ -10,6 +10,7 @@ import torch, os, json
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from einops import rearrange
def lets_dance_with_long_video( def lets_dance_with_long_video(
@@ -150,6 +151,14 @@ class SDVideoPipeline(torch.nn.Module):
return latents return latents
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
if post_normalize:
mean, std = latents.mean(), latents.std()
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
latents = latents * contrast_enhance_scale
return latents
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
@@ -172,6 +181,8 @@ class SDVideoPipeline(torch.nn.Module):
smoother=None, smoother=None,
smoother_progress_ids=[], smoother_progress_ids=[],
vram_limit_level=0, vram_limit_level=0,
post_normalize=False,
contrast_enhance_scale=1.0,
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
progress_bar_st=None, progress_bar_st=None,
): ):
@@ -226,15 +237,18 @@ class SDVideoPipeline(torch.nn.Module):
cross_frame_attention=cross_frame_attention, cross_frame_attention=cross_frame_attention,
device=self.device, vram_limit_level=vram_limit_level device=self.device, vram_limit_level=vram_limit_level
) )
noise_pred_nega = lets_dance_with_long_video( if cfg_scale != 1.0:
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet, noise_pred_nega = lets_dance_with_long_video(
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames, self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride, sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size, animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
cross_frame_attention=cross_frame_attention, unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
device=self.device, vram_limit_level=vram_limit_level cross_frame_attention=cross_frame_attention,
) device=self.device, vram_limit_level=vram_limit_level
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM and smoother # DDIM and smoother
if smoother is not None and progress_id in smoother_progress_ids: if smoother is not None and progress_id in smoother_progress_ids:
@@ -250,6 +264,7 @@ class SDVideoPipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image # Decode image
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
output_frames = self.decode_images(latents) output_frames = self.decode_images(latents)
# Post-process # Post-process

View File

@@ -8,9 +8,9 @@ class SDPrompter(Prompter):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True): def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True, max_length=99999999):
prompt = self.process_prompt(prompt, positive=positive) prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device) input_ids = tokenize_long_prompt(self.tokenizer, prompt, max_length=max_length).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip) prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))

View File

@@ -3,12 +3,12 @@ from ..models import ModelManager
import os import os
def tokenize_long_prompt(tokenizer, prompt): def tokenize_long_prompt(tokenizer, prompt, max_length=99999999):
# Get model_max_length from self.tokenizer # Get model_max_length from self.tokenizer
length = tokenizer.model_max_length length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo. # To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999 tokenizer.model_max_length = max_length
# Tokenize it! # Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids input_ids = tokenizer(prompt, return_tensors="pt").input_ids