diff --git a/README.md b/README.md index e16f1e3..c29a12e 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu Until now, DiffSynth Studio has supported the following models: +* [CogVideo](https://huggingface.co/THUDM/CogVideoX-5b) * [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev) * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) * [Kolors](https://huggingface.co/Kwai-Kolors/Kolors) @@ -31,10 +32,16 @@ Until now, DiffSynth Studio has supported the following models: ## News -- **August 22, 2024** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI! +- **August 22, 2024.** CogVideoX-5B is supported in this project. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including + - Text to video + - Video editing + - Self-upscaling + - Video interpolation + +- **August 22, 2024.** We have implemented an interesting painter that supports all text-to-image models. Now you can create stunning images using the painter, with assistance from AI! - Use it in our [WebUI](#usage-in-webui). -- **August 21, 2024** FLUX is supported in DiffSynth-Studio. +- **August 21, 2024.** FLUX is supported in DiffSynth-Studio. - Enable CFG and highres-fix to improve visual quality. See [here](/examples/image_synthesis/README.md) - LoRA, ControlNet, and additional models will be available soon. @@ -120,6 +127,14 @@ download_from_huggingface("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.fp1 ### Video Synthesis +#### Text-to-video using CogVideoX-5B + +CogVideoX-5B is released by ZhiPu. We provide an improved pipeline, supporting text-to-video, video editing, self-upscaling and video interpolation. [`examples/video_synthesis`](./examples/video_synthesis/) + +The video on the left is generated using the original text-to-video pipeline, while the video on the right is the result after editing and frame interpolation. + +https://github.com/user-attachments/assets/26b044c1-4a60-44a4-842f-627ff289d006 + #### Long Video Synthesis We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/) diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 3e492d5..2184338 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -32,11 +32,16 @@ from ..models.sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder from ..models.hunyuan_dit import HunyuanDiT - from ..models.flux_dit import FluxDiT from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder +from ..models.cog_vae import CogVAEEncoder, CogVAEDecoder +from ..models.cog_dit import CogDiT + +from ..extensions.RIFE import IFNet +from ..extensions.ESRGAN import RRDBNet + model_loader_configs = [ @@ -70,7 +75,10 @@ model_loader_configs = [ (None, "94eefa3dac9cec93cb1ebaf1747d7b78", ["flux_text_encoder_1"], [FluxTextEncoder1], "diffusers"), (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"), (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"), - (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai") + (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"), + (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), + (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), + (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -79,7 +87,9 @@ huggingface_model_loader_configs = [ ("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None), ("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None), ("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None), + ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None), ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), + ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), ] patch_model_loader_configs = [ # These configs are provided for detecting model type automatically. @@ -230,6 +240,18 @@ preset_models_on_modelscope = { ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"), ], + # Omost prompt + "OmostPrompt":[ + ("Omost/omost-llama-3-8b-4bits", "model-00001-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model-00002-of-00002.safetensors", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "tokenizer_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "generation_config.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "model.safetensors.index.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ("Omost/omost-llama-3-8b-4bits", "special_tokens_map.json", "models/OmostPrompt/omost-llama-3-8b-4bits"), + ], + # Translator "opus-mt-zh-en": [ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"), @@ -276,7 +298,27 @@ preset_models_on_modelscope = { ("AI-ModelScope/FLUX.1-dev", "text_encoder_2/model.safetensors.index.json", "models/FLUX/FLUX.1-dev/text_encoder_2"), ("AI-ModelScope/FLUX.1-dev", "ae.safetensors", "models/FLUX/FLUX.1-dev"), ("AI-ModelScope/FLUX.1-dev", "flux1-dev.safetensors", "models/FLUX/FLUX.1-dev"), - ] + ], + # ESRGAN + "ESRGAN_x4": [ + ("AI-ModelScope/Real-ESRGAN", "RealESRGAN_x4.pth", "models/ESRGAN"), + ], + # RIFE + "RIFE": [ + ("AI-ModelScope/RIFE", "flownet.pkl", "models/RIFE"), + ], + # CogVideo + "CogVideoX-5B": [ + ("ZhipuAI/CogVideoX-5b", "text_encoder/config.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "text_encoder/model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/text_encoder"), + ("ZhipuAI/CogVideoX-5b", "transformer/config.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model.safetensors.index.json", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00001-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "transformer/diffusion_pytorch_model-00002-of-00002.safetensors", "models/CogVideo/CogVideoX-5b/transformer"), + ("ZhipuAI/CogVideoX-5b", "vae/diffusion_pytorch_model.safetensors", "models/CogVideo/CogVideoX-5b/vae"), + ], } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -309,4 +351,8 @@ Preset_model_id: TypeAlias = Literal[ "FLUX.1-dev", "SDXL_lora_zyd232_ChineseInkStyle_SDXL_v1_0", "QwenPrompt", + "OmostPrompt", + "ESRGAN_x4", + "RIFE", + "CogVideoX-5B", ] \ No newline at end of file diff --git a/diffsynth/extensions/ESRGAN/__init__.py b/diffsynth/extensions/ESRGAN/__init__.py index e71cd3f..00b90d1 100644 --- a/diffsynth/extensions/ESRGAN/__init__.py +++ b/diffsynth/extensions/ESRGAN/__init__.py @@ -41,7 +41,7 @@ class RRDB(torch.nn.Module): class RRDBNet(torch.nn.Module): - def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32): + def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, **kwargs): super(RRDBNet, self).__init__() self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)]) @@ -65,6 +65,21 @@ class RRDBNet(torch.nn.Module): feat = self.lrelu(self.conv_up2(feat)) out = self.conv_last(self.lrelu(self.conv_hr(feat))) return out + + @staticmethod + def state_dict_converter(): + return RRDBNetStateDictConverter() + + +class RRDBNetStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict, {"upcast_to_float32": True} + + def from_civitai(self, state_dict): + return state_dict, {"upcast_to_float32": True} class ESRGAN(torch.nn.Module): @@ -73,12 +88,8 @@ class ESRGAN(torch.nn.Module): self.model = model @staticmethod - def from_pretrained(model_path): - model = RRDBNet() - state_dict = torch.load(model_path, map_location="cpu")["params_ema"] - model.load_state_dict(state_dict) - model.eval() - return ESRGAN(model) + def from_model_manager(model_manager): + return ESRGAN(model_manager.fetch_model("esrgan")) def process_image(self, image): image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1) diff --git a/diffsynth/extensions/RIFE/__init__.py b/diffsynth/extensions/RIFE/__init__.py index 086a51c..e76c391 100644 --- a/diffsynth/extensions/RIFE/__init__.py +++ b/diffsynth/extensions/RIFE/__init__.py @@ -58,7 +58,7 @@ class IFBlock(nn.Module): class IFNet(nn.Module): - def __init__(self): + def __init__(self, **kwargs): super(IFNet, self).__init__() self.block0 = IFBlock(7+4, c=90) self.block1 = IFBlock(7+4, c=90) @@ -113,7 +113,7 @@ class IFNetStateDictConverter: return state_dict_ def from_civitai(self, state_dict): - return self.from_diffusers(state_dict) + return self.from_diffusers(state_dict), {"upcast_to_float32": True} class RIFEInterpolater: @@ -125,7 +125,7 @@ class RIFEInterpolater: @staticmethod def from_model_manager(model_manager): - return RIFEInterpolater(model_manager.RIFE, device=model_manager.device) + return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device) def process_image(self, image): width, height = image.size @@ -203,7 +203,7 @@ class RIFESmoother(RIFEInterpolater): @staticmethod def from_model_manager(model_manager): - return RIFESmoother(model_manager.RIFE, device=model_manager.device) + return RIFEInterpolater(model_manager.fetch_model("rife"), device=model_manager.device) def process_tensors(self, input_tensor, scale=1.0, batch_size=4): output_tensor = [] diff --git a/diffsynth/models/cog_dit.py b/diffsynth/models/cog_dit.py new file mode 100644 index 0000000..c6d0cc6 --- /dev/null +++ b/diffsynth/models/cog_dit.py @@ -0,0 +1,395 @@ +import torch +from einops import rearrange, repeat +from .sd3_dit import TimestepEmbeddings +from .attention import Attention +from .utils import load_state_dict_from_folder +from .tiler import TileWorker2Dto3D +import numpy as np + + + +class CogPatchify(torch.nn.Module): + def __init__(self, dim_in, dim_out, patch_size) -> None: + super().__init__() + self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size)) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C") + return hidden_states + + + +class CogAdaLayerNorm(torch.nn.Module): + def __init__(self, dim, dim_cond, single=False): + super().__init__() + self.single = single + self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6)) + self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5) + + + def forward(self, hidden_states, prompt_emb, emb): + emb = self.linear(torch.nn.functional.silu(emb)) + if self.single: + shift, scale = emb.unsqueeze(1).chunk(2, dim=2) + hidden_states = self.norm(hidden_states) * (1 + scale) + shift + return hidden_states + else: + shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2) + hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a + prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b + return hidden_states, prompt_emb, gate_a, gate_b + + + +class CogDiTBlock(torch.nn.Module): + def __init__(self, dim, dim_cond, num_heads): + super().__init__() + self.norm1 = CogAdaLayerNorm(dim, dim_cond) + self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True) + self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True) + self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True) + + self.norm2 = CogAdaLayerNorm(dim, dim_cond) + self.ff = torch.nn.Sequential( + torch.nn.Linear(dim, dim*4), + torch.nn.GELU(approximate="tanh"), + torch.nn.Linear(dim*4, dim) + ) + + + def apply_rotary_emb(self, x, freqs_cis): + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + + def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length): + q = self.norm_q(q) + k = self.norm_k(k) + q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb) + k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb) + return q, k, v + + + def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb): + # Attention + norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1( + hidden_states, prompt_emb, time_emb + ) + attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + attention_io = self.attn1( + attention_io, + qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1]) + ) + + hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:] + prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]] + + # Feed forward + norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2( + hidden_states, prompt_emb, time_emb + ) + ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) + ff_io = self.ff(ff_io) + + hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:] + prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]] + + return hidden_states, prompt_emb + + + +class CogDiT(torch.nn.Module): + def __init__(self): + super().__init__() + self.patchify = CogPatchify(16, 3072, 2) + self.time_embedder = TimestepEmbeddings(3072, 512) + self.context_embedder = torch.nn.Linear(4096, 3072) + self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)]) + self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True) + self.norm_out = CogAdaLayerNorm(3072, 512, single=True) + self.proj_out = torch.nn.Linear(3072, 64, bias=True) + + + def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + + def get_3d_rotary_pos_embed( + self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True + ): + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t)) + grid_t = torch.from_numpy(grid_t).float() + freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t) + freqs_t = freqs_t.repeat_interleave(2, dim=-1) + + # Spatial frequencies for height and width + freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h)) + freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w)) + grid_h = torch.from_numpy(grid_h).float() + grid_w = torch.from_numpy(grid_w).float() + freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h) + freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w) + freqs_h = freqs_h.repeat_interleave(2, dim=-1) + freqs_w = freqs_w.repeat_interleave(2, dim=-1) + + # Broadcast and concatenate tensors along specified dimension + def broadcast(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(t.shape) for t in tensors} + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*(list(t.shape) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*(len(set(t[1])) <= 2 for t in expandable_dims)] + ), "invalid dimensions for broadcastable concatenation" + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)] + return torch.cat(tensors, dim=dim) + + freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1) + + t, h, w, d = freqs.shape + freqs = freqs.view(t * h * w, d) + + # Generate sine and cosine components + sin = freqs.sin() + cos = freqs.cos() + + if use_real: + return cos, sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + + def prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ): + grid_height = height // 2 + grid_width = width // 2 + base_size_width = 720 // (8 * 2) + base_size_height = 480 // (8 * 2) + + grid_crops_coords = self.get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed( + embed_dim=64, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + use_real=True, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + + def unpatchify(self, hidden_states, height, width): + hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) + return hidden_states + + + def build_mask(self, T, H, W, dtype, device, is_bound): + t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W) + h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W) + w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W) + border_width = (H + W) // 4 + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else t + 1, + pad if is_bound[1] else T - t, + pad if is_bound[2] else h + 1, + pad if is_bound[3] else H - h, + pad if is_bound[4] else w + 1, + pad if is_bound[5] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=dtype, device=device) + mask = rearrange(mask, "T H W -> 1 1 T H W") + return mask + + + def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)): + B, C, T, H, W = hidden_states.shape + value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device) + weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride): + for w in range(0, W, tile_stride): + if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W): + continue + h_, w_ = h + tile_size, w + tile_size + if h_ > H: h, h_ = max(H - tile_size, 0), H + if w_ > W: w, w_ = max(W - tile_size, 0), W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in tasks: + mask = self.build_mask( + value.shape[2], (hr-hl), (wr-wl), + hidden_states.dtype, hidden_states.device, + is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W) + ) + model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb) + value[:, :, :, hl:hr, wl:wr] += model_output * mask + weight[:, :, :, hl:hr, wl:wr] += mask + value = value / weight + + return value + + + def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30): + if tiled: + return TileWorker2Dto3D().tiled_forward( + forward_fn=lambda x: self.forward(x, timestep, prompt_emb), + model_input=hidden_states, + tile_size=tile_size, tile_stride=tile_stride, + tile_device=hidden_states.device, tile_dtype=hidden_states.dtype, + computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype + ) + num_frames, height, width = hidden_states.shape[-3:] + if image_rotary_emb is None: + image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device) + hidden_states = self.patchify(hidden_states) + time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype) + prompt_emb = self.context_embedder(prompt_emb) + for block in self.blocks: + hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb) + + hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, prompt_emb.shape[1]:] + hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb) + hidden_states = self.proj_out(hidden_states) + hidden_states = self.unpatchify(hidden_states, height, width) + + return hidden_states + + + @staticmethod + def state_dict_converter(): + return CogDiTStateDictConverter() + + + @staticmethod + def from_pretrained(file_path, torch_dtype=torch.bfloat16): + model = CogDiT().to(torch_dtype) + state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype) + state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict) + model.load_state_dict(state_dict) + return model + + + +class CogDiTStateDictConverter: + def __init__(self): + pass + + + def from_diffusers(self, state_dict): + rename_dict = { + "patch_embed.proj.weight": "patchify.proj.weight", + "patch_embed.proj.bias": "patchify.proj.bias", + "patch_embed.text_proj.weight": "context_embedder.weight", + "patch_embed.text_proj.bias": "context_embedder.bias", + "time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight", + "time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias", + "time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight", + "time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias", + + "norm_final.weight": "norm_final.weight", + "norm_final.bias": "norm_final.bias", + "norm_out.linear.weight": "norm_out.linear.weight", + "norm_out.linear.bias": "norm_out.linear.bias", + "norm_out.norm.weight": "norm_out.norm.weight", + "norm_out.norm.bias": "norm_out.norm.bias", + "proj_out.weight": "proj_out.weight", + "proj_out.bias": "proj_out.bias", + } + suffix_dict = { + "norm1.linear.weight": "norm1.linear.weight", + "norm1.linear.bias": "norm1.linear.bias", + "norm1.norm.weight": "norm1.norm.weight", + "norm1.norm.bias": "norm1.norm.bias", + "attn1.norm_q.weight": "norm_q.weight", + "attn1.norm_q.bias": "norm_q.bias", + "attn1.norm_k.weight": "norm_k.weight", + "attn1.norm_k.bias": "norm_k.bias", + "attn1.to_q.weight": "attn1.to_q.weight", + "attn1.to_q.bias": "attn1.to_q.bias", + "attn1.to_k.weight": "attn1.to_k.weight", + "attn1.to_k.bias": "attn1.to_k.bias", + "attn1.to_v.weight": "attn1.to_v.weight", + "attn1.to_v.bias": "attn1.to_v.bias", + "attn1.to_out.0.weight": "attn1.to_out.weight", + "attn1.to_out.0.bias": "attn1.to_out.bias", + "norm2.linear.weight": "norm2.linear.weight", + "norm2.linear.bias": "norm2.linear.bias", + "norm2.norm.weight": "norm2.norm.weight", + "norm2.norm.bias": "norm2.norm.bias", + "ff.net.0.proj.weight": "ff.0.weight", + "ff.net.0.proj.bias": "ff.0.bias", + "ff.net.2.weight": "ff.2.weight", + "ff.net.2.bias": "ff.2.bias", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + if name == "patch_embed.proj.weight": + param = param.unsqueeze(2) + state_dict_[rename_dict[name]] = param + else: + names = name.split(".") + if names[0] == "transformer_blocks": + suffix = ".".join(names[2:]) + state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param + return state_dict_ + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) diff --git a/diffsynth/models/cog_vae.py b/diffsynth/models/cog_vae.py new file mode 100644 index 0000000..24ab3b3 --- /dev/null +++ b/diffsynth/models/cog_vae.py @@ -0,0 +1,518 @@ +import torch +from einops import rearrange, repeat +from .tiler import TileWorker2Dto3D + + + +class Downsample3D(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 2, + padding: int = 0, + compress_time: bool = False, + ): + super().__init__() + + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: + if self.compress_time: + batch_size, channels, frames, height, width = x.shape + + # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) + x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) + + if x.shape[-1] % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) + x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) + + x = torch.cat([x_first[..., None], x_rest], dim=-1) + # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) + x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) + # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) + x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) + + # Pad the tensor + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + batch_size, channels, frames, height, width = x.shape + # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) + x = self.conv(x) + # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) + x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + + +class Upsample3D(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + compress_time: bool = False, + ) -> None: + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: + if self.compress_time: + if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: + # split first frame + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + + x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0) + x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + inputs = torch.cat([x_first, x_rest], dim=2) + elif inputs.shape[2] > 1: + inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0) + else: + inputs = inputs.squeeze(2) + inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + # only interpolate 2D + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0) + inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = inputs.shape + inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + inputs = self.conv(inputs) + inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) + + return inputs + + + +class CogVideoXSpatialNorm3D(torch.nn.Module): + def __init__(self, f_channels, zq_channels, groups): + super().__init__() + self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) + z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) + + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + + +class Resnet3DBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False): + super().__init__() + self.nonlinearity = torch.nn.SiLU() + if spatial_norm_dim is None: + self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups) + self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups) + + self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1)) + + self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1)) + + if in_channels != out_channels: + if use_conv_shortcut: + self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1)) + else: + self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1) + else: + self.conv_shortcut = lambda x: x + + + def forward(self, hidden_states, zq): + residual = hidden_states + + hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + hidden_states = hidden_states + self.conv_shortcut(residual) + + return hidden_states + + + +class CachedConv3d(torch.nn.Conv3d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): + super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.cached_tensor = None + + + def clear_cache(self): + self.cached_tensor = None + + + def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor: + if use_cache: + if self.cached_tensor is None: + self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2) + input = torch.concat([self.cached_tensor, input], dim=2) + self.cached_tensor = input[:, :, -2:] + return super().forward(input) + + + +class CogVAEDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.scaling_factor = 0.7 + self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1)) + + self.blocks = torch.nn.ModuleList([ + Resnet3DBlock(512, 512, 16, 32), + Resnet3DBlock(512, 512, 16, 32), + Resnet3DBlock(512, 512, 16, 32), + Resnet3DBlock(512, 512, 16, 32), + Resnet3DBlock(512, 512, 16, 32), + Resnet3DBlock(512, 512, 16, 32), + Upsample3D(512, 512, compress_time=True), + Resnet3DBlock(512, 256, 16, 32), + Resnet3DBlock(256, 256, 16, 32), + Resnet3DBlock(256, 256, 16, 32), + Resnet3DBlock(256, 256, 16, 32), + Upsample3D(256, 256, compress_time=True), + Resnet3DBlock(256, 256, 16, 32), + Resnet3DBlock(256, 256, 16, 32), + Resnet3DBlock(256, 256, 16, 32), + Resnet3DBlock(256, 256, 16, 32), + Upsample3D(256, 256, compress_time=False), + Resnet3DBlock(256, 128, 16, 32), + Resnet3DBlock(128, 128, 16, 32), + Resnet3DBlock(128, 128, 16, 32), + Resnet3DBlock(128, 128, 16, 32), + ]) + + self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32) + self.conv_act = torch.nn.SiLU() + self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1)) + + + def forward(self, sample): + sample = sample / self.scaling_factor + hidden_states = self.conv_in(sample) + + for block in self.blocks: + hidden_states = block(hidden_states, sample) + + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + + def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x): + if tiled: + B, C, T, H, W = sample.shape + return TileWorker2Dto3D().tiled_forward( + forward_fn=lambda x: self.decode_small_video(x), + model_input=sample, + tile_size=tile_size, tile_stride=tile_stride, + tile_device=sample.device, tile_dtype=sample.dtype, + computation_device=sample.device, computation_dtype=sample.dtype, + scales=(3/16, (T//2*8+T%2)/T, 8, 8), + progress_bar=progress_bar + ) + else: + return self.decode_small_video(sample) + + + def decode_small_video(self, sample): + B, C, T, H, W = sample.shape + computation_device = self.conv_in.weight.device + computation_dtype = self.conv_in.weight.dtype + value = [] + for i in range(T//2): + tl = i*2 + T%2 - (T%2 and i==0) + tr = i*2 + 2 + T%2 + model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device) + model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device) + value.append(model_output) + value = torch.concat(value, dim=2) + for name, module in self.named_modules(): + if isinstance(module, CachedConv3d): + module.clear_cache() + return value + + + @staticmethod + def state_dict_converter(): + return CogVAEDecoderStateDictConverter() + + + +class CogVAEEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.scaling_factor = 0.7 + self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1)) + + self.blocks = torch.nn.ModuleList([ + Resnet3DBlock(128, 128, None, 32), + Resnet3DBlock(128, 128, None, 32), + Resnet3DBlock(128, 128, None, 32), + Downsample3D(128, 128, compress_time=True), + Resnet3DBlock(128, 256, None, 32), + Resnet3DBlock(256, 256, None, 32), + Resnet3DBlock(256, 256, None, 32), + Downsample3D(256, 256, compress_time=True), + Resnet3DBlock(256, 256, None, 32), + Resnet3DBlock(256, 256, None, 32), + Resnet3DBlock(256, 256, None, 32), + Downsample3D(256, 256, compress_time=False), + Resnet3DBlock(256, 512, None, 32), + Resnet3DBlock(512, 512, None, 32), + Resnet3DBlock(512, 512, None, 32), + Resnet3DBlock(512, 512, None, 32), + Resnet3DBlock(512, 512, None, 32), + ]) + + self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True) + self.conv_act = torch.nn.SiLU() + self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1)) + + + def forward(self, sample): + hidden_states = self.conv_in(sample) + + for block in self.blocks: + hidden_states = block(hidden_states, sample) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states)[:, :16] + hidden_states = hidden_states * self.scaling_factor + + return hidden_states + + + def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x): + if tiled: + B, C, T, H, W = sample.shape + return TileWorker2Dto3D().tiled_forward( + forward_fn=lambda x: self.encode_small_video(x), + model_input=sample, + tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride), + tile_device=sample.device, tile_dtype=sample.dtype, + computation_device=sample.device, computation_dtype=sample.dtype, + scales=(16/3, (T//4+T%2)/T, 1/8, 1/8), + progress_bar=progress_bar + ) + else: + return self.encode_small_video(sample) + + + def encode_small_video(self, sample): + B, C, T, H, W = sample.shape + computation_device = self.conv_in.weight.device + computation_dtype = self.conv_in.weight.dtype + value = [] + for i in range(T//8): + t = i*8 + T%2 - (T%2 and i==0) + t_ = i*8 + 8 + T%2 + model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device) + model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device) + value.append(model_output) + value = torch.concat(value, dim=2) + for name, module in self.named_modules(): + if isinstance(module, CachedConv3d): + module.clear_cache() + return value + + + @staticmethod + def state_dict_converter(): + return CogVAEEncoderStateDictConverter() + + + +class CogVAEEncoderStateDictConverter: + def __init__(self): + pass + + + def from_diffusers(self, state_dict): + rename_dict = { + "encoder.conv_in.conv.weight": "conv_in.weight", + "encoder.conv_in.conv.bias": "conv_in.bias", + "encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight", + "encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias", + "encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight", + "encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias", + "encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight", + "encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias", + "encoder.norm_out.weight": "norm_out.weight", + "encoder.norm_out.bias": "norm_out.bias", + "encoder.conv_out.conv.weight": "conv_out.weight", + "encoder.conv_out.conv.bias": "conv_out.bias", + } + prefix_dict = { + "encoder.down_blocks.0.resnets.0.": "blocks.0.", + "encoder.down_blocks.0.resnets.1.": "blocks.1.", + "encoder.down_blocks.0.resnets.2.": "blocks.2.", + "encoder.down_blocks.1.resnets.0.": "blocks.4.", + "encoder.down_blocks.1.resnets.1.": "blocks.5.", + "encoder.down_blocks.1.resnets.2.": "blocks.6.", + "encoder.down_blocks.2.resnets.0.": "blocks.8.", + "encoder.down_blocks.2.resnets.1.": "blocks.9.", + "encoder.down_blocks.2.resnets.2.": "blocks.10.", + "encoder.down_blocks.3.resnets.0.": "blocks.12.", + "encoder.down_blocks.3.resnets.1.": "blocks.13.", + "encoder.down_blocks.3.resnets.2.": "blocks.14.", + "encoder.mid_block.resnets.0.": "blocks.15.", + "encoder.mid_block.resnets.1.": "blocks.16.", + } + suffix_dict = { + "norm1.norm_layer.weight": "norm1.norm_layer.weight", + "norm1.norm_layer.bias": "norm1.norm_layer.bias", + "norm1.conv_y.conv.weight": "norm1.conv_y.weight", + "norm1.conv_y.conv.bias": "norm1.conv_y.bias", + "norm1.conv_b.conv.weight": "norm1.conv_b.weight", + "norm1.conv_b.conv.bias": "norm1.conv_b.bias", + "norm2.norm_layer.weight": "norm2.norm_layer.weight", + "norm2.norm_layer.bias": "norm2.norm_layer.bias", + "norm2.conv_y.conv.weight": "norm2.conv_y.weight", + "norm2.conv_y.conv.bias": "norm2.conv_y.bias", + "norm2.conv_b.conv.weight": "norm2.conv_b.weight", + "norm2.conv_b.conv.bias": "norm2.conv_b.bias", + "conv1.conv.weight": "conv1.weight", + "conv1.conv.bias": "conv1.bias", + "conv2.conv.weight": "conv2.weight", + "conv2.conv.bias": "conv2.bias", + "conv_shortcut.weight": "conv_shortcut.weight", + "conv_shortcut.bias": "conv_shortcut.bias", + "norm1.weight": "norm1.weight", + "norm1.bias": "norm1.bias", + "norm2.weight": "norm2.weight", + "norm2.bias": "norm2.bias", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + for prefix in prefix_dict: + if name.startswith(prefix): + suffix = name[len(prefix):] + state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param + return state_dict_ + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) + + + +class CogVAEDecoderStateDictConverter: + def __init__(self): + pass + + + def from_diffusers(self, state_dict): + rename_dict = { + "decoder.conv_in.conv.weight": "conv_in.weight", + "decoder.conv_in.conv.bias": "conv_in.bias", + "decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight", + "decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias", + "decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight", + "decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias", + "decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight", + "decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias", + "decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight", + "decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias", + "decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight", + "decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias", + "decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight", + "decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias", + "decoder.conv_out.conv.weight": "conv_out.weight", + "decoder.conv_out.conv.bias": "conv_out.bias" + } + prefix_dict = { + "decoder.mid_block.resnets.0.": "blocks.0.", + "decoder.mid_block.resnets.1.": "blocks.1.", + "decoder.up_blocks.0.resnets.0.": "blocks.2.", + "decoder.up_blocks.0.resnets.1.": "blocks.3.", + "decoder.up_blocks.0.resnets.2.": "blocks.4.", + "decoder.up_blocks.0.resnets.3.": "blocks.5.", + "decoder.up_blocks.1.resnets.0.": "blocks.7.", + "decoder.up_blocks.1.resnets.1.": "blocks.8.", + "decoder.up_blocks.1.resnets.2.": "blocks.9.", + "decoder.up_blocks.1.resnets.3.": "blocks.10.", + "decoder.up_blocks.2.resnets.0.": "blocks.12.", + "decoder.up_blocks.2.resnets.1.": "blocks.13.", + "decoder.up_blocks.2.resnets.2.": "blocks.14.", + "decoder.up_blocks.2.resnets.3.": "blocks.15.", + "decoder.up_blocks.3.resnets.0.": "blocks.17.", + "decoder.up_blocks.3.resnets.1.": "blocks.18.", + "decoder.up_blocks.3.resnets.2.": "blocks.19.", + "decoder.up_blocks.3.resnets.3.": "blocks.20.", + } + suffix_dict = { + "norm1.norm_layer.weight": "norm1.norm_layer.weight", + "norm1.norm_layer.bias": "norm1.norm_layer.bias", + "norm1.conv_y.conv.weight": "norm1.conv_y.weight", + "norm1.conv_y.conv.bias": "norm1.conv_y.bias", + "norm1.conv_b.conv.weight": "norm1.conv_b.weight", + "norm1.conv_b.conv.bias": "norm1.conv_b.bias", + "norm2.norm_layer.weight": "norm2.norm_layer.weight", + "norm2.norm_layer.bias": "norm2.norm_layer.bias", + "norm2.conv_y.conv.weight": "norm2.conv_y.weight", + "norm2.conv_y.conv.bias": "norm2.conv_y.bias", + "norm2.conv_b.conv.weight": "norm2.conv_b.weight", + "norm2.conv_b.conv.bias": "norm2.conv_b.bias", + "conv1.conv.weight": "conv1.weight", + "conv1.conv.bias": "conv1.bias", + "conv2.conv.weight": "conv2.weight", + "conv2.conv.bias": "conv2.bias", + "conv_shortcut.weight": "conv_shortcut.weight", + "conv_shortcut.bias": "conv_shortcut.bias", + } + state_dict_ = {} + for name, param in state_dict.items(): + if name in rename_dict: + state_dict_[rename_dict[name]] = param + else: + for prefix in prefix_dict: + if name.startswith(prefix): + suffix = name[len(prefix):] + state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param + return state_dict_ + + + def from_civitai(self, state_dict): + return self.from_diffusers(state_dict) + diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6cca984..eb1cde1 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -464,9 +464,9 @@ class FluxDiTStateDictConverter: name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) state_dict_[name_] = param else: - print(name) + pass else: - print(name) + pass for name in list(state_dict_.keys()): if ".proj_in_besides_attn." in name: name_ = name.replace(".proj_in_besides_attn.", ".linear.") @@ -570,6 +570,6 @@ class FluxDiTStateDictConverter: rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])] state_dict_[rename] = param else: - print(name) + pass return state_dict_ \ No newline at end of file diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 0419364..8f0ed70 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -4,6 +4,7 @@ from .sdxl_unet import SDXLUNet from .sd_text_encoder import SDTextEncoder from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sd3_dit import SD3DiT +from .flux_dit import FluxDiT from .hunyuan_dit import HunyuanDiT @@ -17,6 +18,13 @@ class LoRAFromCivitai: def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): + for key in state_dict: + if ".lora_up" in key: + return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha) + return self.convert_state_dict_AB(state_dict, lora_prefix, alpha) + + + def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0): renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "") state_dict_ = {} for key in state_dict: @@ -39,6 +47,29 @@ class LoRAFromCivitai: return state_dict_ + def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16): + state_dict_ = {} + for key in state_dict: + if ".lora_B." not in key: + continue + if not key.startswith(lora_prefix): + continue + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) + weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_up, weight_down) + keys = key.split(".") + keys.pop(keys.index("lora_B")) + target_name = ".".join(keys) + target_name = target_name[len(lora_prefix):] + state_dict_[target_name] = lora_weight.cpu() + return state_dict_ + + def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None): state_dict_model = model.state_dict() state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha) @@ -134,6 +165,23 @@ class SDXLLoRAFromCivitai(LoRAFromCivitai): } +class FluxLoRAFromCivitai(LoRAFromCivitai): + def __init__(self): + super().__init__() + self.supported_model_classes = [FluxDiT, FluxDiT] + self.lora_prefix = ["lora_unet_", "transformer."] + self.renamed_lora_prefix = {} + self.special_keys = { + "single.blocks": "single_blocks", + "double.blocks": "double_blocks", + "img.attn": "img_attn", + "img.mlp": "img_mlp", + "img.mod": "img_mod", + "txt.attn": "txt_attn", + "txt.mlp": "txt_mlp", + "txt.mod": "txt_mod", + } + class GeneralLoRAFromPeft: def __init__(self): @@ -192,4 +240,8 @@ class GeneralLoRAFromPeft: return "", "" except: pass - return None \ No newline at end of file + return None + + +def get_lora_loaders(): + return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft(), FluxLoRAFromCivitai()] diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 574d7ed..150565d 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -10,7 +10,7 @@ from .sd_text_encoder import SDTextEncoder from .sd_unet import SDUNet from .sd_vae_encoder import SDVAEEncoder from .sd_vae_decoder import SDVAEDecoder -from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft +from .lora import get_lora_loaders from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2 from .sdxl_unet import SDXLUNet @@ -43,93 +43,17 @@ from .flux_dit import FluxDiT from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2 from .flux_vae import FluxVAEEncoder, FluxVAEDecoder +from .cog_vae import CogVAEEncoder, CogVAEDecoder +from .cog_dit import CogDiT + +from ..extensions.RIFE import IFNet +from ..extensions.ESRGAN import RRDBNet + from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs +from .utils import load_state_dict -def load_state_dict(file_path, torch_dtype=None): - if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) - else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) - - -def load_state_dict_from_safetensors(file_path, torch_dtype=None): - state_dict = {} - with safe_open(file_path, framework="pt", device="cpu") as f: - for k in f.keys(): - state_dict[k] = f.get_tensor(k) - if torch_dtype is not None: - state_dict[k] = state_dict[k].to(torch_dtype) - return state_dict - - -def load_state_dict_from_bin(file_path, torch_dtype=None): - state_dict = torch.load(file_path, map_location="cpu") - if torch_dtype is not None: - for i in state_dict: - if isinstance(state_dict[i], torch.Tensor): - state_dict[i] = state_dict[i].to(torch_dtype) - return state_dict - - -def search_for_embeddings(state_dict): - embeddings = [] - for k in state_dict: - if isinstance(state_dict[k], torch.Tensor): - embeddings.append(state_dict[k]) - elif isinstance(state_dict[k], dict): - embeddings += search_for_embeddings(state_dict[k]) - return embeddings - - -def search_parameter(param, state_dict): - for name, param_ in state_dict.items(): - if param.numel() == param_.numel(): - if param.shape == param_.shape: - if torch.dist(param, param_) < 1e-3: - return name - else: - if torch.dist(param.flatten(), param_.flatten()) < 1e-3: - return name - return None - - -def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): - matched_keys = set() - with torch.no_grad(): - for name in source_state_dict: - rename = search_parameter(source_state_dict[name], target_state_dict) - if rename is not None: - print(f'"{name}": "{rename}",') - matched_keys.add(rename) - elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0: - length = source_state_dict[name].shape[0] // 3 - rename = [] - for i in range(3): - rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict)) - if None not in rename: - print(f'"{name}": {rename},') - for rename_ in rename: - matched_keys.add(rename_) - for name in target_state_dict: - if name not in matched_keys: - print("Cannot find", name, target_state_dict[name].shape) - - -def search_for_files(folder, extensions): - files = [] - if os.path.isdir(folder): - for file in sorted(os.listdir(folder)): - files += search_for_files(os.path.join(folder, file), extensions) - elif os.path.isfile(folder): - for extension in extensions: - if folder.endswith(extension): - files.append(folder) - break - return files - - def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): keys = [] for key, value in state_dict.items(): @@ -195,7 +119,10 @@ def load_model_from_huggingface_folder(file_path, model_names, model_classes, to model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval() if torch_dtype == torch.float16 and hasattr(model, "half"): model = model.half() - model = model.to(device=device) + try: + model = model.to(device=device) + except: + pass loaded_model_names.append(model_name) loaded_models.append(model) return loaded_model_names, loaded_models @@ -356,7 +283,7 @@ class ModelDetectorFromHuggingfaceFolder: return False with open(os.path.join(file_path, "config.json"), "r") as f: config = json.load(f) - if "architectures" not in config: + if "architectures" not in config and "_class_name" not in config: return False return True @@ -365,7 +292,8 @@ class ModelDetectorFromHuggingfaceFolder: with open(os.path.join(file_path, "config.json"), "r") as f: config = json.load(f) loaded_model_names, loaded_models = [], [] - for architecture in config["architectures"]: + architectures = config["architectures"] if "architectures" in config else [config["_class_name"]] + for architecture in architectures: huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture] if redirected_architecture is not None: architecture = redirected_architecture @@ -478,7 +406,7 @@ class ModelManager: if len(state_dict) == 0: state_dict = load_state_dict(file_path) for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): - for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]: + for lora in get_lora_loaders(): match_results = lora.match(model, state_dict) if match_results is not None: print(f" Adding LoRA to {model_name} ({model_path}).") diff --git a/diffsynth/models/tiler.py b/diffsynth/models/tiler.py index af37ff6..6f36cdf 100644 --- a/diffsynth/models/tiler.py +++ b/diffsynth/models/tiler.py @@ -103,4 +103,78 @@ class TileWorker: # Done! model_output = model_output.to(device=inference_device, dtype=inference_dtype) - return model_output \ No newline at end of file + return model_output + + + +class TileWorker2Dto3D: + """ + Process 3D tensors, but only enable TileWorker on 2D. + """ + def __init__(self): + pass + + + def build_mask(self, T, H, W, dtype, device, is_bound, border_width): + t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W) + h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W) + w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W) + border_width = (H + W) // 4 if border_width is None else border_width + pad = torch.ones_like(h) * border_width + mask = torch.stack([ + pad if is_bound[0] else t + 1, + pad if is_bound[1] else T - t, + pad if is_bound[2] else h + 1, + pad if is_bound[3] else H - h, + pad if is_bound[4] else w + 1, + pad if is_bound[5] else W - w + ]).min(dim=0).values + mask = mask.clip(1, border_width) + mask = (mask / border_width).to(dtype=dtype, device=device) + mask = rearrange(mask, "T H W -> 1 1 T H W") + return mask + + + def tiled_forward( + self, + forward_fn, + model_input, + tile_size, tile_stride, + tile_device="cpu", tile_dtype=torch.float32, + computation_device="cuda", computation_dtype=torch.float32, + border_width=None, scales=[1, 1, 1, 1], + progress_bar=lambda x:x + ): + B, C, T, H, W = model_input.shape + scale_C, scale_T, scale_H, scale_W = scales + tile_size_H, tile_size_W = tile_size + tile_stride_H, tile_stride_W = tile_stride + + value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device) + weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device) + + # Split tasks + tasks = [] + for h in range(0, H, tile_stride_H): + for w in range(0, W, tile_stride_W): + if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W): + continue + h_, w_ = h + tile_size_H, w + tile_size_W + if h_ > H: h, h_ = max(H - tile_size_H, 0), H + if w_ > W: w, w_ = max(W - tile_size_W, 0), W + tasks.append((h, h_, w, w_)) + + # Run + for hl, hr, wl, wr in progress_bar(tasks): + mask = self.build_mask( + int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W), + tile_dtype, tile_device, + is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W), + border_width=border_width + ) + grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device) + grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device) + value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask + weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask + value = value / weight + return value \ No newline at end of file diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py new file mode 100644 index 0000000..e05a6ea --- /dev/null +++ b/diffsynth/models/utils.py @@ -0,0 +1,96 @@ +import torch, os +from safetensors import safe_open + + + +def load_state_dict_from_folder(file_path, torch_dtype=None): + state_dict = {} + for file_name in os.listdir(file_path): + if "." in file_name and file_name.split(".")[-1] in [ + "safetensors", "bin", "ckpt", "pth", "pt" + ]: + state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype)) + return state_dict + + +def load_state_dict(file_path, torch_dtype=None): + if file_path.endswith(".safetensors"): + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + else: + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + + +def load_state_dict_from_safetensors(file_path, torch_dtype=None): + state_dict = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + if torch_dtype is not None: + state_dict[k] = state_dict[k].to(torch_dtype) + return state_dict + + +def load_state_dict_from_bin(file_path, torch_dtype=None): + state_dict = torch.load(file_path, map_location="cpu") + if torch_dtype is not None: + for i in state_dict: + if isinstance(state_dict[i], torch.Tensor): + state_dict[i] = state_dict[i].to(torch_dtype) + return state_dict + + +def search_for_embeddings(state_dict): + embeddings = [] + for k in state_dict: + if isinstance(state_dict[k], torch.Tensor): + embeddings.append(state_dict[k]) + elif isinstance(state_dict[k], dict): + embeddings += search_for_embeddings(state_dict[k]) + return embeddings + + +def search_parameter(param, state_dict): + for name, param_ in state_dict.items(): + if param.numel() == param_.numel(): + if param.shape == param_.shape: + if torch.dist(param, param_) < 1e-3: + return name + else: + if torch.dist(param.flatten(), param_.flatten()) < 1e-3: + return name + return None + + +def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False): + matched_keys = set() + with torch.no_grad(): + for name in source_state_dict: + rename = search_parameter(source_state_dict[name], target_state_dict) + if rename is not None: + print(f'"{name}": "{rename}",') + matched_keys.add(rename) + elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0: + length = source_state_dict[name].shape[0] // 3 + rename = [] + for i in range(3): + rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict)) + if None not in rename: + print(f'"{name}": {rename},') + for rename_ in rename: + matched_keys.add(rename_) + for name in target_state_dict: + if name not in matched_keys: + print("Cannot find", name, target_state_dict[name].shape) + + +def search_for_files(folder, extensions): + files = [] + if os.path.isdir(folder): + for file in sorted(os.listdir(folder)): + files += search_for_files(os.path.join(folder, file), extensions) + elif os.path.isfile(folder): + for extension in extensions: + if folder.endswith(extension): + files.append(folder) + break + return files diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index 79c0393..79d9fc5 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -6,5 +6,6 @@ from .sd3_image import SD3ImagePipeline from .hunyuan_image import HunyuanDiTImagePipeline from .svd_video import SVDVideoPipeline from .flux_image import FluxImagePipeline +from .cog_video import CogVideoPipeline from .pipeline_runner import SDVideoPipelineRunner KolorsImagePipeline = SDXLImagePipeline diff --git a/diffsynth/pipelines/base.py b/diffsynth/pipelines/base.py index 78e66b5..2feb405 100644 --- a/diffsynth/pipelines/base.py +++ b/diffsynth/pipelines/base.py @@ -50,4 +50,13 @@ class BasePipeline(torch.nn.Module): noise_pred_locals = [inference_callback(prompt_emb_local) for prompt_emb_local in prompt_emb_locals] noise_pred = self.merge_latents(noise_pred_global, noise_pred_locals, masks, mask_scales) return noise_pred + + + def extend_prompt(self, prompt, local_prompts, masks, mask_scales): + extended_prompt_dict = self.prompter.extend_prompt(prompt) + prompt = extended_prompt_dict.get("prompt", prompt) + local_prompts += extended_prompt_dict.get("prompts", []) + masks += extended_prompt_dict.get("masks", []) + mask_scales += [5.0] * len(extended_prompt_dict.get("masks", [])) + return prompt, local_prompts, masks, mask_scales \ No newline at end of file diff --git a/diffsynth/pipelines/cog_video.py b/diffsynth/pipelines/cog_video.py new file mode 100644 index 0000000..777ce75 --- /dev/null +++ b/diffsynth/pipelines/cog_video.py @@ -0,0 +1,131 @@ +from ..models import ModelManager, FluxTextEncoder2, CogDiT, CogVAEEncoder, CogVAEDecoder +from ..prompters import CogPrompter +from ..schedulers import EnhancedDDIMScheduler +from .base import BasePipeline +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np +from einops import rearrange + + + +class CogVideoPipeline(BasePipeline): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__(device=device, torch_dtype=torch_dtype) + self.scheduler = EnhancedDDIMScheduler(rescale_zero_terminal_snr=True, prediction_type="v_prediction") + self.prompter = CogPrompter() + # models + self.text_encoder: FluxTextEncoder2 = None + self.dit: CogDiT = None + self.vae_encoder: CogVAEEncoder = None + self.vae_decoder: CogVAEDecoder = None + + + def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]): + self.text_encoder = model_manager.fetch_model("flux_text_encoder_2") + self.dit = model_manager.fetch_model("cog_dit") + self.vae_encoder = model_manager.fetch_model("cog_vae_encoder") + self.vae_decoder = model_manager.fetch_model("cog_vae_decoder") + self.prompter.fetch_models(self.text_encoder) + self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) + + + @staticmethod + def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]): + pipe = CogVideoPipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype + ) + pipe.fetch_models(model_manager, prompt_refiner_classes) + return pipe + + + def tensor2video(self, frames): + frames = rearrange(frames, "C T H W -> T H W C") + frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) + frames = [Image.fromarray(frame) for frame in frames] + return frames + + + def encode_prompt(self, prompt, positive=True): + prompt_emb = self.prompter.encode_prompt(prompt, device=self.device, positive=positive) + return {"prompt_emb": prompt_emb} + + + def prepare_extra_input(self, latents): + return {"image_rotary_emb": self.dit.prepare_rotary_positional_embeddings(latents.shape[3], latents.shape[4], latents.shape[2], device=self.device)} + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + input_video=None, + cfg_scale=7.0, + denoising_strength=1.0, + num_frames=49, + height=480, + width=720, + num_inference_steps=20, + tiled=False, + tile_size=(60, 90), + tile_stride=(30, 45), + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Tiler parameters + tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) + + # Prepare latent tensors + noise = torch.randn((1, 16, num_frames // 4 + 1, height//8, width//8), device="cpu", dtype=self.torch_dtype) + if denoising_strength == 1.0: + latents = noise.clone() + else: + input_video = self.preprocess_images(input_video) + input_video = torch.stack(input_video, dim=2) + latents = self.vae_encoder.encode_video(input_video, **tiler_kwargs, progress_bar=progress_bar_cmd).to(dtype=self.torch_dtype) + latents = self.scheduler.add_noise(latents, noise, self.scheduler.timesteps[0]) + if not tiled: latents = latents.to(self.device) + + # Encode prompt + prompt_emb_posi = self.encode_prompt(prompt, positive=True) + if cfg_scale != 1.0: + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) + + # Extra input + extra_input = self.prepare_extra_input(latents) + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(self.device) + + # Classifier-free guidance + noise_pred_posi = self.dit( + latents, timestep=timestep, **prompt_emb_posi, **tiler_kwargs, **extra_input + ) + if cfg_scale != 1.0: + noise_pred_nega = self.dit( + latents, timestep=timestep, **prompt_emb_nega, **tiler_kwargs, **extra_input + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + # DDIM + latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + + # Update progress bar + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + video = self.vae_decoder.decode_video(latents.to("cpu"), **tiler_kwargs, progress_bar=progress_bar_cmd) + video = self.tensor2video(video[0]) + + return video diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 74de285..8d6a246 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -25,7 +25,7 @@ class FluxImagePipeline(BasePipeline): return self.dit - def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[]): + def fetch_models(self, model_manager: ModelManager, prompt_refiner_classes=[], prompt_extender_classes=[]): self.text_encoder_1 = model_manager.fetch_model("flux_text_encoder_1") self.text_encoder_2 = model_manager.fetch_model("flux_text_encoder_2") self.dit = model_manager.fetch_model("flux_dit") @@ -33,15 +33,16 @@ class FluxImagePipeline(BasePipeline): self.vae_encoder = model_manager.fetch_model("flux_vae_encoder") self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2) self.prompter.load_prompt_refiners(model_manager, prompt_refiner_classes) + self.prompter.load_prompt_extenders(model_manager, prompt_extender_classes) @staticmethod - def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[]): + def from_model_manager(model_manager: ModelManager, prompt_refiner_classes=[],prompt_extender_classes=[]): pipe = FluxImagePipeline( device=model_manager.device, torch_dtype=model_manager.torch_dtype, ) - pipe.fetch_models(model_manager, prompt_refiner_classes) + pipe.fetch_models(model_manager, prompt_refiner_classes,prompt_extender_classes) return pipe @@ -105,6 +106,9 @@ class FluxImagePipeline(BasePipeline): else: latents = torch.randn((1, 16, height//8, width//8), device=self.device, dtype=self.torch_dtype) + # Extend prompt + prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) + # Encode prompts prompt_emb_posi = self.encode_prompt(prompt, positive=True) if cfg_scale != 1.0: diff --git a/diffsynth/prompters/__init__.py b/diffsynth/prompters/__init__.py index 4c6a20a..6c7c7bf 100644 --- a/diffsynth/prompters/__init__.py +++ b/diffsynth/prompters/__init__.py @@ -5,3 +5,5 @@ from .sd3_prompter import SD3Prompter from .hunyuan_dit_prompter import HunyuanDiTPrompter from .kolors_prompter import KolorsPrompter from .flux_prompter import FluxPrompter +from .omost import OmostPromter +from .cog_prompter import CogPrompter diff --git a/diffsynth/prompters/base_prompter.py b/diffsynth/prompters/base_prompter.py index de9a40d..9f0101a 100644 --- a/diffsynth/prompters/base_prompter.py +++ b/diffsynth/prompters/base_prompter.py @@ -37,14 +37,20 @@ def tokenize_long_prompt(tokenizer, prompt, max_length=None): class BasePrompter: - def __init__(self, refiners=[]): + def __init__(self, refiners=[], extenders=[]): self.refiners = refiners + self.extenders = extenders - def load_prompt_refiners(self, model_nameger: ModelManager, refiner_classes=[]): + def load_prompt_refiners(self, model_manager: ModelManager, refiner_classes=[]): for refiner_class in refiner_classes: - refiner = refiner_class.from_model_manager(model_nameger) + refiner = refiner_class.from_model_manager(model_manager) self.refiners.append(refiner) + + def load_prompt_extenders(self,model_manager:ModelManager,extender_classes=[]): + for extender_class in extender_classes: + extender = extender_class.from_model_manager(model_manager) + self.extenders.append(extender) @torch.no_grad() @@ -55,3 +61,10 @@ class BasePrompter: for refiner in self.refiners: prompt = refiner(prompt, positive=positive) return prompt + + @torch.no_grad() + def extend_prompt(self, prompt:str, positive=True): + extended_prompt = dict(prompt=prompt) + for extender in self.extenders: + extended_prompt = extender(extended_prompt) + return extended_prompt \ No newline at end of file diff --git a/diffsynth/prompters/cog_prompter.py b/diffsynth/prompters/cog_prompter.py new file mode 100644 index 0000000..a1ab84a --- /dev/null +++ b/diffsynth/prompters/cog_prompter.py @@ -0,0 +1,46 @@ +from .base_prompter import BasePrompter +from ..models.flux_text_encoder import FluxTextEncoder2 +from transformers import T5TokenizerFast +import os + + +class CogPrompter(BasePrompter): + def __init__( + self, + tokenizer_path=None + ): + if tokenizer_path is None: + base_path = os.path.dirname(os.path.dirname(__file__)) + tokenizer_path = os.path.join(base_path, "tokenizer_configs/cog/tokenizer") + super().__init__() + self.tokenizer = T5TokenizerFast.from_pretrained(tokenizer_path) + self.text_encoder: FluxTextEncoder2 = None + + + def fetch_models(self, text_encoder: FluxTextEncoder2 = None): + self.text_encoder = text_encoder + + + def encode_prompt_using_t5(self, prompt, text_encoder, tokenizer, max_length, device): + input_ids = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + ).input_ids.to(device) + prompt_emb = text_encoder(input_ids) + prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1)) + + return prompt_emb + + + def encode_prompt( + self, + prompt, + positive=True, + device="cuda" + ): + prompt = self.process_prompt(prompt, positive=positive) + prompt_emb = self.encode_prompt_using_t5(prompt, self.text_encoder, self.tokenizer, 226, device) + return prompt_emb diff --git a/diffsynth/prompters/omost.py b/diffsynth/prompters/omost.py new file mode 100644 index 0000000..39999ce --- /dev/null +++ b/diffsynth/prompters/omost.py @@ -0,0 +1,311 @@ +from transformers import AutoTokenizer, TextIteratorStreamer +import difflib +import torch +import numpy as np +import re +from ..models.model_manager import ModelManager +from PIL import Image + +valid_colors = { # r, g, b + 'aliceblue': (240, 248, 255), 'antiquewhite': (250, 235, 215), 'aqua': (0, 255, 255), + 'aquamarine': (127, 255, 212), 'azure': (240, 255, 255), 'beige': (245, 245, 220), + 'bisque': (255, 228, 196), 'black': (0, 0, 0), 'blanchedalmond': (255, 235, 205), 'blue': (0, 0, 255), + 'blueviolet': (138, 43, 226), 'brown': (165, 42, 42), 'burlywood': (222, 184, 135), + 'cadetblue': (95, 158, 160), 'chartreuse': (127, 255, 0), 'chocolate': (210, 105, 30), + 'coral': (255, 127, 80), 'cornflowerblue': (100, 149, 237), 'cornsilk': (255, 248, 220), + 'crimson': (220, 20, 60), 'cyan': (0, 255, 255), 'darkblue': (0, 0, 139), 'darkcyan': (0, 139, 139), + 'darkgoldenrod': (184, 134, 11), 'darkgray': (169, 169, 169), 'darkgrey': (169, 169, 169), + 'darkgreen': (0, 100, 0), 'darkkhaki': (189, 183, 107), 'darkmagenta': (139, 0, 139), + 'darkolivegreen': (85, 107, 47), 'darkorange': (255, 140, 0), 'darkorchid': (153, 50, 204), + 'darkred': (139, 0, 0), 'darksalmon': (233, 150, 122), 'darkseagreen': (143, 188, 143), + 'darkslateblue': (72, 61, 139), 'darkslategray': (47, 79, 79), 'darkslategrey': (47, 79, 79), + 'darkturquoise': (0, 206, 209), 'darkviolet': (148, 0, 211), 'deeppink': (255, 20, 147), + 'deepskyblue': (0, 191, 255), 'dimgray': (105, 105, 105), 'dimgrey': (105, 105, 105), + 'dodgerblue': (30, 144, 255), 'firebrick': (178, 34, 34), 'floralwhite': (255, 250, 240), + 'forestgreen': (34, 139, 34), 'fuchsia': (255, 0, 255), 'gainsboro': (220, 220, 220), + 'ghostwhite': (248, 248, 255), 'gold': (255, 215, 0), 'goldenrod': (218, 165, 32), + 'gray': (128, 128, 128), 'grey': (128, 128, 128), 'green': (0, 128, 0), 'greenyellow': (173, 255, 47), + 'honeydew': (240, 255, 240), 'hotpink': (255, 105, 180), 'indianred': (205, 92, 92), + 'indigo': (75, 0, 130), 'ivory': (255, 255, 240), 'khaki': (240, 230, 140), 'lavender': (230, 230, 250), + 'lavenderblush': (255, 240, 245), 'lawngreen': (124, 252, 0), 'lemonchiffon': (255, 250, 205), + 'lightblue': (173, 216, 230), 'lightcoral': (240, 128, 128), 'lightcyan': (224, 255, 255), + 'lightgoldenrodyellow': (250, 250, 210), 'lightgray': (211, 211, 211), 'lightgrey': (211, 211, 211), + 'lightgreen': (144, 238, 144), 'lightpink': (255, 182, 193), 'lightsalmon': (255, 160, 122), + 'lightseagreen': (32, 178, 170), 'lightskyblue': (135, 206, 250), 'lightslategray': (119, 136, 153), + 'lightslategrey': (119, 136, 153), 'lightsteelblue': (176, 196, 222), 'lightyellow': (255, 255, 224), + 'lime': (0, 255, 0), 'limegreen': (50, 205, 50), 'linen': (250, 240, 230), 'magenta': (255, 0, 255), + 'maroon': (128, 0, 0), 'mediumaquamarine': (102, 205, 170), 'mediumblue': (0, 0, 205), + 'mediumorchid': (186, 85, 211), 'mediumpurple': (147, 112, 219), 'mediumseagreen': (60, 179, 113), + 'mediumslateblue': (123, 104, 238), 'mediumspringgreen': (0, 250, 154), + 'mediumturquoise': (72, 209, 204), 'mediumvioletred': (199, 21, 133), 'midnightblue': (25, 25, 112), + 'mintcream': (245, 255, 250), 'mistyrose': (255, 228, 225), 'moccasin': (255, 228, 181), + 'navajowhite': (255, 222, 173), 'navy': (0, 0, 128), 'navyblue': (0, 0, 128), + 'oldlace': (253, 245, 230), 'olive': (128, 128, 0), 'olivedrab': (107, 142, 35), + 'orange': (255, 165, 0), 'orangered': (255, 69, 0), 'orchid': (218, 112, 214), + 'palegoldenrod': (238, 232, 170), 'palegreen': (152, 251, 152), 'paleturquoise': (175, 238, 238), + 'palevioletred': (219, 112, 147), 'papayawhip': (255, 239, 213), 'peachpuff': (255, 218, 185), + 'peru': (205, 133, 63), 'pink': (255, 192, 203), 'plum': (221, 160, 221), 'powderblue': (176, 224, 230), + 'purple': (128, 0, 128), 'rebeccapurple': (102, 51, 153), 'red': (255, 0, 0), + 'rosybrown': (188, 143, 143), 'royalblue': (65, 105, 225), 'saddlebrown': (139, 69, 19), + 'salmon': (250, 128, 114), 'sandybrown': (244, 164, 96), 'seagreen': (46, 139, 87), + 'seashell': (255, 245, 238), 'sienna': (160, 82, 45), 'silver': (192, 192, 192), + 'skyblue': (135, 206, 235), 'slateblue': (106, 90, 205), 'slategray': (112, 128, 144), + 'slategrey': (112, 128, 144), 'snow': (255, 250, 250), 'springgreen': (0, 255, 127), + 'steelblue': (70, 130, 180), 'tan': (210, 180, 140), 'teal': (0, 128, 128), 'thistle': (216, 191, 216), + 'tomato': (255, 99, 71), 'turquoise': (64, 224, 208), 'violet': (238, 130, 238), + 'wheat': (245, 222, 179), 'white': (255, 255, 255), 'whitesmoke': (245, 245, 245), + 'yellow': (255, 255, 0), 'yellowgreen': (154, 205, 50) +} + +valid_locations = { # x, y in 90*90 + 'in the center': (45, 45), + 'on the left': (15, 45), + 'on the right': (75, 45), + 'on the top': (45, 15), + 'on the bottom': (45, 75), + 'on the top-left': (15, 15), + 'on the top-right': (75, 15), + 'on the bottom-left': (15, 75), + 'on the bottom-right': (75, 75) +} + +valid_offsets = { # x, y in 90*90 + 'no offset': (0, 0), + 'slightly to the left': (-10, 0), + 'slightly to the right': (10, 0), + 'slightly to the upper': (0, -10), + 'slightly to the lower': (0, 10), + 'slightly to the upper-left': (-10, -10), + 'slightly to the upper-right': (10, -10), + 'slightly to the lower-left': (-10, 10), + 'slightly to the lower-right': (10, 10)} + +valid_areas = { # w, h in 90*90 + "a small square area": (50, 50), + "a small vertical area": (40, 60), + "a small horizontal area": (60, 40), + "a medium-sized square area": (60, 60), + "a medium-sized vertical area": (50, 80), + "a medium-sized horizontal area": (80, 50), + "a large square area": (70, 70), + "a large vertical area": (60, 90), + "a large horizontal area": (90, 60) +} + +def safe_str(x): + return x.strip(',. ') + '.' + +def closest_name(input_str, options): + input_str = input_str.lower() + + closest_match = difflib.get_close_matches(input_str, list(options.keys()), n=1, cutoff=0.5) + assert isinstance(closest_match, list) and len(closest_match) > 0, f'The value [{input_str}] is not valid!' + result = closest_match[0] + + if result != input_str: + print(f'Automatically corrected [{input_str}] -> [{result}].') + + return result + +class Canvas: + @staticmethod + def from_bot_response(response: str): + + matched = re.search(r'```python\n(.*?)\n```', response, re.DOTALL) + assert matched, 'Response does not contain codes!' + code_content = matched.group(1) + assert 'canvas = Canvas()' in code_content, 'Code block must include valid canvas var!' + local_vars = {'Canvas': Canvas} + exec(code_content, {}, local_vars) + canvas = local_vars.get('canvas', None) + assert isinstance(canvas, Canvas), 'Code block must produce valid canvas var!' + return canvas + + def __init__(self): + self.components = [] + self.color = None + self.record_tags = True + self.prefixes = [] + self.suffixes = [] + return + + def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, + HTML_web_color_name: str): + assert isinstance(description, str), 'Global description is not valid!' + assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \ + 'Global detailed_descriptions is not valid!' + assert isinstance(tags, str), 'Global tags is not valid!' + + HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors) + self.color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8) + + self.prefixes = [description] + self.suffixes = detailed_descriptions + + if self.record_tags: + self.suffixes = self.suffixes + [tags] + + self.prefixes = [safe_str(x) for x in self.prefixes] + self.suffixes = [safe_str(x) for x in self.suffixes] + + return + + def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, + detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, + quality_meta: str, HTML_web_color_name: str): + assert isinstance(description, str), 'Local description is wrong!' + assert isinstance(distance_to_viewer, (int, float)) and distance_to_viewer > 0, \ + f'The distance_to_viewer for [{description}] is not positive float number!' + assert isinstance(detailed_descriptions, list) and all(isinstance(item, str) for item in detailed_descriptions), \ + f'The detailed_descriptions for [{description}] is not valid!' + assert isinstance(tags, str), f'The tags for [{description}] is not valid!' + assert isinstance(atmosphere, str), f'The atmosphere for [{description}] is not valid!' + assert isinstance(style, str), f'The style for [{description}] is not valid!' + assert isinstance(quality_meta, str), f'The quality_meta for [{description}] is not valid!' + + location = closest_name(location, valid_locations) + offset = closest_name(offset, valid_offsets) + area = closest_name(area, valid_areas) + HTML_web_color_name = closest_name(HTML_web_color_name, valid_colors) + + xb, yb = valid_locations[location] + xo, yo = valid_offsets[offset] + w, h = valid_areas[area] + rect = (yb + yo - h // 2, yb + yo + h // 2, xb + xo - w // 2, xb + xo + w // 2) + rect = [max(0, min(90, i)) for i in rect] + color = np.array([[valid_colors[HTML_web_color_name]]], dtype=np.uint8) + + prefixes = self.prefixes + [description] + suffixes = detailed_descriptions + + if self.record_tags: + suffixes = suffixes + [tags, atmosphere, style, quality_meta] + + prefixes = [safe_str(x) for x in prefixes] + suffixes = [safe_str(x) for x in suffixes] + + self.components.append(dict( + rect=rect, + distance_to_viewer=distance_to_viewer, + color=color, + prefixes=prefixes, + suffixes=suffixes + )) + + return + + def process(self): + # sort components + self.components = sorted(self.components, key=lambda x: x['distance_to_viewer'], reverse=True) + + # compute initial latent + # print(self.color) + initial_latent = np.zeros(shape=(90, 90, 3), dtype=np.float32) + self.color + + for component in self.components: + a, b, c, d = component['rect'] + initial_latent[a:b, c:d] = 0.7 * component['color'] + 0.3 * initial_latent[a:b, c:d] + + initial_latent = initial_latent.clip(0, 255).astype(np.uint8) + + # compute conditions + + bag_of_conditions = [ + dict(mask=np.ones(shape=(90, 90), dtype=np.float32), prefixes=self.prefixes, suffixes=self.suffixes) + ] + + for i, component in enumerate(self.components): + a, b, c, d = component['rect'] + m = np.zeros(shape=(90, 90), dtype=np.float32) + m[a:b, c:d] = 1.0 + bag_of_conditions.append(dict( + mask=m, + prefixes=component['prefixes'], + suffixes=component['suffixes'] + )) + + return dict( + initial_latent=initial_latent, + bag_of_conditions=bag_of_conditions, + ) + + +class OmostPromter(torch.nn.Module): + + def __init__(self,model = None,tokenizer = None, template = "",device="cpu"): + super().__init__() + self.model=model + self.tokenizer = tokenizer + self.device = device + if template == "": + template = r'''You are a helpful AI assistant to compose images using the below python class `Canvas`: + ```python + class Canvas: + def set_global_description(self, description: str, detailed_descriptions: list[str], tags: str, HTML_web_color_name: str): + pass + + def add_local_description(self, location: str, offset: str, area: str, distance_to_viewer: float, description: str, detailed_descriptions: list[str], tags: str, atmosphere: str, style: str, quality_meta: str, HTML_web_color_name: str): + assert location in ["in the center", "on the left", "on the right", "on the top", "on the bottom", "on the top-left", "on the top-right", "on the bottom-left", "on the bottom-right"] + assert offset in ["no offset", "slightly to the left", "slightly to the right", "slightly to the upper", "slightly to the lower", "slightly to the upper-left", "slightly to the upper-right", "slightly to the lower-left", "slightly to the lower-right"] + assert area in ["a small square area", "a small vertical area", "a small horizontal area", "a medium-sized square area", "a medium-sized vertical area", "a medium-sized horizontal area", "a large square area", "a large vertical area", "a large horizontal area"] + assert distance_to_viewer > 0 + pass + ```''' + self.template = template + + @staticmethod + def from_model_manager(model_manager: ModelManager): + model, model_path = model_manager.fetch_model("omost_prompt", require_model_path=True) + tokenizer = AutoTokenizer.from_pretrained(model_path) + omost = OmostPromter( + model=model, + tokenizer=tokenizer, + ) + return omost + + + def __call__(self,prompt_dict:dict): + raw_prompt=prompt_dict["prompt"] + conversation = [{"role": "system", "content": self.template}] + conversation.append({"role": "user", "content": raw_prompt}) + + input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(self.device) + streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) + + generate_kwargs = dict( + input_ids=input_ids, + streamer=streamer, + # stopping_criteria=stopping_criteria, + # max_new_tokens=max_new_tokens, + do_sample=True, + # temperature=temperature, + # top_p=top_p, + ) + self.model.generate(**generate_kwargs) + outputs = [] + for text in streamer: + outputs.append(text) + llm_outputs = "".join(outputs) + + canvas = Canvas.from_bot_response(llm_outputs) + canvas_output = canvas.process() + + prompts = [" ".join(_["prefixes"]+_["suffixes"]) for _ in canvas_output["bag_of_conditions"]] + canvas_output["prompt"] = prompts[0] + canvas_output["prompts"] = prompts[1:] + + raw_masks = [_["mask"] for _ in canvas_output["bag_of_conditions"]] + masks=[] + for mask in raw_masks: + mask[mask>0.5]=255 + mask = np.stack([mask] * 3, axis=-1).astype("uint8") + masks.append(Image.fromarray(mask)) + + canvas_output["masks"] = masks + + prompt_dict.update(canvas_output) + return prompt_dict + + + + \ No newline at end of file diff --git a/diffsynth/prompters/prompt_refiners.py b/diffsynth/prompters/prompt_refiners.py index 4ba469a..28205ea 100644 --- a/diffsynth/prompters/prompt_refiners.py +++ b/diffsynth/prompters/prompt_refiners.py @@ -1,8 +1,7 @@ from transformers import AutoTokenizer from ..models.model_manager import ModelManager import torch - - +from .omost import OmostPromter class BeautifulPrompt(torch.nn.Module): def __init__(self, tokenizer_path=None, model=None, template=""): @@ -117,8 +116,8 @@ class Translator(torch.nn.Module): @staticmethod - def from_model_manager(model_nameger: ModelManager): - model, model_path = model_nameger.fetch_model("translator", require_model_path=True) + def from_model_manager(model_manager: ModelManager): + model, model_path = model_manager.fetch_model("translator", require_model_path=True) translator = Translator(tokenizer_path=model_path, model=model) return translator diff --git a/diffsynth/prompters/sd_prompter.py b/diffsynth/prompters/sd_prompter.py index f5a59f7..e3b31ea 100644 --- a/diffsynth/prompters/sd_prompter.py +++ b/diffsynth/prompters/sd_prompter.py @@ -1,5 +1,5 @@ from .base_prompter import BasePrompter, tokenize_long_prompt -from ..models.model_manager import ModelManager, load_state_dict, search_for_embeddings +from ..models.utils import load_state_dict, search_for_embeddings from ..models import SDTextEncoder from transformers import CLIPTokenizer import torch, os diff --git a/diffsynth/schedulers/ddim.py b/diffsynth/schedulers/ddim.py index f310639..d42c9c3 100644 --- a/diffsynth/schedulers/ddim.py +++ b/diffsynth/schedulers/ddim.py @@ -3,7 +3,7 @@ import torch, math class EnhancedDDIMScheduler(): - def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"): + def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon", rescale_zero_terminal_snr=False): self.num_train_timesteps = num_train_timesteps if beta_schedule == "scaled_linear": betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) @@ -11,11 +11,33 @@ class EnhancedDDIMScheduler(): betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) else: raise NotImplementedError(f"{beta_schedule} is not implemented") - self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist() + self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0) + if rescale_zero_terminal_snr: + self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod) + self.alphas_cumprod = self.alphas_cumprod.tolist() self.set_timesteps(10) self.prediction_type = prediction_type + def rescale_zero_terminal_snr(self, alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt.square() # Revert sqrt + + return alphas_bar + + def set_timesteps(self, num_inference_steps, denoising_strength=1.0): # The timesteps are aligned to 999...0, which is different from other implementations, # but I think this implementation is more reasonable in theory. diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json b/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json new file mode 100644 index 0000000..3f51320 --- /dev/null +++ b/diffsynth/tokenizer_configs/cog/tokenizer/added_tokens.json @@ -0,0 +1,102 @@ +{ + "": 32099, + "": 32089, + "": 32088, + "": 32087, + "": 32086, + "": 32085, + "": 32084, + "": 32083, + "": 32082, + "": 32081, + "": 32080, + "": 32098, + "": 32079, + "": 32078, + "": 32077, + "": 32076, + "": 32075, + "": 32074, + "": 32073, + "": 32072, + "": 32071, + "": 32070, + "": 32097, + "": 32069, + "": 32068, + "": 32067, + "": 32066, + "": 32065, + "": 32064, + "": 32063, + "": 32062, + "": 32061, + "": 32060, + "": 32096, + "": 32059, + "": 32058, + "": 32057, + "": 32056, + "": 32055, + "": 32054, + "": 32053, + "": 32052, + "": 32051, + "": 32050, + "": 32095, + "": 32049, + "": 32048, + "": 32047, + "": 32046, + "": 32045, + "": 32044, + "": 32043, + "": 32042, + "": 32041, + "": 32040, + "": 32094, + "": 32039, + "": 32038, + "": 32037, + "": 32036, + "": 32035, + "": 32034, + "": 32033, + "": 32032, + "": 32031, + "": 32030, + "": 32093, + "": 32029, + "": 32028, + "": 32027, + "": 32026, + "": 32025, + "": 32024, + "": 32023, + "": 32022, + "": 32021, + "": 32020, + "": 32092, + "": 32019, + "": 32018, + "": 32017, + "": 32016, + "": 32015, + "": 32014, + "": 32013, + "": 32012, + "": 32011, + "": 32010, + "": 32091, + "": 32009, + "": 32008, + "": 32007, + "": 32006, + "": 32005, + "": 32004, + "": 32003, + "": 32002, + "": 32001, + "": 32000, + "": 32090 +} diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json b/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json new file mode 100644 index 0000000..17ade34 --- /dev/null +++ b/diffsynth/tokenizer_configs/cog/tokenizer/special_tokens_map.json @@ -0,0 +1,125 @@ +{ + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model b/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model new file mode 100644 index 0000000..4e28ff6 Binary files /dev/null and b/diffsynth/tokenizer_configs/cog/tokenizer/spiece.model differ diff --git a/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json b/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json new file mode 100644 index 0000000..161715a --- /dev/null +++ b/diffsynth/tokenizer_configs/cog/tokenizer/tokenizer_config.json @@ -0,0 +1,940 @@ +{ + "add_prefix_space": true, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32000": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32001": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32002": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32003": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32004": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32005": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32006": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32007": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32008": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32009": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32010": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32011": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32012": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32013": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32014": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32015": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32016": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32017": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32018": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32019": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32020": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32021": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32022": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32023": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32024": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32025": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32026": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32027": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32028": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32029": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32030": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32031": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32032": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32033": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32034": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32035": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32036": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32037": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32038": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32039": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32040": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32041": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32042": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32043": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32044": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32045": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32046": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32047": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32048": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32049": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32050": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32051": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32052": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32053": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32054": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32055": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32056": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32057": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32058": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32059": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32060": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32061": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32062": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32063": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32064": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32065": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32066": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32067": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32068": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32069": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32070": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32071": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32072": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32073": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32074": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32075": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32076": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32077": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32078": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32079": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32080": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32081": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32082": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32083": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32084": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32085": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32086": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32087": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32088": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32089": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32090": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32091": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32092": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32093": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32094": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32095": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32096": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32097": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32098": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + }, + "32099": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": true, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "clean_up_tokenization_spaces": true, + "eos_token": "", + "extra_ids": 100, + "legacy": true, + "model_max_length": 226, + "pad_token": "", + "sp_model_kwargs": {}, + "tokenizer_class": "T5Tokenizer", + "unk_token": "" +} diff --git a/examples/image_synthesis/omost_flux_text_to_image.py b/examples/image_synthesis/omost_flux_text_to_image.py new file mode 100644 index 0000000..7562342 --- /dev/null +++ b/examples/image_synthesis/omost_flux_text_to_image.py @@ -0,0 +1,24 @@ +import torch +from diffsynth import download_models, ModelManager, OmostPromter, FluxImagePipeline + + +download_models(["OmostPrompt"]) +download_models(["FLUX.1-dev"]) + +model_manager = ModelManager(torch_dtype=torch.bfloat16) +model_manager.load_models([ + "models/OmostPrompt/omost-llama-3-8b-4bits", + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" +]) + +pipe = FluxImagePipeline.from_model_manager(model_manager, prompt_extender_classes=[OmostPromter]) + +torch.manual_seed(0) +image = pipe( + prompt="an image of a witch who is releasing ice and fire magic", + num_inference_steps=30, embedded_guidance=3.5 +) +image.save("image_omost.jpg") diff --git a/examples/video_synthesis/README.md b/examples/video_synthesis/README.md index f9802e6..2d2ff1c 100644 --- a/examples/video_synthesis/README.md +++ b/examples/video_synthesis/README.md @@ -1,8 +1,46 @@ # Text to Video -In DiffSynth Studio, we can use AnimateDiff and SVD to generate videos. However, these models usually generate terrible contents. We do not recommend users to use these models, until a more powerful video model emerges. +In DiffSynth Studio, we can use some video models to generate videos. -### Example 7: Text to Video +### Example: Text-to-Video using CogVideoX-5B (Experimental) + +See [cogvideo_text_to_video.py](cogvideo_text_to_video.py). + +First, we generate a video using prompt "an astronaut riding a horse on Mars". + +https://github.com/user-attachments/assets/4c91c1cd-e4a0-471a-bd8d-24d761262941 + +Then, we convert the astronaut to a robot. + +https://github.com/user-attachments/assets/225a00a4-2bc8-4740-8e86-a64b460a29ec + +Upscale the video using the model itself. + +https://github.com/user-attachments/assets/c02cb30c-de60-473c-8242-32c67b3155ad + +Make the video look smoother by interpolating frames. + +https://github.com/user-attachments/assets/f0e465b4-45df-4435-ab10-7a084ca2b0a0 + +Here is another example. + +First, we generate a video using prompt "a dog is running". + +https://github.com/user-attachments/assets/e3696297-99f5-4d0c-a5ca-1d1566db85b4 + +Then, we add a blue collar to the dog. + +https://github.com/user-attachments/assets/7ff22be7-4390-4d33-ae6c-53f6f056e18d + +Upscale the video using the model itself. + +https://github.com/user-attachments/assets/a909c32c-0b7d-495c-a53c-d23a99a3d3e9 + +Make the video look smoother by interpolating frames. + +https://github.com/user-attachments/assets/ea37c150-97a0-4858-8003-0c2e5eef3331 + +### Example: Text-to-Video using AnimateDiff Generate a video using a Stable Diffusion model and an AnimateDiff model. We can break the limitation of number of frames! See [sd_text_to_video.py](./sd_text_to_video.py). diff --git a/examples/video_synthesis/cogvideo_text_to_video.py b/examples/video_synthesis/cogvideo_text_to_video.py new file mode 100644 index 0000000..22e4f36 --- /dev/null +++ b/examples/video_synthesis/cogvideo_text_to_video.py @@ -0,0 +1,73 @@ +from diffsynth import ModelManager, save_video, VideoData, download_models, CogVideoPipeline +from diffsynth.extensions.RIFE import RIFEInterpolater +import torch, os +os.environ["TOKENIZERS_PARALLELISM"] = "True" + + + +def text_to_video(model_manager, prompt, seed, output_path): + pipe = CogVideoPipeline.from_model_manager(model_manager) + torch.manual_seed(seed) + video = pipe( + prompt=prompt, + height=480, width=720, + cfg_scale=7.0, num_inference_steps=200 + ) + save_video(video, output_path, fps=8, quality=5) + + +def edit_video(model_manager, prompt, seed, input_path, output_path): + pipe = CogVideoPipeline.from_model_manager(model_manager) + input_video = VideoData(video_file=input_path) + torch.manual_seed(seed) + video = pipe( + prompt=prompt, + height=480, width=720, + cfg_scale=7.0, num_inference_steps=200, + input_video=input_video, denoising_strength=0.7 + ) + save_video(video, output_path, fps=8, quality=5) + + +def self_upscale(model_manager, prompt, seed, input_path, output_path): + pipe = CogVideoPipeline.from_model_manager(model_manager) + input_video = VideoData(video_file=input_path, height=480*2, width=720*2).raw_data() + torch.manual_seed(seed) + video = pipe( + prompt=prompt, + height=480*2, width=720*2, + cfg_scale=7.0, num_inference_steps=30, + input_video=input_video, denoising_strength=0.4, tiled=True + ) + save_video(video, output_path, fps=8, quality=7) + + +def interpolate_video(model_manager, input_path, output_path): + rife = RIFEInterpolater.from_model_manager(model_manager) + video = VideoData(video_file=input_path).raw_data() + video = rife.interpolate(video, num_iter=2) + save_video(video, output_path, fps=32, quality=5) + + + +download_models(["CogVideoX-5B", "RIFE"]) + +model_manager = ModelManager(torch_dtype=torch.bfloat16) +model_manager.load_models([ + "models/CogVideo/CogVideoX-5b/text_encoder", + "models/CogVideo/CogVideoX-5b/transformer", + "models/CogVideo/CogVideoX-5b/vae/diffusion_pytorch_model.safetensors", + "models/RIFE/flownet.pkl", +]) + +# Example 1 +text_to_video(model_manager, "an astronaut riding a horse on Mars.", 0, "1_video_1.mp4") +edit_video(model_manager, "a white robot riding a horse on Mars.", 1, "1_video_1.mp4", "1_video_2.mp4") +self_upscale(model_manager, "a white robot riding a horse on Mars.", 2, "1_video_2.mp4", "1_video_3.mp4") +interpolate_video(model_manager, "1_video_3.mp4", "1_video_4.mp4") + +# Example 2 +text_to_video(model_manager, "a dog is running.", 1, "2_video_1.mp4") +edit_video(model_manager, "a dog with blue collar.", 2, "2_video_1.mp4", "2_video_2.mp4") +self_upscale(model_manager, "a dog with blue collar.", 3, "2_video_2.mp4", "2_video_3.mp4") +interpolate_video(model_manager, "2_video_3.mp4", "2_video_4.mp4") diff --git a/requirements.txt b/requirements.txt index 9159ab9..9af7c82 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,5 @@ imageio[ffmpeg] safetensors einops sentencepiece +protobuf modelscope