support CogVideoX-5B (#184)

* support cogvideo

* update examples
This commit is contained in:
Zhongjie Duan
2024-09-03 11:37:54 +08:00
committed by GitHub
parent fe485b3fa1
commit d154bee18a
22 changed files with 2653 additions and 107 deletions

View File

@@ -15,6 +15,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
Until now, DiffSynth Studio has supported the following models:
* [CogVideo](https://huggingface.co/THUDM/CogVideoX-5b)
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
@@ -31,10 +32,16 @@ Until now, DiffSynth Studio has supported the following models:
## News
- **August 22, 2024** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including
- Text to video
- Video editing
- Self-upscaling
- Video interpolation
- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI!
- Use it in our [WebUI](#usage-in-webui).
- **August 21, 2024** FLUX is supported in DiffSynth-Studio.
- **August 21, 2024.** FLUX is supported in DiffSynth-Studio.
- Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md)
- LoRA, ControlNet, and additional models will be available soon.
@@ -120,6 +127,14 @@ download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp1
### Video Synthesis
#### Text-to-video using CogVideoX-5B
CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/)
The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation.
https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006
#### Long Video Synthesis
We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)

View File

@@ -32,11 +32,16 @@ from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.hunyuan_dit import HunyuanDiT
from ..models.flux_dit import FluxDiT
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder
from ..models.cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
model_loader_configs = [
@@ -70,7 +75,10 @@ model_loader_configs = [
(None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"),
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai")
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
]
huggingface_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -79,6 +87,7 @@ huggingface_model_loader_configs = [
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.
@@ -264,7 +273,27 @@ preset_models_on_modelscope = {
("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"),
("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"),
("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"),
]
],
# ESRGAN
"ESRGAN_x4": [
("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"),
],
# RIFE
"RIFE": [
("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"),
],
# CogVideo
"CogVideoX-5B": [
("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"),
("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"),
("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"),
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
@@ -296,4 +325,7 @@ Preset_model_id: TypeAlias = Literal[
"ControlNet_union_sdxl_promax",
"FLUX.1-dev",
"SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0",
"ESRGAN_x4",
"RIFE",
"CogVideoX-5B",
]

View File

@@ -41,7 +41,7 @@ class RRDB(torch.nn.Module):
class RRDBNet(torch.nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs):
super(RRDBNet, self).__init__()
self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
@@ -65,6 +65,21 @@ class RRDBNet(torch.nn.Module):
feat = self.lrelu(self.conv_up2(feat))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
@staticmethod
def state_dict_converter():
return RRDBNetStateDictConverter()
class RRDBNetStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict, {"upcast_to_float32": True}
def from_civitai(self, state_dict):
return state_dict, {"upcast_to_float32": True}
class ESRGAN(torch.nn.Module):
@@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module):
self.model = model
@staticmethod
def from_pretrained(model_path):
model = RRDBNet()
state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
model.load_state_dict(state_dict)
model.eval()
return ESRGAN(model)
def from_model_manager(model_manager):
return ESRGAN(model_manager.fetch_model("esrgan"))
def process_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)

View File

@@ -58,7 +58,7 @@ class IFBlock(nn.Module):
class IFNet(nn.Module):
def __init__(self):
def __init__(self, **kwargs):
super(IFNet, self).__init__()
self.block0 = IFBlock(7+4, c=90)
self.block1 = IFBlock(7+4, c=90)
@@ -113,7 +113,7 @@ class IFNetStateDictConverter:
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
return self.from_diffusers(state_dict), {"upcast_to_float32": True}
class RIFEInterpolater:
@@ -125,7 +125,7 @@ class RIFEInterpolater:
@staticmethod
def from_model_manager(model_manager):
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_image(self, image):
width, height = image.size
@@ -203,7 +203,7 @@ class RIFESmoother(RIFEInterpolater):
@staticmethod
def from_model_manager(model_manager):
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device)
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
output_tensor = []

