mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
1 Commits
wan-lora-f
...
ExVideo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a076adf592 |
267
ExVideo_animatediff_train.py
Normal file
267
ExVideo_animatediff_train.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import torch, json, os, imageio
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from einops import rearrange
|
||||||
|
import lightning as pl
|
||||||
|
from diffsynth import ModelManager, EnhancedDDIMScheduler, SDVideoPipeline, SDUNet, load_state_dict, SDMotionModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def lets_dance(
|
||||||
|
unet: SDUNet,
|
||||||
|
motion_modules: SDMotionModel,
|
||||||
|
sample,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
):
|
||||||
|
# 1. ControlNet (skip)
|
||||||
|
# 2. time
|
||||||
|
time_emb = unet.time_proj(timestep[None]).to(sample.dtype)
|
||||||
|
time_emb = unet.time_embedding(time_emb)
|
||||||
|
|
||||||
|
# 3. pre-process
|
||||||
|
hidden_states = unet.conv_in(sample)
|
||||||
|
text_emb = encoder_hidden_states
|
||||||
|
res_stack = [hidden_states]
|
||||||
|
|
||||||
|
# 4. blocks
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
for block_id, block in enumerate(unet.blocks):
|
||||||
|
# 4.1 UNet
|
||||||
|
if use_gradient_checkpointing:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
# 4.2 AnimateDiff
|
||||||
|
if block_id in motion_modules.call_block_id:
|
||||||
|
motion_module_id = motion_modules.call_block_id[block_id]
|
||||||
|
if use_gradient_checkpointing:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(motion_modules.motion_modules[motion_module_id]),
|
||||||
|
hidden_states, time_emb, text_emb, res_stack,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id](hidden_states, time_emb, text_emb, res_stack)
|
||||||
|
|
||||||
|
# 5. output
|
||||||
|
hidden_states = unet.conv_norm_out(hidden_states)
|
||||||
|
hidden_states = unet.conv_act(hidden_states)
|
||||||
|
hidden_states = unet.conv_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TextVideoDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, base_path, metadata_path, steps_per_epoch=10000, training_shapes=[(128, 1, 128, 512, 512)]):
|
||||||
|
with open(metadata_path, "r") as f:
|
||||||
|
metadata = json.load(f)
|
||||||
|
self.path = [os.path.join(base_path, i["path"]) for i in metadata]
|
||||||
|
self.text = [i["text"] for i in metadata]
|
||||||
|
self.steps_per_epoch = steps_per_epoch
|
||||||
|
self.training_shapes = training_shapes
|
||||||
|
|
||||||
|
self.frame_process = []
|
||||||
|
for max_num_frames, interval, num_frames, height, width in training_shapes:
|
||||||
|
self.frame_process.append(v2.Compose([
|
||||||
|
v2.Resize(size=max(height, width), antialias=True),
|
||||||
|
v2.CenterCrop(size=(height, width)),
|
||||||
|
v2.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
|
||||||
|
reader = imageio.get_reader(file_path)
|
||||||
|
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
|
||||||
|
reader.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for frame_id in range(num_frames):
|
||||||
|
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||||
|
frame = torch.tensor(frame, dtype=torch.float32)
|
||||||
|
frame = rearrange(frame, "H W C -> 1 C H W")
|
||||||
|
frame = frame_process(frame)
|
||||||
|
frames.append(frame)
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
frames = torch.concat(frames, dim=0)
|
||||||
|
frames = rearrange(frames, "T C H W -> C T H W")
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_video(self, file_path, training_shape_id):
|
||||||
|
data = {}
|
||||||
|
max_num_frames, interval, num_frames, height, width = self.training_shapes[training_shape_id]
|
||||||
|
frame_process = self.frame_process[training_shape_id]
|
||||||
|
start_frame_id = torch.randint(0, max_num_frames - (num_frames - 1) * interval, (1,))[0]
|
||||||
|
frames = self.load_frames_using_imageio(file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process)
|
||||||
|
if frames is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
data[f"frames_{training_shape_id}"] = frames
|
||||||
|
data[f"start_frame_id_{training_shape_id}"] = start_frame_id
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
video_data = {}
|
||||||
|
for training_shape_id in range(len(self.training_shapes)):
|
||||||
|
while True:
|
||||||
|
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||||
|
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||||
|
text = self.text[data_id]
|
||||||
|
if isinstance(text, list):
|
||||||
|
text = text[torch.randint(0, len(text), (1,))[0]]
|
||||||
|
video_file = self.path[data_id]
|
||||||
|
try:
|
||||||
|
data = self.load_video(video_file, training_shape_id)
|
||||||
|
except:
|
||||||
|
data = None
|
||||||
|
if data is not None:
|
||||||
|
data[f"text_{training_shape_id}"] = text
|
||||||
|
break
|
||||||
|
video_data.update(data)
|
||||||
|
return video_data
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.steps_per_epoch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LightningModel(pl.LightningModule):
|
||||||
|
def __init__(self, learning_rate=1e-5, sd_ckpt_path=None):
|
||||||
|
super().__init__()
|
||||||
|
# Load models
|
||||||
|
model_manager = ModelManager(torch_dtype=torch.float16, device="cpu")
|
||||||
|
model_manager.load_stable_diffusion(load_state_dict(sd_ckpt_path))
|
||||||
|
|
||||||
|
# Initialize motion modules
|
||||||
|
model_manager.model["motion_modules"] = SDMotionModel().to(dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
# Build pipeline
|
||||||
|
self.pipe = SDVideoPipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.vae_encoder.eval()
|
||||||
|
self.pipe.vae_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.vae_decoder.eval()
|
||||||
|
self.pipe.vae_decoder.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.text_encoder.eval()
|
||||||
|
self.pipe.text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.unet.eval()
|
||||||
|
self.pipe.unet.requires_grad_(False)
|
||||||
|
|
||||||
|
self.pipe.motion_modules.train()
|
||||||
|
self.pipe.motion_modules.requires_grad_(True)
|
||||||
|
|
||||||
|
# Reset the scheduler
|
||||||
|
self.pipe.scheduler = EnhancedDDIMScheduler(beta_schedule="scaled_linear")
|
||||||
|
self.pipe.scheduler.set_timesteps(1000)
|
||||||
|
|
||||||
|
# Other parameters
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
def encode_video_with_vae(self, video):
|
||||||
|
video = video.to(device=self.device, dtype=self.dtype)
|
||||||
|
video = video.unsqueeze(0)
|
||||||
|
latents = self.pipe.vae_encoder.encode_video(video, batch_size=16)
|
||||||
|
latents = rearrange(latents[0], "C T H W -> T C H W")
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_loss(self, prompt, frames):
|
||||||
|
with torch.no_grad():
|
||||||
|
# Call video encoder
|
||||||
|
latents = self.encode_video_with_vae(frames)
|
||||||
|
|
||||||
|
# Call text encoder
|
||||||
|
prompt_embs = self.pipe.prompter.encode_prompt(self.pipe.text_encoder, prompt, device=self.device, max_length=77)
|
||||||
|
prompt_embs = prompt_embs.repeat(latents.shape[0], 1, 1)
|
||||||
|
|
||||||
|
# Call scheduler
|
||||||
|
timestep = torch.randint(0, len(self.pipe.scheduler.timesteps), (1,), device=self.device)[0]
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||||
|
|
||||||
|
# Calculate loss
|
||||||
|
model_pred = lets_dance(
|
||||||
|
self.pipe.unet, self.pipe.motion_modules,
|
||||||
|
sample=noisy_latents, encoder_hidden_states=prompt_embs, timestep=timestep
|
||||||
|
)
|
||||||
|
loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
# Loss
|
||||||
|
frames = batch["frames_0"][0]
|
||||||
|
prompt = batch["text_0"][0]
|
||||||
|
loss = self.calculate_loss(prompt, frames)
|
||||||
|
|
||||||
|
# Record log
|
||||||
|
self.log("train_loss", loss, prog_bar=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.AdamW(self.pipe.motion_modules.parameters(), lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_modules.named_parameters()))
|
||||||
|
trainable_param_names = [named_param[0] for named_param in trainable_param_names]
|
||||||
|
checkpoint["trainable_param_names"] = trainable_param_names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# dataset and data loader
|
||||||
|
dataset = TextVideoDataset(
|
||||||
|
"/data/zhongjie/datasets/opensoraplan/data/processed",
|
||||||
|
"/data/zhongjie/datasets/opensoraplan/data/processed/metadata.json",
|
||||||
|
training_shapes=[(16, 1, 16, 512, 512)],
|
||||||
|
steps_per_epoch=7*10000,
|
||||||
|
)
|
||||||
|
train_loader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4
|
||||||
|
)
|
||||||
|
|
||||||
|
# model
|
||||||
|
model = LightningModel(
|
||||||
|
learning_rate=1e-5,
|
||||||
|
sd_ckpt_path="models/stable_diffusion/v1-5-pruned-emaonly.safetensors",
|
||||||
|
)
|
||||||
|
|
||||||
|
# train
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=100000,
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
strategy="deepspeed_stage_1",
|
||||||
|
precision="16-mixed",
|
||||||
|
default_root_dir="/data/zhongjie/models/train_extended_animatediff",
|
||||||
|
accumulate_grad_batches=1,
|
||||||
|
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||||
|
)
|
||||||
|
trainer.fit(
|
||||||
|
model=model,
|
||||||
|
train_dataloaders=train_loader,
|
||||||
|
ckpt_path=None
|
||||||
|
)
|
||||||
@@ -194,10 +194,10 @@ class ModelManager:
|
|||||||
self.model[component].append(model)
|
self.model[component].append(model)
|
||||||
self.model_path[component].append(file_path)
|
self.model_path[component].append(file_path)
|
||||||
|
|
||||||
def load_animatediff(self, state_dict, file_path=""):
|
def load_animatediff(self, state_dict, file_path="", add_positional_conv=None):
|
||||||
component = "motion_modules"
|
component = "motion_modules"
|
||||||
model = SDMotionModel()
|
model = SDMotionModel(add_positional_conv=add_positional_conv)
|
||||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv))
|
||||||
model.to(self.torch_dtype).to(self.device)
|
model.to(self.torch_dtype).to(self.device)
|
||||||
self.model[component] = model
|
self.model[component] = model
|
||||||
self.model_path[component] = file_path
|
self.model_path[component] = file_path
|
||||||
|
|||||||
@@ -1,20 +1,28 @@
|
|||||||
from .sd_unet import SDUNet, Attention, GEGLU
|
from .sd_unet import SDUNet, Attention, GEGLU
|
||||||
|
from .svd_unet import get_timestep_embedding
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
class TemporalTransformerBlock(torch.nn.Module):
|
class TemporalTransformerBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.add_positional_conv = add_positional_conv
|
||||||
|
|
||||||
# 1. Self-Attn
|
# 1. Self-Attn
|
||||||
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||||
|
self.pe1 = torch.nn.Parameter(emb)
|
||||||
|
if add_positional_conv:
|
||||||
|
self.positional_conv_1 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||||
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||||
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||||
|
|
||||||
# 2. Cross-Attn
|
# 2. Cross-Attn
|
||||||
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
emb = get_timestep_embedding(torch.arange(max_position_embeddings), dim, True, 0).reshape(1, max_position_embeddings, dim)
|
||||||
|
self.pe2 = torch.nn.Parameter(emb)
|
||||||
|
if add_positional_conv:
|
||||||
|
self.positional_conv_2 = torch.nn.Conv1d(dim, dim, kernel_size=3, padding=1, padding_mode="reflect")
|
||||||
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||||
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||||
|
|
||||||
@@ -24,19 +32,47 @@ class TemporalTransformerBlock(torch.nn.Module):
|
|||||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||||
|
|
||||||
|
|
||||||
|
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||||
|
if frame_id < max_id:
|
||||||
|
position_id = frame_id
|
||||||
|
else:
|
||||||
|
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||||
|
if position_id < repeat_length:
|
||||||
|
position_id = max_id - 2 - position_id
|
||||||
|
else:
|
||||||
|
position_id = max_id - 2 * repeat_length + position_id
|
||||||
|
return position_id
|
||||||
|
|
||||||
|
|
||||||
|
def positional_ids(self, num_frames):
|
||||||
|
max_id = self.pe1.shape[1]
|
||||||
|
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
||||||
|
return positional_ids
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, batch_size=1):
|
def forward(self, hidden_states, batch_size=1):
|
||||||
|
|
||||||
# 1. Self-Attention
|
# 1. Self-Attention
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||||
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
norm_hidden_states = norm_hidden_states + self.pe1[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||||
|
if self.add_positional_conv:
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||||
|
norm_hidden_states = self.positional_conv_1(norm_hidden_states)
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||||
|
attn_output = self.attn1(norm_hidden_states)
|
||||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
# 2. Cross-Attention
|
# 2. Cross-Attention
|
||||||
norm_hidden_states = self.norm2(hidden_states)
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||||
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
norm_hidden_states = norm_hidden_states + self.pe2[:, self.positional_ids(norm_hidden_states.shape[1])]
|
||||||
|
if self.add_positional_conv:
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) f c -> (b h) c f", b=batch_size)
|
||||||
|
norm_hidden_states = self.positional_conv_2(norm_hidden_states)
|
||||||
|
norm_hidden_states = rearrange(norm_hidden_states, "(b h) c f -> (b h) f c", b=batch_size)
|
||||||
|
attn_output = self.attn2(norm_hidden_states)
|
||||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||||
hidden_states = attn_output + hidden_states
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
@@ -51,7 +87,7 @@ class TemporalTransformerBlock(torch.nn.Module):
|
|||||||
|
|
||||||
class TemporalBlock(torch.nn.Module):
|
class TemporalBlock(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inner_dim = num_attention_heads * attention_head_dim
|
inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
@@ -62,7 +98,9 @@ class TemporalBlock(torch.nn.Module):
|
|||||||
TemporalTransformerBlock(
|
TemporalTransformerBlock(
|
||||||
inner_dim,
|
inner_dim,
|
||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_head_dim
|
attention_head_dim,
|
||||||
|
max_position_embeddings=32 if add_positional_conv is None else add_positional_conv,
|
||||||
|
add_positional_conv=add_positional_conv
|
||||||
)
|
)
|
||||||
for d in range(num_layers)
|
for d in range(num_layers)
|
||||||
])
|
])
|
||||||
@@ -92,30 +130,30 @@ class TemporalBlock(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SDMotionModel(torch.nn.Module):
|
class SDMotionModel(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self, add_positional_conv=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.motion_modules = torch.nn.ModuleList([
|
self.motion_modules = torch.nn.ModuleList([
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
TemporalBlock(8, 160, 1280, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
TemporalBlock(8, 80, 640, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
TemporalBlock(8, 40, 320, eps=1e-6, add_positional_conv=add_positional_conv),
|
||||||
])
|
])
|
||||||
self.call_block_id = {
|
self.call_block_id = {
|
||||||
1: 0,
|
1: 0,
|
||||||
@@ -152,7 +190,42 @@ class SDMotionModelStateDictConverter:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def from_diffusers(self, state_dict):
|
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||||
|
if frame_id < max_id:
|
||||||
|
position_id = frame_id
|
||||||
|
else:
|
||||||
|
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||||
|
if position_id < repeat_length:
|
||||||
|
position_id = max_id - 2 - position_id
|
||||||
|
else:
|
||||||
|
position_id = max_id - 2 * repeat_length + position_id
|
||||||
|
return position_id
|
||||||
|
|
||||||
|
def process_positional_conv_parameters(self, state_dict, add_positional_conv):
|
||||||
|
ids = [self.frame_id_to_position_id(i, 16, 15) for i in range(add_positional_conv)]
|
||||||
|
for i in range(21):
|
||||||
|
# Extend positional embedding
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.pe1"
|
||||||
|
state_dict[name] = state_dict[name][:, ids]
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.pe2"
|
||||||
|
state_dict[name] = state_dict[name][:, ids]
|
||||||
|
# add post convolution
|
||||||
|
dim = state_dict[f"motion_modules.{i}.transformer_blocks.0.pe1"].shape[-1]
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.bias"
|
||||||
|
state_dict[name] = torch.zeros((dim,))
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.bias"
|
||||||
|
state_dict[name] = torch.zeros((dim,))
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_1.weight"
|
||||||
|
param = torch.zeros((dim, dim, 3))
|
||||||
|
param[:, :, 1] = torch.eye(dim, dim)
|
||||||
|
state_dict[name] = param
|
||||||
|
name = f"motion_modules.{i}.transformer_blocks.0.positional_conv_2.weight"
|
||||||
|
param = torch.zeros((dim, dim, 3))
|
||||||
|
param[:, :, 1] = torch.eye(dim, dim)
|
||||||
|
state_dict[name] = param
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict, add_positional_conv=None):
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
"norm": "norm",
|
"norm": "norm",
|
||||||
"proj_in": "proj_in",
|
"proj_in": "proj_in",
|
||||||
@@ -192,7 +265,9 @@ class SDMotionModelStateDictConverter:
|
|||||||
else:
|
else:
|
||||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||||
state_dict_[rename] = state_dict[name]
|
state_dict_[rename] = state_dict[name]
|
||||||
|
if add_positional_conv is not None:
|
||||||
|
state_dict_ = self.process_positional_conv_parameters(state_dict_, add_positional_conv)
|
||||||
return state_dict_
|
return state_dict_
|
||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict, add_positional_conv=None):
|
||||||
return self.from_diffusers(state_dict)
|
return self.from_diffusers(state_dict, add_positional_conv=add_positional_conv)
|
||||||
|
|||||||
115
diffsynth/models/sd_motion_ex.py
Normal file
115
diffsynth/models/sd_motion_ex.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from .attention import Attention
|
||||||
|
from .svd_unet import get_timestep_embedding
|
||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ExVideoMotionBlock(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_attention_heads, attention_head_dim, in_channels, max_position_embeddings=16, num_layers=1, add_positional_conv=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
emb = get_timestep_embedding(torch.arange(max_position_embeddings), in_channels, True, 0).reshape(max_position_embeddings, in_channels, 1, 1)
|
||||||
|
self.positional_embedding = torch.nn.Parameter(emb)
|
||||||
|
self.positional_conv = torch.nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) if add_positional_conv is not None else None
|
||||||
|
self.norms = torch.nn.ModuleList([torch.nn.LayerNorm(in_channels) for _ in range(num_layers)])
|
||||||
|
self.attns = torch.nn.ModuleList([Attention(q_dim=in_channels, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True) for _ in range(num_layers)])
|
||||||
|
|
||||||
|
def frame_id_to_position_id(self, frame_id, max_id, repeat_length):
|
||||||
|
if frame_id < max_id:
|
||||||
|
position_id = frame_id
|
||||||
|
else:
|
||||||
|
position_id = (frame_id - max_id) % (repeat_length * 2)
|
||||||
|
if position_id < repeat_length:
|
||||||
|
position_id = max_id - 2 - position_id
|
||||||
|
else:
|
||||||
|
position_id = max_id - 2 * repeat_length + position_id
|
||||||
|
return position_id
|
||||||
|
|
||||||
|
def positional_ids(self, num_frames):
|
||||||
|
max_id = self.positional_embedding.shape[0]
|
||||||
|
positional_ids = torch.IntTensor([self.frame_id_to_position_id(i, max_id, max_id - 1) for i in range(num_frames)])
|
||||||
|
return positional_ids
|
||||||
|
|
||||||
|
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1, **kwargs):
|
||||||
|
batch, inner_dim, height, width = hidden_states.shape
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
pos_emb = self.positional_ids(batch // batch_size)
|
||||||
|
pos_emb = self.positional_embedding[pos_emb]
|
||||||
|
pos_emb = pos_emb.repeat(batch_size)
|
||||||
|
hidden_states = hidden_states + pos_emb
|
||||||
|
if self.positional_conv is not None:
|
||||||
|
hidden_states = rearrange(hidden_states, "(B T) C H W -> B C T H W", B=batch_size)
|
||||||
|
hidden_states = self.positional_conv(hidden_states)
|
||||||
|
hidden_states = rearrange(hidden_states, "B C T H W -> (B H W) T C")
|
||||||
|
else:
|
||||||
|
hidden_states = rearrange(hidden_states, "(B T) C H W -> (B H W) T C", B=batch_size)
|
||||||
|
|
||||||
|
for norm, attn in zip(self.norms, self.attns):
|
||||||
|
norm_hidden_states = norm(hidden_states)
|
||||||
|
attn_output = attn(norm_hidden_states)
|
||||||
|
hidden_states = hidden_states + attn_output
|
||||||
|
|
||||||
|
hidden_states = rearrange(hidden_states, "(B H W) T C -> (B T) C H W", B=batch_size, H=height, W=width)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
return hidden_states, time_emb, text_emb, res_stack
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ExVideoMotionModel(torch.nn.Module):
|
||||||
|
def __init__(self, num_layers=2):
|
||||||
|
super().__init__()
|
||||||
|
self.motion_modules = torch.nn.ModuleList([
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 160, 1280, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 80, 640, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
ExVideoMotionBlock(8, 40, 320, num_layers=num_layers),
|
||||||
|
])
|
||||||
|
self.call_block_id = {
|
||||||
|
1: 0,
|
||||||
|
4: 1,
|
||||||
|
9: 2,
|
||||||
|
12: 3,
|
||||||
|
17: 4,
|
||||||
|
20: 5,
|
||||||
|
24: 6,
|
||||||
|
26: 7,
|
||||||
|
29: 8,
|
||||||
|
32: 9,
|
||||||
|
34: 10,
|
||||||
|
36: 11,
|
||||||
|
40: 12,
|
||||||
|
43: 13,
|
||||||
|
46: 14,
|
||||||
|
50: 15,
|
||||||
|
53: 16,
|
||||||
|
56: 17,
|
||||||
|
60: 18,
|
||||||
|
63: 19,
|
||||||
|
66: 20
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def state_dict_converter(self):
|
||||||
|
pass
|
||||||
@@ -10,6 +10,7 @@ import torch, os, json
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
def lets_dance_with_long_video(
|
def lets_dance_with_long_video(
|
||||||
@@ -150,6 +151,14 @@ class SDVideoPipeline(torch.nn.Module):
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
|
def post_process_latents(self, latents, post_normalize=True, contrast_enhance_scale=1.0):
|
||||||
|
if post_normalize:
|
||||||
|
mean, std = latents.mean(), latents.std()
|
||||||
|
latents = (latents - latents.mean(dim=[1, 2, 3], keepdim=True)) / latents.std(dim=[1, 2, 3], keepdim=True) * std + mean
|
||||||
|
latents = latents * contrast_enhance_scale
|
||||||
|
return latents
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@@ -172,6 +181,8 @@ class SDVideoPipeline(torch.nn.Module):
|
|||||||
smoother=None,
|
smoother=None,
|
||||||
smoother_progress_ids=[],
|
smoother_progress_ids=[],
|
||||||
vram_limit_level=0,
|
vram_limit_level=0,
|
||||||
|
post_normalize=False,
|
||||||
|
contrast_enhance_scale=1.0,
|
||||||
progress_bar_cmd=tqdm,
|
progress_bar_cmd=tqdm,
|
||||||
progress_bar_st=None,
|
progress_bar_st=None,
|
||||||
):
|
):
|
||||||
@@ -226,15 +237,18 @@ class SDVideoPipeline(torch.nn.Module):
|
|||||||
cross_frame_attention=cross_frame_attention,
|
cross_frame_attention=cross_frame_attention,
|
||||||
device=self.device, vram_limit_level=vram_limit_level
|
device=self.device, vram_limit_level=vram_limit_level
|
||||||
)
|
)
|
||||||
noise_pred_nega = lets_dance_with_long_video(
|
if cfg_scale != 1.0:
|
||||||
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
noise_pred_nega = lets_dance_with_long_video(
|
||||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
self.unet, motion_modules=self.motion_modules, controlnet=self.controlnet,
|
||||||
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, controlnet_frames=controlnet_frames,
|
||||||
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
animatediff_batch_size=animatediff_batch_size, animatediff_stride=animatediff_stride,
|
||||||
cross_frame_attention=cross_frame_attention,
|
unet_batch_size=unet_batch_size, controlnet_batch_size=controlnet_batch_size,
|
||||||
device=self.device, vram_limit_level=vram_limit_level
|
cross_frame_attention=cross_frame_attention,
|
||||||
)
|
device=self.device, vram_limit_level=vram_limit_level
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
)
|
||||||
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
# DDIM and smoother
|
# DDIM and smoother
|
||||||
if smoother is not None and progress_id in smoother_progress_ids:
|
if smoother is not None and progress_id in smoother_progress_ids:
|
||||||
@@ -250,6 +264,7 @@ class SDVideoPipeline(torch.nn.Module):
|
|||||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||||
|
|
||||||
# Decode image
|
# Decode image
|
||||||
|
latents = self.post_process_latents(latents, post_normalize=post_normalize, contrast_enhance_scale=contrast_enhance_scale)
|
||||||
output_frames = self.decode_images(latents)
|
output_frames = self.decode_images(latents)
|
||||||
|
|
||||||
# Post-process
|
# Post-process
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ class SDPrompter(Prompter):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||||
|
|
||||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True, max_length=99999999):
|
||||||
prompt = self.process_prompt(prompt, positive=positive)
|
prompt = self.process_prompt(prompt, positive=positive)
|
||||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
input_ids = tokenize_long_prompt(self.tokenizer, prompt, max_length=max_length).to(device)
|
||||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,12 @@ from ..models import ModelManager
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def tokenize_long_prompt(tokenizer, prompt):
|
def tokenize_long_prompt(tokenizer, prompt, max_length=99999999):
|
||||||
# Get model_max_length from self.tokenizer
|
# Get model_max_length from self.tokenizer
|
||||||
length = tokenizer.model_max_length
|
length = tokenizer.model_max_length
|
||||||
|
|
||||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||||
tokenizer.model_max_length = 99999999
|
tokenizer.model_max_length = max_length
|
||||||
|
|
||||||
# Tokenize it!
|
# Tokenize it!
|
||||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
|
|||||||
Reference in New Issue
Block a user