395
diffsynth/models/cog_dit.py Normal file
View File

@@ -0,0 +1,395 @@
import torch
from einops import rearrange, repeat
from .sd3_dit import TimestepEmbeddings
from .attention import Attention
from .utils import load_state_dict_from_folder
from .tiler import TileWorker2Dto3D
import numpy as np
class CogPatchify(torch.nn.Module):
def __init__(self, dim_in, dim_out, patch_size) -> None:
super().__init__()
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
return hidden_states
class CogAdaLayerNorm(torch.nn.Module):
def __init__(self, dim, dim_cond, single=False):
super().__init__()
self.single = single
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
def forward(self, hidden_states, prompt_emb, emb):
emb = self.linear(torch.nn.functional.silu(emb))
if self.single:
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
return hidden_states
else:
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
return hidden_states, prompt_emb, gate_a, gate_b
class CogDiTBlock(torch.nn.Module):
def __init__(self, dim, dim_cond, num_heads):
super().__init__()
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
self.ff = torch.nn.Sequential(
torch.nn.Linear(dim, dim*4),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(dim*4, dim)
)
def apply_rotary_emb(self, x, freqs_cis):
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
q = self.norm_q(q)
k = self.norm_k(k)
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
return q, k, v
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
# Attention
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
hidden_states, prompt_emb, time_emb
)
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attention_io = self.attn1(
attention_io,
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
)
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
# Feed forward
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
hidden_states, prompt_emb, time_emb
)
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_io = self.ff(ff_io)
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
return hidden_states, prompt_emb
class CogDiT(torch.nn.Module):
def __init__(self):
super().__init__()
self.patchify = CogPatchify(16, 3072, 2)
self.time_embedder = TimestepEmbeddings(3072, 512)
self.context_embedder = torch.nn.Linear(4096, 3072)
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_3d_rotary_pos_embed(
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
):
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
dim_t = embed_dim // 4
dim_h = embed_dim // 8 * 3
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
):
grid_height = height // 2
grid_width = width // 2
base_size_width = 720 // (8 * 2)
base_size_height = 480 // (8 * 2)
grid_crops_coords = self.get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
embed_dim=64,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def build_mask(self, T, H, W, dtype, device, is_bound):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
B, C, T, H, W = hidden_states.shape
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = max(H - tile_size, 0), H
if w_ > W: w, w_ = max(W - tile_size, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
mask = self.build_mask(
value.shape[2], (hr-hl), (wr-wl),
hidden_states.dtype, hidden_states.device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
)
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
value[:, :, :, hl:hr, wl:wr] += model_output * mask
weight[:, :, :, hl:hr, wl:wr] += mask
value = value / weight
return value
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30):
if tiled:
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
model_input=hidden_states,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
)
num_frames, height, width = hidden_states.shape[-3:]
if image_rotary_emb is None:
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
hidden_states = self.patchify(hidden_states)
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
for block in self.blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@staticmethod
def state_dict_converter():
return CogDiTStateDictConverter()
@staticmethod
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
model = CogDiT().to(torch_dtype)
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
model.load_state_dict(state_dict)
return model
class CogDiTStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"patch_embed.proj.weight": "patchify.proj.weight",
"patch_embed.proj.bias": "patchify.proj.bias",
"patch_embed.text_proj.weight": "context_embedder.weight",
"patch_embed.text_proj.bias": "context_embedder.bias",
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
"norm_final.weight": "norm_final.weight",
"norm_final.bias": "norm_final.bias",
"norm_out.linear.weight": "norm_out.linear.weight",
"norm_out.linear.bias": "norm_out.linear.bias",
"norm_out.norm.weight": "norm_out.norm.weight",
"norm_out.norm.bias": "norm_out.norm.bias",
"proj_out.weight": "proj_out.weight",
"proj_out.bias": "proj_out.bias",
}
suffix_dict = {
"norm1.linear.weight": "norm1.linear.weight",
"norm1.linear.bias": "norm1.linear.bias",
"norm1.norm.weight": "norm1.norm.weight",
"norm1.norm.bias": "norm1.norm.bias",
"attn1.norm_q.weight": "norm_q.weight",
"attn1.norm_q.bias": "norm_q.bias",
"attn1.norm_k.weight": "norm_k.weight",
"attn1.norm_k.bias": "norm_k.bias",
"attn1.to_q.weight": "attn1.to_q.weight",
"attn1.to_q.bias": "attn1.to_q.bias",
"attn1.to_k.weight": "attn1.to_k.weight",
"attn1.to_k.bias": "attn1.to_k.bias",
"attn1.to_v.weight": "attn1.to_v.weight",
"attn1.to_v.bias": "attn1.to_v.bias",
"attn1.to_out.0.weight": "attn1.to_out.weight",
"attn1.to_out.0.bias": "attn1.to_out.bias",
"norm2.linear.weight": "norm2.linear.weight",
"norm2.linear.bias": "norm2.linear.bias",
"norm2.norm.weight": "norm2.norm.weight",
"norm2.norm.bias": "norm2.norm.bias",
"ff.net.0.proj.weight": "ff.0.weight",
"ff.net.0.proj.bias": "ff.0.bias",
"ff.net.2.weight": "ff.2.weight",
"ff.net.2.bias": "ff.2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
if name == "patch_embed.proj.weight":
param = param.unsqueeze(2)
state_dict_[rename_dict[name]] = param
else:
names = name.split(".")
if names[0] == "transformer_blocks":
suffix = ".".join(names[2:])
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

518
diffsynth/models/cog_vae.py Normal file
View File

@@ -0,0 +1,518 @@
import torch
from einops import rearrange, repeat
from .tiler import TileWorker2Dto3D
class Downsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 2,
padding: int = 0,
compress_time: bool = False,
):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
if x.shape[-1] % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
# Pad the tensor
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
batch_size, channels, frames, height, width = x.shape
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
x = self.conv(x)
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
class Upsample3D(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
compress_time: bool = False,
) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
if self.compress_time:
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
# split first frame
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
x_first = x_first[:, :, None, :, :]
inputs = torch.cat([x_first, x_rest], dim=2)
elif inputs.shape[2] > 1:
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
else:
inputs = inputs.squeeze(2)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs[:, :, None, :, :]
else:
# only interpolate 2D
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = inputs.shape
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
inputs = self.conv(inputs)
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
return inputs
class CogVideoXSpatialNorm3D(torch.nn.Module):
def __init__(self, f_channels, zq_channels, groups):
super().__init__()
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
zq = torch.cat([z_first, z_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class Resnet3DBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
super().__init__()
self.nonlinearity = torch.nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
if in_channels != out_channels:
if use_conv_shortcut:
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
else:
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
else:
self.conv_shortcut = lambda x: x
def forward(self, hidden_states, zq):
residual = hidden_states
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = hidden_states + self.conv_shortcut(residual)
return hidden_states
class CachedConv3d(torch.nn.Conv3d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.cached_tensor = None
def clear_cache(self):
self.cached_tensor = None
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
if use_cache:
if self.cached_tensor is None:
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
input = torch.concat([self.cached_tensor, input], dim=2)
self.cached_tensor = input[:, :, -2:]
return super().forward(input)
class CogVAEDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Resnet3DBlock(512, 512, 16, 32),
Upsample3D(512, 512, compress_time=True),
Resnet3DBlock(512, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Resnet3DBlock(256, 256, 16, 32),
Upsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
Resnet3DBlock(128, 128, 16, 32),
])
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
sample = sample / self.scaling_factor
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.decode_small_video(x),
model_input=sample,
tile_size=tile_size, tile_stride=tile_stride,
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
progress_bar=progress_bar
)
else:
return self.decode_small_video(sample)
def decode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//2):
tl = i*2 + T%2 - (T%2 and i==0)
tr = i*2 + 2 + T%2
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEDecoderStateDictConverter()
class CogVAEEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.scaling_factor = 0.7
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
self.blocks = torch.nn.ModuleList([
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Resnet3DBlock(128, 128, None, 32),
Downsample3D(128, 128, compress_time=True),
Resnet3DBlock(128, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=True),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Resnet3DBlock(256, 256, None, 32),
Downsample3D(256, 256, compress_time=False),
Resnet3DBlock(256, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
Resnet3DBlock(512, 512, None, 32),
])
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
self.conv_act = torch.nn.SiLU()
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
def forward(self, sample):
hidden_states = self.conv_in(sample)
for block in self.blocks:
hidden_states = block(hidden_states, sample)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)[:, :16]
hidden_states = hidden_states * self.scaling_factor
return hidden_states
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
if tiled:
B, C, T, H, W = sample.shape
return TileWorker2Dto3D().tiled_forward(
forward_fn=lambda x: self.encode_small_video(x),
model_input=sample,
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
tile_device=sample.device, tile_dtype=sample.dtype,
computation_device=sample.device, computation_dtype=sample.dtype,
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
progress_bar=progress_bar
)
else:
return self.encode_small_video(sample)
def encode_small_video(self, sample):
B, C, T, H, W = sample.shape
computation_device = self.conv_in.weight.device
computation_dtype = self.conv_in.weight.dtype
value = []
for i in range(T//8):
t = i*8 + T%2 - (T%2 and i==0)
t_ = i*8 + 8 + T%2
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
value.append(model_output)
value = torch.concat(value, dim=2)
for name, module in self.named_modules():
if isinstance(module, CachedConv3d):
module.clear_cache()
return value
@staticmethod
def state_dict_converter():
return CogVAEEncoderStateDictConverter()
class CogVAEEncoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"encoder.conv_in.conv.weight": "conv_in.weight",
"encoder.conv_in.conv.bias": "conv_in.bias",
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
"encoder.norm_out.weight": "norm_out.weight",
"encoder.norm_out.bias": "norm_out.bias",
"encoder.conv_out.conv.weight": "conv_out.weight",
"encoder.conv_out.conv.bias": "conv_out.bias",
}
prefix_dict = {
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
"encoder.mid_block.resnets.0.": "blocks.15.",
"encoder.mid_block.resnets.1.": "blocks.16.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
"norm1.weight": "norm1.weight",
"norm1.bias": "norm1.bias",
"norm2.weight": "norm2.weight",
"norm2.bias": "norm2.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class CogVAEDecoderStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
rename_dict = {
"decoder.conv_in.conv.weight": "conv_in.weight",
"decoder.conv_in.conv.bias": "conv_in.bias",
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
"decoder.conv_out.conv.weight": "conv_out.weight",
"decoder.conv_out.conv.bias": "conv_out.bias"
}
prefix_dict = {
"decoder.mid_block.resnets.0.": "blocks.0.",
"decoder.mid_block.resnets.1.": "blocks.1.",
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
}
suffix_dict = {
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
"conv1.conv.weight": "conv1.weight",
"conv1.conv.bias": "conv1.bias",
"conv2.conv.weight": "conv2.weight",
"conv2.conv.bias": "conv2.bias",
"conv_shortcut.weight": "conv_shortcut.weight",
"conv_shortcut.bias": "conv_shortcut.bias",
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
else:
for prefix in prefix_dict:
if name.startswith(prefix):
suffix = name[len(prefix):]
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -43,93 +43,17 @@ from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
@@ -356,7 +280,7 @@ class ModelDetectorFromHuggingfaceFolder:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config:
if "architectures" not in config and "_class_name" not in config:
return False
return True
@@ -365,7 +289,8 @@ class ModelDetectorFromHuggingfaceFolder:
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
for architecture in config["architectures"]:
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture

View File

@@ -103,4 +103,78 @@ class TileWorker:
# Done!
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
return model_output
return model_output
class TileWorker2Dto3D:
"""
Process 3D tensors, but only enable TileWorker on 2D.
"""
def __init__(self):
pass
def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
border_width = (H + W) // 4 if border_width is None else border_width
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else t + 1,
pad if is_bound[1] else T - t,
pad if is_bound[2] else h + 1,
pad if is_bound[3] else H - h,
pad if is_bound[4] else w + 1,
pad if is_bound[5] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=dtype, device=device)
mask = rearrange(mask, "T H W -> 1 1 T H W")
return mask
def tiled_forward(
self,
forward_fn,
model_input,
tile_size, tile_stride,
tile_device="cpu", tile_dtype=torch.float32,
computation_device="cuda", computation_dtype=torch.float32,
border_width=None, scales=[1, 1, 1, 1],
progress_bar=lambda x:x
):
B, C, T, H, W = model_input.shape
scale_C, scale_T, scale_H, scale_W = scales
tile_size_H, tile_size_W = tile_size
tile_stride_H, tile_stride_W = tile_stride
value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride_H):
for w in range(0, W, tile_stride_W):
if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
continue
h_, w_ = h + tile_size_H, w + tile_size_W
if h_ > H: h, h_ = max(H - tile_size_H, 0), H
if w_ > W: w, w_ = max(W - tile_size_W, 0), W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in progress_bar(tasks):
mask = self.build_mask(
int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
tile_dtype, tile_device,
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
border_width=border_width
)
grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
value = value / weight
return value

96
diffsynth/models/utils.py Normal file
View File

@@ -0,0 +1,96 @@
import torch, os
from safetensors import safe_open
def load_state_dict_from_folder(file_path, torch_dtype=None):
state_dict = {}
for file_name in os.listdir(file_path):
if "." in file_name and file_name.split(".")[-1] in [
"safetensors", "bin", "ckpt", "pth", "pt"
]:
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
return state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files

View File

@@ -6,5 +6,6 @@ from .sd3_image import SD3ImagePipeline
from .hunyuan_image import HunyuanDiTImagePipeline
from .svd_video import SVDVideoPipeline
from .flux_image import FluxImagePipeline
from .cog_video import CogVideoPipeline
from .pipeline_runner import SDVideoPipelineRunner
KolorsImagePipeline = SDXLImagePipeline

View File

@@ -0,0 +1,131 @@
from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder
from ..prompters import CogPrompter
from ..schedulers import EnhancedDDIMScheduler
from .base import BasePipeline
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
from einops import rearrange
class CogVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction")
self.prompter = CogPrompter()
# models
self.text_encoder: FluxTextEncoder2 = None
self.dit: CogDiT = None
self.vae_encoder: CogVAEEncoder = None
self.vae_decoder: CogVAEDecoder = None
def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]):
self.text_encoder = model_manager.fetch_model("flux_text_encoder_2")
self.dit = model_manager.fetch_model("cog_dit")
self.vae_encoder = model_manager.fetch_model("cog_vae_encoder")
self.vae_decoder = model_manager.fetch_model("cog_vae_decoder")
self.prompter.fetch_models(self.text_encoder)
self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes)
@staticmethod
def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]):
pipe = CogVideoPipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype
)
pipe.fetch_models(model_manager, prompt_refiner_classes)
return pipe
def tensor2video(self, frames):
frames = rearrange(frames, "C T H W -> T H W C")
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
frames = [Image.fromarray(frame) for frame in frames]
return frames
def encode_prompt(self, prompt, positive=True):
prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
return {"prompt_emb": prompt_emb}
def prepare_extra_input(self, latents):
return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)}
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
input_video=None,
cfg_scale=7.0,
denoising_strength=1.0,
num_frames=49,
height=480,
width=720,
num_inference_steps=20,
tiled=False,
tile_size=(60, 90),
tile_stride=(30, 45),
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Tiler parameters
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# Prepare latent tensors
noise = torch.randn((1, 16, num_frames // 4 + 1, height//8, width//8), device="cpu", dtype=self.torch_dtype)
if denoising_strength == 1.0:
latents = noise.clone()
else:
input_video = self.preprocess_images(input_video)
input_video = torch.stack(input_video, dim=2)
latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0])
if not tiled: latents = latents.to(self.device)
# Encode prompt
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
# Extra input
extra_input = self.prepare_extra_input(latents)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
noise_pred_posi = self.dit(
latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input
)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(
latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
# DDIM
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# Update progress bar
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd)
video = self.tensor2video(video[0])
return video

View File

@@ -5,3 +5,4 @@ from .sd3_prompter import SD3Prompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter
from .kolors_prompter import KolorsPrompter
from .flux_prompter import FluxPrompter
from .cog_prompter import CogPrompter

View File

@@ -0,0 +1,46 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder2
from transformers import T5TokenizerFast
import os
class CogPrompter(BasePrompter):
def __init__(
self,
tokenizer_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/cog/tokenizer")
super().__init__()
self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path)
self.text_encoder: FluxTextEncoder2 = None
def fetch_models(self, text_encoder: FluxTextEncoder2 = None):
self.text_encoder = text_encoder
def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device):
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
).input_ids.to(device)
prompt_emb = text_encoder(input_ids)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
def encode_prompt(
self,
prompt,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder, self.tokenizer, 226, device)
return prompt_emb

View File

@@ -1,5 +1,5 @@
from .base_prompter import BasePrompter, tokenize_long_prompt
from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings
from ..models.utils import load_state_dict, search_for_embeddings
from ..models import SDTextEncoder
from transformers import CLIPTokenizer
import torch, os

View File

@@ -3,7 +3,7 @@ import torch, math
class EnhancedDDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False):
self.num_train_timesteps = num_train_timesteps
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
@@ -11,11 +11,33 @@ class EnhancedDDIMScheduler():
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
if rescale_zero_terminal_snr:
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod.tolist()
self.set_timesteps(10)
self.prediction_type = prediction_type
def rescale_zero_terminal_snr(self, alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt.square() # Revert sqrt
return alphas_bar
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
# The timesteps are aligned to 999...0, which is different from other implementations,
# but I think this implementation is more reasonable in theory.

View File

@@ -0,0 +1,102 @@
{
"<extra_id_0>": 32099,
"<extra_id_10>": 32089,
"<extra_id_11>": 32088,
"<extra_id_12>": 32087,
"<extra_id_13>": 32086,
"<extra_id_14>": 32085,
"<extra_id_15>": 32084,
"<extra_id_16>": 32083,
"<extra_id_17>": 32082,
"<extra_id_18>": 32081,
"<extra_id_19>": 32080,
"<extra_id_1>": 32098,
"<extra_id_20>": 32079,
"<extra_id_21>": 32078,
"<extra_id_22>": 32077,
"<extra_id_23>": 32076,
"<extra_id_24>": 32075,
"<extra_id_25>": 32074,
"<extra_id_26>": 32073,
"<extra_id_27>": 32072,
"<extra_id_28>": 32071,
"<extra_id_29>": 32070,
"<extra_id_2>": 32097,
"<extra_id_30>": 32069,
"<extra_id_31>": 32068,
"<extra_id_32>": 32067,
"<extra_id_33>": 32066,
"<extra_id_34>": 32065,
"<extra_id_35>": 32064,
"<extra_id_36>": 32063,
"<extra_id_37>": 32062,
"<extra_id_38>": 32061,
"<extra_id_39>": 32060,
"<extra_id_3>": 32096,
"<extra_id_40>": 32059,
"<extra_id_41>": 32058,
"<extra_id_42>": 32057,
"<extra_id_43>": 32056,
"<extra_id_44>": 32055,
"<extra_id_45>": 32054,
"<extra_id_46>": 32053,
"<extra_id_47>": 32052,
"<extra_id_48>": 32051,
"<extra_id_49>": 32050,
"<extra_id_4>": 32095,
"<extra_id_50>": 32049,
"<extra_id_51>": 32048,
"<extra_id_52>": 32047,
"<extra_id_53>": 32046,
"<extra_id_54>": 32045,
"<extra_id_55>": 32044,
"<extra_id_56>": 32043,
"<extra_id_57>": 32042,
"<extra_id_58>": 32041,
"<extra_id_59>": 32040,
"<extra_id_5>": 32094,
"<extra_id_60>": 32039,
"<extra_id_61>": 32038,
"<extra_id_62>": 32037,
"<extra_id_63>": 32036,
"<extra_id_64>": 32035,
"<extra_id_65>": 32034,
"<extra_id_66>": 32033,
"<extra_id_67>": 32032,
"<extra_id_68>": 32031,
"<extra_id_69>": 32030,
"<extra_id_6>": 32093,
"<extra_id_70>": 32029,
"<extra_id_71>": 32028,
"<extra_id_72>": 32027,
"<extra_id_73>": 32026,
"<extra_id_74>": 32025,
"<extra_id_75>": 32024,
"<extra_id_76>": 32023,
"<extra_id_77>": 32022,
"<extra_id_78>": 32021,
"<extra_id_79>": 32020,
"<extra_id_7>": 32092,
"<extra_id_80>": 32019,
"<extra_id_81>": 32018,
"<extra_id_82>": 32017,
"<extra_id_83>": 32016,
"<extra_id_84>": 32015,
"<extra_id_85>": 32014,
"<extra_id_86>": 32013,
"<extra_id_87>": 32012,
"<extra_id_88>": 32011,
"<extra_id_89>": 32010,
"<extra_id_8>": 32091,
"<extra_id_90>": 32009,
"<extra_id_91>": 32008,
"<extra_id_92>": 32007,
"<extra_id_93>": 32006,
"<extra_id_94>": 32005,
"<extra_id_95>": 32004,
"<extra_id_96>": 32003,
"<extra_id_97>": 32002,
"<extra_id_98>": 32001,
"<extra_id_99>": 32000,
"<extra_id_9>": 32090
}

View File

@@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

Binary file not shown.

View File

@@ -0,0 +1,940 @@
{
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": true,
"normalized": false,
"rstrip": true,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": true,
"model_max_length": 226,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}

View File

@@ -1,8 +1,46 @@
# Text to Video
In DiffSynth Studio, we can use AnimateDiff and SVD to generate videos. However, these models usually generate terrible contents. We do not recommend users to use these models, until a more powerful video model emerges.
In DiffSynth Studio, we can use some video models to generate videos.
### Example 7: Text to Video
### Example: Text-to-Video using CogVideoX-5B (Experimental)
See [cogvideo_text_to_video.py](cogvideo_text_to_video.py).
First, we generate a video using prompt "an astronaut riding a horse on Mars".
https://github.com/user-attachments/assets/4c91c1cd-e4a0-471a-bd8d-24d761262941
Then, we convert the astronaut to a robot.
https://github.com/user-attachments/assets/225a00a4-2bc8-4740-8e86-a64b460a29ec
Upscale the video using the model itself.
https://github.com/user-attachments/assets/c02cb30c-de60-473c-8242-32c67b3155ad
Make the video look smoother by interpolating frames.
https://github.com/user-attachments/assets/f0e465b4-45df-4435-ab10-7a084ca2b0a0
Here is another example.
First, we generate a video using prompt "a dog is running".
https://github.com/user-attachments/assets/e3696297-99f5-4d0c-a5ca-1d1566db85b4
Then, we add a blue collar to the dog.
https://github.com/user-attachments/assets/7ff22be7-4390-4d33-ae6c-53f6f056e18d
Upscale the video using the model itself.
https://github.com/user-attachments/assets/a909c32c-0b7d-495c-a53c-d23a99a3d3e9
Make the video look smoother by interpolating frames.
https://github.com/user-attachments/assets/ea37c150-97a0-4858-8003-0c2e5eef3331
### Example: Text-to-Video using AnimateDiff
Generate a video using a Stable Diffusion model and an AnimateDiff model. We can break the limitation of number of frames! See [sd_text_to_video.py](./sd_text_to_video.py).

View File

@@ -0,0 +1,73 @@
from diffsynth import ModelManager, save_video, VideoData, download_models, CogVideoPipeline
from diffsynth.extensions.RIFE import RIFEInterpolater
import torch, os
os.environ["TOKENIZERS_PARALLELISM"] = "True"
def text_to_video(model_manager, prompt, seed, output_path):
pipe = CogVideoPipeline.from_model_manager(model_manager)
torch.manual_seed(seed)
video = pipe(
prompt=prompt,
height=480, width=720,
cfg_scale=7.0, num_inference_steps=200
)
save_video(video, output_path, fps=8, quality=5)
def edit_video(model_manager, prompt, seed, input_path, output_path):
pipe = CogVideoPipeline.from_model_manager(model_manager)
input_video = VideoData(video_file=input_path)
torch.manual_seed(seed)
video = pipe(
prompt=prompt,
height=480, width=720,
cfg_scale=7.0, num_inference_steps=200,
input_video=input_video, denoising_strength=0.7
)
save_video(video, output_path, fps=8, quality=5)
def self_upscale(model_manager, prompt, seed, input_path, output_path):
pipe = CogVideoPipeline.from_model_manager(model_manager)
input_video = VideoData(video_file=input_path, height=480*2, width=720*2).raw_data()
torch.manual_seed(seed)
video = pipe(
prompt=prompt,
height=480*2, width=720*2,
cfg_scale=7.0, num_inference_steps=30,
input_video=input_video, denoising_strength=0.4, tiled=True
)
save_video(video, output_path, fps=8, quality=7)
def interpolate_video(model_manager, input_path, output_path):
rife = RIFEInterpolater.from_model_manager(model_manager)
video = VideoData(video_file=input_path).raw_data()
video = rife.interpolate(video, num_iter=2)
save_video(video, output_path, fps=32, quality=5)
download_models(["CogVideoX-5B", "RIFE"])
model_manager = ModelManager(torch_dtype=torch.bfloat16)
model_manager.load_models([
"models/CogVideo/CogVideoX-5b/text_encoder",
"models/CogVideo/CogVideoX-5b/transformer",
"models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors",
"models/RIFE/flownet.pkl",
])
# Example 1
text_to_video(model_manager, "an astronaut riding a horse on Mars.", 0, "1_video_1.mp4")
edit_video(model_manager, "a white robot riding a horse on Mars.", 1, "1_video_1.mp4", "1_video_2.mp4")
self_upscale(model_manager, "a white robot riding a horse on Mars.", 2, "1_video_2.mp4", "1_video_3.mp4")
interpolate_video(model_manager, "1_video_3.mp4", "1_video_4.mp4")
# Example 2
text_to_video(model_manager, "a dog is running.", 1, "2_video_1.mp4")
edit_video(model_manager, "a dog with blue collar.", 2, "2_video_1.mp4", "2_video_2.mp4")
self_upscale(model_manager, "a dog with blue collar.", 3, "2_video_2.mp4", "2_video_3.mp4")
interpolate_video(model_manager, "2_video_3.mp4", "2_video_4.mp4")

View File

@@ -7,4 +7,5 @@ imageio[ffmpeg]
safetensors
einops
sentencepiece
protobuf
modelscope