mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2cefc20ed6 | ||
|
|
02a4c8df9f | ||
|
|
582e33ad51 | ||
|
|
491bbf5369 | ||
|
|
0c92f3b2cc | ||
|
|
427232cbc0 | ||
|
|
2899283c01 | ||
|
|
9cff769fbd | ||
|
|
23e33273f1 | ||
|
|
f191353cf4 | ||
|
|
66a094fc84 | ||
|
|
3681adc5ac | ||
|
|
7434ec8fcd | ||
|
|
0699212665 | ||
|
|
f47de78b59 | ||
|
|
5fdc8039ec | ||
|
|
46d4616e23 |
12
README.md
12
README.md
@@ -17,6 +17,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
|
||||
|
||||
Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
* [StepVideo](https://github.com/stepfun-ai/Step-Video-T2V)
|
||||
* [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)
|
||||
* [CogVideoX](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
* [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
@@ -34,11 +35,14 @@ Until now, DiffSynth Studio has supported the following models:
|
||||
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
|
||||
|
||||
## News
|
||||
|
||||
- **February 17, 2024** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! State-of-the-art video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
||||
|
||||
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter and In-Context LoRA, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
||||
* Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
* Training dataset: Coming soon
|
||||
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
||||
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
- Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.
|
||||
|
||||
|
||||
@@ -51,6 +51,10 @@ from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
|
||||
from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
|
||||
from ..models.wanx_vae import WanXVideoVAE
|
||||
|
||||
model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -103,6 +107,9 @@ model_loader_configs = [
|
||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
|
||||
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
|
||||
(None, "1378ea763357eea97acdef78e65d6d96", ["wanxvideo_vae"], [WanXVideoVAE], "civitai")
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -115,7 +122,8 @@ huggingface_model_loader_configs = [
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
|
||||
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
@@ -135,8 +135,8 @@ class VideoData:
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
|
||||
@@ -80,7 +80,10 @@ def load_model_from_single_file(state_dict, model_names, model_classes, model_re
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
else:
|
||||
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
try:
|
||||
@@ -155,7 +158,7 @@ class ModelDetectorFromSingleFile:
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
@@ -197,7 +200,7 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
@@ -240,7 +243,7 @@ class ModelDetectorFromHuggingfaceFolder:
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isfile(file_path):
|
||||
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
@@ -281,7 +284,7 @@ class ModelDetectorFromPatchedSingleFile:
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if os.path.isdir(file_path):
|
||||
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
@@ -387,7 +390,11 @@ class ModelManager:
|
||||
print(f"Loading models from: {file_path}")
|
||||
if device is None: device = self.device
|
||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
||||
if os.path.isfile(file_path):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for path in file_path:
|
||||
state_dict.update(load_state_dict(path))
|
||||
elif os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
|
||||
@@ -9,7 +9,8 @@ class SD3TextEncoder1(SDTextEncoder):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
embeds = self.token_embedding(input_ids)
|
||||
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
|
||||
940
diffsynth/models/stepvideo_dit.py
Normal file
940
diffsynth/models/stepvideo_dit.py
Normal file
@@ -0,0 +1,940 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
from typing import Dict, Optional, Tuple
|
||||
import torch, math
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": nn.SiLU(),
|
||||
"silu": nn.SiLU(),
|
||||
"mish": nn.Mish(),
|
||||
"gelu": nn.GELU(),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in ACTIVATION_FUNCTIONS:
|
||||
return ACTIVATION_FUNCTIONS[act_fn]
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_1 = linear_cls(
|
||||
in_channels,
|
||||
time_embed_dim,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = linear_cls(
|
||||
cond_proj_dim,
|
||||
in_channels,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = linear_cls(
|
||||
time_embed_dim,
|
||||
time_embed_dim_out,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
if self.use_additional_conditions:
|
||||
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
||||
self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, resolution=None, nframe=None, fps=None):
|
||||
hidden_dtype = timestep.dtype
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
if self.use_additional_conditions:
|
||||
batch_size = timestep.shape[0]
|
||||
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
||||
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
||||
nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype)
|
||||
nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1)
|
||||
conditioning = timesteps_emb + resolution_emb + nframe_emb
|
||||
|
||||
if fps is not None:
|
||||
fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype)
|
||||
fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1)
|
||||
conditioning = conditioning + fps_emb
|
||||
else:
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
|
||||
|
||||
out = self.linear(self.silu(embedded_timestep))
|
||||
|
||||
return out, embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = nn.Linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def attn_processor(self, attn_type):
|
||||
if attn_type == 'torch':
|
||||
return self.torch_attn_func
|
||||
elif attn_type == 'parallel':
|
||||
return self.parallel_attn_func
|
||||
else:
|
||||
raise Exception('Not supported attention type...')
|
||||
|
||||
def torch_attn_func(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
causal=False,
|
||||
drop_rate=0.0,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(q.dtype)
|
||||
|
||||
if attn_mask is not None and attn_mask.ndim == 3: ## no head
|
||||
n_heads = q.shape[2]
|
||||
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
|
||||
|
||||
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(q.device)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
||||
)
|
||||
x = rearrange(x, 'b h s d -> b s h d')
|
||||
return x
|
||||
|
||||
|
||||
class RoPE1D:
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
self.base = freq
|
||||
self.F0 = F0
|
||||
self.scaling_factor = scaling_factor
|
||||
self.cache = {}
|
||||
|
||||
def get_cos_sin(self, D, seq_len, device, dtype):
|
||||
if (D, seq_len, device, dtype) not in self.cache:
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
||||
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = freqs.cos() # (Seq, Dim)
|
||||
sin = freqs.sin()
|
||||
self.cache[D, seq_len, device, dtype] = (cos, sin)
|
||||
return self.cache[D, seq_len, device, dtype]
|
||||
|
||||
@staticmethod
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
||||
assert pos1d.ndim == 2
|
||||
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
|
||||
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
|
||||
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
||||
|
||||
def __call__(self, tokens, positions):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* positions: batch_size x ntokens (t position of each token)
|
||||
output:
|
||||
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
D = tokens.size(3)
|
||||
assert positions.ndim == 2 # Batch, Seq
|
||||
cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
|
||||
tokens = self.apply_rope1d(tokens, positions, cos, sin)
|
||||
return tokens
|
||||
|
||||
|
||||
class RoPE3D(RoPE1D):
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
super(RoPE3D, self).__init__(freq, F0, scaling_factor)
|
||||
self.position_cache = {}
|
||||
|
||||
def get_mesh_3d(self, rope_positions, bsz):
|
||||
f, h, w = rope_positions
|
||||
|
||||
if f"{f}-{h}-{w}" not in self.position_cache:
|
||||
x = torch.arange(f, device='cpu')
|
||||
y = torch.arange(h, device='cpu')
|
||||
z = torch.arange(w, device='cpu')
|
||||
self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
|
||||
return self.position_cache[f"{f}-{h}-{w}"]
|
||||
|
||||
def __call__(self, tokens, rope_positions, ch_split, parallel=False):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* rope_positions: list of (f, h, w)
|
||||
output:
|
||||
* tokens after appplying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
assert sum(ch_split) == tokens.size(-1);
|
||||
|
||||
mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
|
||||
out = []
|
||||
for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
|
||||
cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
|
||||
|
||||
if parallel:
|
||||
pass
|
||||
else:
|
||||
mesh = mesh_grid[:, :, i].clone()
|
||||
x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
|
||||
out.append(x)
|
||||
|
||||
tokens = torch.cat(out, dim=-1)
|
||||
return tokens
|
||||
|
||||
|
||||
class SelfAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_rope = with_rope
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
if self.with_rope:
|
||||
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
|
||||
self.rope_ch_split = [64, 32, 32]
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
self.parallel = attn_type=='parallel'
|
||||
|
||||
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
|
||||
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
cu_seqlens=None,
|
||||
max_seqlen=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None
|
||||
):
|
||||
xqkv = self.wqkv(x)
|
||||
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
|
||||
|
||||
xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
if self.with_rope:
|
||||
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attn_mask=None
|
||||
):
|
||||
xq = self.wq(x)
|
||||
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
|
||||
|
||||
xkv = self.wkv(encoder_hidden_states)
|
||||
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
|
||||
|
||||
xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(gate, approximate=self.approximate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: Optional[int] = None,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim*mult if inner_dim is None else inner_dim
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
self.net = nn.ModuleList([
|
||||
GELU(dim, inner_dim, approximate="tanh", bias=bias),
|
||||
nn.Identity(),
|
||||
nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
])
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def modulate(x, scale, shift):
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
def gate(x, gate):
|
||||
x = gate * x
|
||||
return x
|
||||
|
||||
|
||||
class StepVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
attention_head_dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = False,
|
||||
attention_type: str = 'parallel'
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
|
||||
|
||||
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
attn_mask = None,
|
||||
rope_positions: list = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
torch.clone(chunk) for chunk in (self.scale_shift_table[None].to(dtype=q.dtype, device=q.device) + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
|
||||
|
||||
attn_q = self.attn1(
|
||||
scale_shift_q,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
q = gate(attn_q, gate_msa) + q
|
||||
|
||||
attn_q = self.attn2(
|
||||
q,
|
||||
kv,
|
||||
attn_mask
|
||||
)
|
||||
|
||||
q = attn_q + q
|
||||
|
||||
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
|
||||
|
||||
ff_output = self.ff(scale_shift_q)
|
||||
|
||||
q = gate(ff_output, gate_mlp) + q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=64,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
layer_norm=False,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.flatten = flatten
|
||||
self.layer_norm = layer_norm
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent).to(latent.dtype)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if self.layer_norm:
|
||||
latent = self.norm(latent)
|
||||
|
||||
return latent
|
||||
|
||||
|
||||
class StepVideoModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 48,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 64,
|
||||
num_layers: int = 48,
|
||||
dropout: float = 0.0,
|
||||
patch_size: int = 1,
|
||||
norm_type: str = "ada_norm_single",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
use_additional_conditions: Optional[bool] = False,
|
||||
caption_channels: Optional[int]|list|tuple = [6144, 1024],
|
||||
attention_type: Optional[str] = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
StepVideoTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_type=attention_type
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Output blocks.
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
||||
)
|
||||
|
||||
if isinstance(caption_channels, int):
|
||||
caption_channel = caption_channels
|
||||
else:
|
||||
caption_channel, clip_channel = caption_channels
|
||||
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
|
||||
|
||||
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channel, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
self.parallel = attention_type=='parallel'
|
||||
|
||||
def patchfy(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
|
||||
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
|
||||
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
|
||||
encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
|
||||
for i, kv_len in enumerate(kv_seqlens):
|
||||
mask[i, :, :kv_len] = 1
|
||||
return encoder_hidden_states, mask
|
||||
|
||||
|
||||
def block_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None,
|
||||
parallel=True
|
||||
):
|
||||
for block in tqdm(self.transformer_blocks, desc="Transformer blocks"):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
attn_mask=attn_mask,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: torch.Tensor=None,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
|
||||
|
||||
bsz, frame, _, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
hidden_states = self.patchfy(hidden_states)
|
||||
len_frame = hidden_states.shape[1]
|
||||
|
||||
if self.use_additional_conditions:
|
||||
added_cond_kwargs = {
|
||||
"resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"fps": fps
|
||||
}
|
||||
else:
|
||||
added_cond_kwargs = {}
|
||||
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
|
||||
|
||||
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
|
||||
clip_embedding = self.clip_projection(encoder_hidden_states_2)
|
||||
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
|
||||
|
||||
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
|
||||
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
|
||||
|
||||
hidden_states = self.block_forward(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
rope_positions=[frame, height, width],
|
||||
attn_mask=attn_mask,
|
||||
parallel=self.parallel
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
|
||||
|
||||
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
|
||||
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=embedded_timestep.dtype, device=embedded_timestep.device) + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
|
||||
|
||||
if return_dict:
|
||||
return {'x': output}
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return StepVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
class StepVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
553
diffsynth/models/stepvideo_text_encoder.py
Normal file
553
diffsynth/models/stepvideo_text_encoder.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .stepvideo_dit import RMSNorm
|
||||
from safetensors.torch import load_file
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from einops import rearrange
|
||||
import json
|
||||
from typing import List
|
||||
from functools import wraps
|
||||
import warnings
|
||||
|
||||
|
||||
|
||||
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
||||
def __init__(self, device=None):
|
||||
self.device = device
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if getattr(func, '__module__', None) == 'torch.nn.init':
|
||||
if 'tensor' in kwargs:
|
||||
return kwargs['tensor']
|
||||
else:
|
||||
return args[0]
|
||||
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
|
||||
kwargs['device'] = self.device
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_empty_init(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with EmptyInitOnDevice('cpu'):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
class LLaMaEmbedding(nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cfg,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = cfg.hidden_size
|
||||
self.params_dtype = cfg.params_dtype
|
||||
self.fp32_residual_connection = cfg.fp32_residual_connection
|
||||
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
|
||||
self.word_embeddings = torch.nn.Embedding(
|
||||
cfg.padded_vocab_size, self.hidden_size,
|
||||
)
|
||||
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
|
||||
|
||||
def forward(self, input_ids):
|
||||
# Embeddings.
|
||||
if self.embedding_weights_in_fp32:
|
||||
self.word_embeddings = self.word_embeddings.to(torch.float32)
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.embedding_weights_in_fp32:
|
||||
embeddings = embeddings.to(self.params_dtype)
|
||||
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
|
||||
|
||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
embeddings = embeddings.float()
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
|
||||
class StepChatTokenizer:
|
||||
"""Step Chat Tokenizer"""
|
||||
|
||||
def __init__(
|
||||
self, model_file, name="StepChatTokenizer",
|
||||
bot_token="<|BOT|>", # Begin of Turn
|
||||
eot_token="<|EOT|>", # End of Turn
|
||||
call_start_token="<|CALL_START|>", # Call Start
|
||||
call_end_token="<|CALL_END|>", # Call End
|
||||
think_start_token="<|THINK_START|>", # Think Start
|
||||
think_end_token="<|THINK_END|>", # Think End
|
||||
mask_start_token="<|MASK_1e69f|>", # Mask start
|
||||
mask_end_token="<|UNMASK_1e69f|>", # Mask end
|
||||
):
|
||||
import sentencepiece
|
||||
|
||||
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
|
||||
|
||||
self._vocab = {}
|
||||
self._inv_vocab = {}
|
||||
|
||||
self._special_tokens = {}
|
||||
self._inv_special_tokens = {}
|
||||
|
||||
self._t5_tokens = []
|
||||
|
||||
for idx in range(self._tokenizer.get_piece_size()):
|
||||
text = self._tokenizer.id_to_piece(idx)
|
||||
self._inv_vocab[idx] = text
|
||||
self._vocab[text] = idx
|
||||
|
||||
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
|
||||
self._special_tokens[text] = idx
|
||||
self._inv_special_tokens[idx] = text
|
||||
|
||||
self._unk_id = self._tokenizer.unk_id()
|
||||
self._bos_id = self._tokenizer.bos_id()
|
||||
self._eos_id = self._tokenizer.eos_id()
|
||||
|
||||
for token in [
|
||||
bot_token, eot_token, call_start_token, call_end_token,
|
||||
think_start_token, think_end_token
|
||||
]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
assert token in self._special_tokens, f"Token '{token}' is not a special token"
|
||||
|
||||
for token in [mask_start_token, mask_end_token]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
|
||||
self._bot_id = self._tokenizer.piece_to_id(bot_token)
|
||||
self._eot_id = self._tokenizer.piece_to_id(eot_token)
|
||||
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
|
||||
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
|
||||
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
|
||||
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
|
||||
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
|
||||
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
|
||||
|
||||
self._underline_id = self._tokenizer.piece_to_id("\u2581")
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self._vocab
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self._inv_vocab
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self._tokenizer.vocab_size()
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self._tokenizer.encode_as_ids(text)
|
||||
|
||||
def detokenize(self, token_ids: List[int]) -> str:
|
||||
return self._tokenizer.decode_ids(token_ids)
|
||||
|
||||
|
||||
class Tokens:
|
||||
def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
self.cu_input_ids = cu_input_ids
|
||||
self.cu_seqlens = cu_seqlens
|
||||
self.max_seq_len = max_seq_len
|
||||
def to(self, device):
|
||||
self.input_ids = self.input_ids.to(device)
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
self.cu_input_ids = self.cu_input_ids.to(device)
|
||||
self.cu_seqlens = self.cu_seqlens.to(device)
|
||||
return self
|
||||
|
||||
class Wrapped_StepChatTokenizer(StepChatTokenizer):
|
||||
def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
|
||||
# [bos, ..., eos, pad, pad, ..., pad]
|
||||
self.BOS = 1
|
||||
self.EOS = 2
|
||||
self.PAD = 2
|
||||
out_tokens = []
|
||||
attn_mask = []
|
||||
if len(text) == 0:
|
||||
part_tokens = [self.BOS] + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
else:
|
||||
for part in text:
|
||||
part_tokens = self.tokenize(part)
|
||||
part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
|
||||
part_tokens = [self.BOS] + part_tokens + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
|
||||
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
|
||||
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
|
||||
|
||||
# padding y based on tp size
|
||||
padded_len = 0
|
||||
padded_flag = True if padded_len > 0 else False
|
||||
if padded_flag:
|
||||
pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
|
||||
pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
|
||||
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
|
||||
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
|
||||
|
||||
# cu_seqlens
|
||||
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
|
||||
seqlen = attn_mask.sum(dim=1).tolist()
|
||||
cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
|
||||
max_seq_len = max(seqlen)
|
||||
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
|
||||
|
||||
|
||||
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
|
||||
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
|
||||
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
|
||||
if hasattr(torch.ops.Optimus, "fwd"):
|
||||
results = torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
|
||||
else:
|
||||
warnings.warn("Cannot load `torch.ops.Optimus.fwd`. Using `torch.nn.functional.scaled_dot_product_attention` instead.")
|
||||
results = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=softmax_scale).transpose(1, 2)
|
||||
return results
|
||||
|
||||
|
||||
class FlashSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
|
||||
if cu_seqlens is None:
|
||||
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
|
||||
else:
|
||||
raise ValueError('cu_seqlens is not supported!')
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def safediv(n, d):
|
||||
q, r = divmod(n, d)
|
||||
assert r == 0
|
||||
return q
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
def __init__(self, cfg, layer_id=None):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.max_seq_len = cfg.seq_length
|
||||
self.use_flash_attention = cfg.use_flash_attn
|
||||
assert self.use_flash_attention, 'FlashAttention is required!'
|
||||
|
||||
self.n_groups = cfg.num_attention_groups
|
||||
self.tp_size = 1
|
||||
self.n_local_heads = cfg.num_attention_heads
|
||||
self.n_local_groups = self.n_groups
|
||||
|
||||
self.wqkv = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
|
||||
bias=False,
|
||||
)
|
||||
self.wo = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
assert self.use_flash_attention, 'non-Flash attention not supported yet.'
|
||||
self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
|
||||
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
seqlen, bsz, dim = x.shape
|
||||
xqkv = self.wqkv(x)
|
||||
|
||||
xq, xkv = torch.split(
|
||||
xqkv,
|
||||
(dim // self.tp_size,
|
||||
self.head_dim*2*self.n_groups // self.tp_size
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# gather on 1st dimention
|
||||
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
|
||||
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
|
||||
xk, xv = xkv.chunk(2, -1)
|
||||
|
||||
# rotary embedding + flash attn
|
||||
xq = rearrange(xq, "s b h d -> b s h d")
|
||||
xk = rearrange(xk, "s b h d -> b s h d")
|
||||
xv = rearrange(xv, "s b h d -> b s h d")
|
||||
|
||||
q_per_kv = self.n_local_heads // self.n_local_groups
|
||||
if q_per_kv > 1:
|
||||
b, s, h, d = xk.size()
|
||||
if h == 1:
|
||||
xk = xk.expand(b, s, q_per_kv, d)
|
||||
xv = xv.expand(b, s, q_per_kv, d)
|
||||
else:
|
||||
''' To cover the cases where h > 1, we have
|
||||
the following implementation, which is equivalent to:
|
||||
xk = xk.repeat_interleave(q_per_kv, dim=-2)
|
||||
xv = xv.repeat_interleave(q_per_kv, dim=-2)
|
||||
but can avoid calling aten::item() that involves cpu.
|
||||
'''
|
||||
idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
|
||||
xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
|
||||
if self.use_flash_attention:
|
||||
output = self.core_attention(xq, xk, xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seq_len=max_seq_len)
|
||||
# reduce-scatter only support first dimention now
|
||||
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
|
||||
else:
|
||||
xq, xk, xv = [
|
||||
rearrange(x, "b s ... -> s b ...").contiguous()
|
||||
for x in (xq, xk, xv)
|
||||
]
|
||||
output = self.core_attention(xq, xk, xv, mask)
|
||||
output = self.wo(output)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
layer_id: int,
|
||||
multiple_of: int=256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
def swiglu(x):
|
||||
x = torch.chunk(x, 2, dim=-1)
|
||||
return F.silu(x[0]) * x[1]
|
||||
self.swiglu = swiglu
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
dim,
|
||||
2 * hidden_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.swiglu(self.w1(x))
|
||||
output = self.w2(x)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, cfg, layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = cfg.num_attention_heads
|
||||
self.dim = cfg.hidden_size
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.attention = MultiQueryAttention(
|
||||
cfg,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
cfg,
|
||||
dim=cfg.hidden_size,
|
||||
hidden_dim=cfg.ffn_hidden_size,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
self.ffn_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
residual = self.attention.forward(
|
||||
self.attention_norm(x), mask,
|
||||
cu_seqlens, max_seq_len
|
||||
)
|
||||
h = x + residual
|
||||
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + ffn_res
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_size=8192,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = config.num_layers
|
||||
self.layers = self._build_layers(config)
|
||||
|
||||
def _build_layers(self, config):
|
||||
layers = torch.nn.ModuleList()
|
||||
for layer_id in range(self.num_layers):
|
||||
layers.append(
|
||||
TransformerBlock(
|
||||
config,
|
||||
layer_id=layer_id + 1 ,
|
||||
)
|
||||
)
|
||||
return layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens=None,
|
||||
max_seq_len=None,
|
||||
):
|
||||
|
||||
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
|
||||
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
|
||||
|
||||
for lid, layer in enumerate(self.layers):
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens,
|
||||
max_seq_len,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step1Model(PreTrainedModel):
|
||||
config_class=PretrainedConfig
|
||||
@with_empty_init
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.tok_embeddings = LLaMaEmbedding(config)
|
||||
self.transformer = Transformer(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
|
||||
hidden_states = self.transformer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class STEP1TextEncoder(torch.nn.Module):
|
||||
def __init__(self, model_dir, max_length=320):
|
||||
super(STEP1TextEncoder, self).__init__()
|
||||
self.max_length = max_length
|
||||
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
|
||||
text_encoder = Step1Model.from_pretrained(model_dir)
|
||||
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, torch_dtype=torch.bfloat16):
|
||||
model = STEP1TextEncoder(path).to(torch_dtype)
|
||||
return model
|
||||
|
||||
@torch.no_grad
|
||||
def forward(self, prompts, with_mask=True, max_length=None, device="cuda"):
|
||||
self.device = device
|
||||
with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
|
||||
if type(prompts) is str:
|
||||
prompts = [prompts]
|
||||
|
||||
txt_tokens = self.text_tokenizer(
|
||||
prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
y = self.text_encoder(
|
||||
txt_tokens.input_ids.to(self.device),
|
||||
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
|
||||
)
|
||||
y_mask = txt_tokens.attention_mask
|
||||
return y.transpose(0,1), y_mask
|
||||
|
||||
1132
diffsynth/models/stepvideo_vae.py
Normal file
1132
diffsynth/models/stepvideo_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
254
diffsynth/models/wanx_text_encoder.py
Normal file
254
diffsynth/models/wanx_text_encoder.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
clamp = torch.finfo(x.dtype).max - 1000
|
||||
x = torch.clamp(x, min=-clamp, max=clamp)
|
||||
return x
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(
|
||||
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-6):
|
||||
super(T5LayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
||||
self.eps)
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
x = x.type_as(self.weight)
|
||||
return self.weight * x
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
||||
assert dim_attn % num_heads == 0
|
||||
super(T5Attention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim_attn // num_heads
|
||||
|
||||
# layers
|
||||
self.q = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.k = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.v = nn.Linear(dim, dim_attn, bias=False)
|
||||
self.o = nn.Linear(dim_attn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, context=None, mask=None, pos_bias=None):
|
||||
"""
|
||||
x: [B, L1, C].
|
||||
context: [B, L2, C] or None.
|
||||
mask: [B, L2] or [B, L1, L2] or None.
|
||||
"""
|
||||
# check inputs
|
||||
context = x if context is None else context
|
||||
b, n, c = x.size(0), self.num_heads, self.head_dim
|
||||
|
||||
# compute query, key, value
|
||||
q = self.q(x).view(b, -1, n, c)
|
||||
k = self.k(context).view(b, -1, n, c)
|
||||
v = self.v(context).view(b, -1, n, c)
|
||||
|
||||
# attention bias
|
||||
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
||||
if pos_bias is not None:
|
||||
attn_bias += pos_bias
|
||||
if mask is not None:
|
||||
assert mask.ndim in [2, 3]
|
||||
mask = mask.view(b, 1, 1,
|
||||
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
||||
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
||||
|
||||
# compute attention (T5 does not use scaling)
|
||||
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
||||
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
||||
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
||||
|
||||
# output
|
||||
x = x.reshape(b, -1, n * c)
|
||||
x = self.o(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_ffn, dropout=0.1):
|
||||
super(T5FeedForward, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_ffn = dim_ffn
|
||||
|
||||
# layers
|
||||
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
||||
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
||||
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) * self.gate(x)
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class T5SelfAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_attn,
|
||||
dim_ffn,
|
||||
num_heads,
|
||||
num_buckets,
|
||||
shared_pos=True,
|
||||
dropout=0.1):
|
||||
super(T5SelfAttention, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.norm1 = T5LayerNorm(dim)
|
||||
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
||||
self.norm2 = T5LayerNorm(dim)
|
||||
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
||||
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True)
|
||||
|
||||
def forward(self, x, mask=None, pos_bias=None):
|
||||
e = pos_bias if self.shared_pos else self.pos_embedding(
|
||||
x.size(1), x.size(1))
|
||||
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
||||
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class T5RelativeEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
||||
super(T5RelativeEmbedding, self).__init__()
|
||||
self.num_buckets = num_buckets
|
||||
self.num_heads = num_heads
|
||||
self.bidirectional = bidirectional
|
||||
self.max_dist = max_dist
|
||||
|
||||
# layers
|
||||
self.embedding = nn.Embedding(num_buckets, num_heads)
|
||||
|
||||
def forward(self, lq, lk):
|
||||
device = self.embedding.weight.device
|
||||
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
||||
# torch.arange(lq).unsqueeze(1).to(device)
|
||||
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
||||
torch.arange(lq, device=device).unsqueeze(1)
|
||||
rel_pos = self._relative_position_bucket(rel_pos)
|
||||
rel_pos_embeds = self.embedding(rel_pos)
|
||||
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
||||
0) # [1, N, Lq, Lk]
|
||||
return rel_pos_embeds.contiguous()
|
||||
|
||||
def _relative_position_bucket(self, rel_pos):
|
||||
# preprocess
|
||||
if self.bidirectional:
|
||||
num_buckets = self.num_buckets // 2
|
||||
rel_buckets = (rel_pos > 0).long() * num_buckets
|
||||
rel_pos = torch.abs(rel_pos)
|
||||
else:
|
||||
num_buckets = self.num_buckets
|
||||
rel_buckets = 0
|
||||
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
||||
|
||||
# embeddings for small and large positions
|
||||
max_exact = num_buckets // 2
|
||||
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
||||
math.log(self.max_dist / max_exact) *
|
||||
(num_buckets - max_exact)).long()
|
||||
rel_pos_large = torch.min(
|
||||
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
||||
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
||||
return rel_buckets
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, T5LayerNorm):
|
||||
nn.init.ones_(m.weight)
|
||||
elif isinstance(m, T5FeedForward):
|
||||
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
||||
elif isinstance(m, T5Attention):
|
||||
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
||||
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
||||
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
||||
elif isinstance(m, T5RelativeEmbedding):
|
||||
nn.init.normal_(
|
||||
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
||||
|
||||
|
||||
class WanXTextEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab=256384,
|
||||
dim=4096,
|
||||
dim_attn=4096,
|
||||
dim_ffn=10240,
|
||||
num_heads=64,
|
||||
num_layers=24,
|
||||
num_buckets=32,
|
||||
shared_pos=False,
|
||||
dropout=0.1):
|
||||
super(WanXTextEncoder, self).__init__()
|
||||
self.dim = dim
|
||||
self.dim_attn = dim_attn
|
||||
self.dim_ffn = dim_ffn
|
||||
self.num_heads = num_heads
|
||||
self.num_layers = num_layers
|
||||
self.num_buckets = num_buckets
|
||||
self.shared_pos = shared_pos
|
||||
|
||||
# layers
|
||||
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
||||
else nn.Embedding(vocab, dim)
|
||||
self.pos_embedding = T5RelativeEmbedding(
|
||||
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.blocks = nn.ModuleList([
|
||||
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
||||
shared_pos, dropout) for _ in range(num_layers)
|
||||
])
|
||||
self.norm = T5LayerNorm(dim)
|
||||
|
||||
# initialize weights
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(self, ids, mask=None):
|
||||
x = self.token_embedding(ids)
|
||||
x = self.dropout(x)
|
||||
e = self.pos_embedding(x.size(1),
|
||||
x.size(1)) if self.shared_pos else None
|
||||
for block in self.blocks:
|
||||
x = block(x, mask, pos_bias=e)
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
794
diffsynth/models/wanx_vae.py
Normal file
794
diffsynth/models/wanx_vae.py
Normal file
@@ -0,0 +1,794 @@
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
def block_causal_mask(x, block_size):
|
||||
# params
|
||||
b, n, s, _, device = *x.size(), x.device
|
||||
assert s % block_size == 0
|
||||
num_blocks = s // block_size
|
||||
|
||||
# build mask
|
||||
mask = torch.zeros(b, n, s, s, dtype=torch.bool, device=device)
|
||||
for i in range(num_blocks):
|
||||
mask[:, :,
|
||||
i * block_size:(i + 1) * block_size, :(i + 1) * block_size] = 1
|
||||
return mask
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim * 2, (3, 1, 1),
|
||||
padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(dim,
|
||||
dim, (3, 1, 1),
|
||||
stride=(2, 1, 1),
|
||||
padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(
|
||||
0, 1, 3, 2).contiguous().chunk(3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
#attn_mask=block_causal_mask(q, block_size=h * w)
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(out_dim, out_dim, dropout),
|
||||
AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(ResidualBlock(dims[0], dims[0], dropout),
|
||||
AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class VideoVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=96,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=mu.dtype, device=mu.device) for s in scale]
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=mu.dtype, device=mu.device)
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
scale = [s.to(dtype=z.dtype, device=z.device) for s in scale]
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
scale = scale.to(dtype=z.dtype, device=z.device)
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
# cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
|
||||
|
||||
class WanXVideoVAE(nn.Module):
|
||||
|
||||
def __init__(self, z_dim=16):
|
||||
super().__init__()
|
||||
|
||||
mean = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
std = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
self.mean = torch.tensor(mean)
|
||||
self.std = torch.tensor(std)
|
||||
self.scale = [self.mean, 1.0 / self.std]
|
||||
|
||||
# init model
|
||||
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
|
||||
self.upsampling_factor = 8
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, _, H, W = data.shape
|
||||
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
||||
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
||||
|
||||
h = repeat(h, "H -> H W", H=H, W=W)
|
||||
w = repeat(w, "W -> H W", H=H, W=W)
|
||||
|
||||
mask = torch.stack([h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = hidden_states.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = T * 4 - 3
|
||||
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
|
||||
).to(dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
target_h = h * self.upsampling_factor
|
||||
target_w = w * self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float().clamp_(-1, 1)
|
||||
return values
|
||||
|
||||
|
||||
def tiled_encode(self, video, device, tile_size, tile_stride):
|
||||
_, _, T, H, W = video.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
h_, w_ = h + size_h, w + size_w
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
out_T = (T + 3) // 4
|
||||
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
|
||||
|
||||
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
|
||||
).to(dtype=video.dtype, device=data_device)
|
||||
|
||||
target_h = h // self.upsampling_factor
|
||||
target_w = w // self.upsampling_factor
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h:target_h + hidden_states_batch.shape[3],
|
||||
target_w:target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
values = values / weight
|
||||
values = values.float()
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, self.scale)
|
||||
return x.float()
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, self.scale)
|
||||
return video.float().clamp_(-1, 1)
|
||||
|
||||
|
||||
def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)):
|
||||
|
||||
videos = [video.to("cpu") for video in videos]
|
||||
hidden_states = []
|
||||
for video in videos:
|
||||
video = video.unsqueeze(0)
|
||||
if tiled:
|
||||
assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}"
|
||||
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
|
||||
else:
|
||||
hidden_state = self.single_encode(video, device)
|
||||
hidden_state = hidden_state.squeeze(0)
|
||||
hidden_states.append(hidden_state)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||
videos = []
|
||||
for hidden_state in hidden_states:
|
||||
hidden_state = hidden_state.unsqueeze(0)
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_state, device)
|
||||
video = video.squeeze(0)
|
||||
videos.append(video)
|
||||
return videos
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanXVideoVAEStateDictConverter()
|
||||
|
||||
|
||||
class WanXVideoVAEStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict['model_state']:
|
||||
state_dict_['model.' + name] = state_dict['model_state'][name]
|
||||
return state_dict_
|
||||
@@ -10,4 +10,5 @@ from .cog_video import CogVideoPipeline
|
||||
from .omnigen_image import OmnigenImagePipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
from .hunyuan_video import HunyuanVideoPipeline
|
||||
from .step_video import StepVideoPipeline
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
|
||||
@@ -101,12 +101,22 @@ class BasePipeline(torch.nn.Module):
|
||||
if model_name not in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.cpu()
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "offload"):
|
||||
module.offload()
|
||||
else:
|
||||
model.cpu()
|
||||
# load the needed models to device
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(self.device)
|
||||
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
||||
for module in model.modules():
|
||||
if hasattr(module, "onload"):
|
||||
module.onload()
|
||||
else:
|
||||
model.to(self.device)
|
||||
# fresh the cuda cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -11,6 +11,9 @@ from PIL import Image
|
||||
from ..models.tiler import FastTileWorker
|
||||
from transformers import SiglipVisionModel
|
||||
from copy import deepcopy
|
||||
from transformers.models.t5.modeling_t5 import T5LayerNorm, T5DenseActDense, T5DenseGatedActDense
|
||||
from ..models.flux_dit import RMSNorm
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
@@ -31,6 +34,105 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder', 'controlnet', 'ipadapter', 'ipadapter_image_encoder']
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
T5LayerNorm: AutoWrappedModule,
|
||||
T5DenseActDense: AutoWrappedModule,
|
||||
T5DenseGatedActDense: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cuda",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_decoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_decoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae_encoder.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae_encoder,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.GroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
def denoising_model(self):
|
||||
return self.dit
|
||||
|
||||
@@ -62,10 +164,10 @@ class FluxImagePipeline(BasePipeline):
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
|
||||
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None, torch_dtype=None):
|
||||
pipe = FluxImagePipeline(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
torch_dtype=model_manager.torch_dtype if torch_dtype is None else torch_dtype,
|
||||
)
|
||||
pipe.fetch_models(model_manager, controlnet_config_units, prompt_refiner_classes, prompt_extender_classes)
|
||||
return pipe
|
||||
|
||||
209
diffsynth/pipelines/step_video.py
Normal file
209
diffsynth/pipelines/step_video.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from ..models import ModelManager
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
|
||||
from ..models.stepvideo_text_encoder import STEP1TextEncoder
|
||||
from ..models.stepvideo_dit import StepVideoModel
|
||||
from ..models.stepvideo_vae import StepVideoVAE
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import StepVideoPrompter
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings
|
||||
from ..models.stepvideo_dit import RMSNorm
|
||||
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
|
||||
|
||||
|
||||
|
||||
class StepVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0.0, extra_one_step=True, shift=13.0, reverse_sigmas=True, num_train_timesteps=1)
|
||||
self.prompter = StepVideoPrompter()
|
||||
self.text_encoder_1: HunyuanDiTCLIPTextEncoder = None
|
||||
self.text_encoder_2: STEP1TextEncoder = None
|
||||
self.dit: StepVideoModel = None
|
||||
self.vae: StepVideoVAE = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae']
|
||||
|
||||
|
||||
def enable_vram_management(self, num_persistent_param_in_dit=None):
|
||||
dtype = next(iter(self.text_encoder_1.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_1,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
BertEmbeddings: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=torch.float32,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.text_encoder_2.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.text_encoder_2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Embedding: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.dit.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.dit,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=self.device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
self.vae,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
CausalConv: AutoWrappedModule,
|
||||
CausalConvAfterNorm: AutoWrappedModule,
|
||||
Upsample2D: AutoWrappedModule,
|
||||
BaseGroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
)
|
||||
self.enable_cpu_offload()
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.text_encoder_1 = model_manager.fetch_model("hunyuan_dit_clip_text_encoder")
|
||||
self.text_encoder_2 = model_manager.fetch_model("stepvideo_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("stepvideo_dit")
|
||||
self.vae = model_manager.fetch_model("stepvideo_vae")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = StepVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True):
|
||||
clip_embeds, llm_embeds, llm_mask = self.prompter.encode_prompt(prompt, device=self.device, positive=positive)
|
||||
clip_embeds = clip_embeds.to(dtype=self.torch_dtype, device=self.device)
|
||||
llm_embeds = llm_embeds.to(dtype=self.torch_dtype, device=self.device)
|
||||
llm_mask = llm_mask.to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"encoder_hidden_states_2": clip_embeds, "encoder_hidden_states": llm_embeds, "encoder_attention_mask": llm_mask}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
height=544,
|
||||
width=992,
|
||||
num_frames=204,
|
||||
cfg_scale=9.0,
|
||||
num_inference_steps=30,
|
||||
tiled=True,
|
||||
tile_size=(34, 34),
|
||||
tile_stride=(16, 16),
|
||||
smooth_scale=0.6,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Initialize noise
|
||||
latents = self.generate_noise((1, max(num_frames//17*3, 1), 64, height//16, width//16), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
@@ -8,3 +8,5 @@ from .flux_prompter import FluxPrompter
|
||||
from .omost import OmostPromter
|
||||
from .cog_prompter import CogPrompter
|
||||
from .hunyuan_video_prompter import HunyuanVideoPrompter
|
||||
from .stepvideo_prompter import StepVideoPrompter
|
||||
from .wanx_prompter import WanXPrompter
|
||||
56
diffsynth/prompters/stepvideo_prompter.py
Normal file
56
diffsynth/prompters/stepvideo_prompter.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder
|
||||
from ..models.stepvideo_text_encoder import STEP1TextEncoder
|
||||
from transformers import BertTokenizer
|
||||
import os, torch
|
||||
|
||||
|
||||
class StepVideoPrompter(BasePrompter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_1_path=None,
|
||||
):
|
||||
if tokenizer_1_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_1_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = BertTokenizer.from_pretrained(tokenizer_1_path)
|
||||
|
||||
def fetch_models(self, text_encoder_1: HunyuanDiTCLIPTextEncoder = None, text_encoder_2: STEP1TextEncoder = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, max_length, device):
|
||||
text_inputs = self.tokenizer_1(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
prompt_embeds = self.text_encoder_1(
|
||||
text_inputs.input_ids.to(device),
|
||||
attention_mask=text_inputs.attention_mask.to(device),
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt_using_llm(self, prompt, max_length, device):
|
||||
y, y_mask = self.text_encoder_2(prompt, max_length=max_length, device=device)
|
||||
return y, y_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda"):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
clip_embeds = self.encode_prompt_using_clip(prompt, max_length=77, device=device)
|
||||
llm_embeds, llm_mask = self.encode_prompt_using_llm(prompt, max_length=320, device=device)
|
||||
|
||||
llm_mask = torch.nn.functional.pad(llm_mask, (clip_embeds.shape[1], 0), value=1)
|
||||
|
||||
return clip_embeds, llm_embeds, llm_mask
|
||||
103
diffsynth/prompters/wanx_prompter.py
Normal file
103
diffsynth/prompters/wanx_prompter.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.wanx_text_encoder import WanXTextEncoder
|
||||
from transformers import AutoTokenizer
|
||||
import os, torch
|
||||
import ftfy
|
||||
import html
|
||||
import string
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
|
||||
class WanXPrompter(BasePrompter):
|
||||
|
||||
def __init__(self, tokenizer_path=None, text_len=512):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_dit/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
||||
self.text_encoder = None
|
||||
|
||||
def fetch_models(self, text_encoder: WanXTextEncoder = None):
|
||||
self.text_encoder = text_encoder
|
||||
|
||||
def encode_prompt(self, prompt, device="cuda"):
|
||||
ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
||||
ids = ids.to(device)
|
||||
mask = mask.to(device)
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
prompt_emb = self.text_encoder(ids, mask)
|
||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
||||
return prompt_emb
|
||||
|
||||
@@ -4,13 +4,14 @@ import torch
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False):
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False, reverse_sigmas=False):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.reverse_sigmas = reverse_sigmas
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
@@ -23,6 +24,8 @@ class FlowMatchScheduler():
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
if self.reverse_sigmas:
|
||||
self.sigmas = 1 - self.sigmas
|
||||
self.timesteps = self.sigmas * self.num_train_timesteps
|
||||
if training:
|
||||
x = self.timesteps
|
||||
@@ -38,7 +41,7 @@ class FlowMatchScheduler():
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
sigma_ = 1 if self.inverse_timesteps else 0
|
||||
sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0
|
||||
else:
|
||||
sigma_ = self.sigmas[timestep_id + 1]
|
||||
prev_sample = sample + model_output * (sigma_ - sigma)
|
||||
|
||||
1
diffsynth/vram_management/__init__.py
Normal file
1
diffsynth/vram_management/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .layers import *
|
||||
95
diffsynth/vram_management/layers.py
Normal file
95
diffsynth/vram_management/layers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch, copy
|
||||
from ..models.utils import init_weights_on_device
|
||||
|
||||
|
||||
def cast_to(weight, dtype, device):
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
|
||||
class AutoWrappedModule(torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
super().__init__()
|
||||
self.module = module.to(dtype=offload_dtype, device=offload_device)
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.module.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
module = self.module
|
||||
else:
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module(*args, **kwargs)
|
||||
|
||||
|
||||
class AutoWrappedLinear(torch.nn.Linear):
|
||||
def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device):
|
||||
with init_weights_on_device(device=torch.device("meta")):
|
||||
super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device)
|
||||
self.weight = module.weight
|
||||
self.bias = module.bias
|
||||
self.offload_dtype = offload_dtype
|
||||
self.offload_device = offload_device
|
||||
self.onload_dtype = onload_dtype
|
||||
self.onload_device = onload_device
|
||||
self.computation_dtype = computation_dtype
|
||||
self.computation_device = computation_device
|
||||
self.state = 0
|
||||
|
||||
def offload(self):
|
||||
if self.state == 1 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
||||
self.state = 0
|
||||
|
||||
def onload(self):
|
||||
if self.state == 0 and (self.offload_dtype != self.onload_dtype or self.offload_device != self.onload_device):
|
||||
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
||||
self.state = 1
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if self.onload_dtype == self.computation_dtype and self.onload_device == self.computation_device:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||
for name, module in model.named_children():
|
||||
for source_module, target_module in module_map.items():
|
||||
if isinstance(module, source_module):
|
||||
num_param = sum(p.numel() for p in module.parameters())
|
||||
if max_num_param is not None and total_num_param + num_param > max_num_param:
|
||||
module_config_ = overflow_module_config
|
||||
else:
|
||||
module_config_ = module_config
|
||||
module_ = target_module(module, **module_config_)
|
||||
setattr(model, name, module_)
|
||||
total_num_param += num_param
|
||||
break
|
||||
else:
|
||||
total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param)
|
||||
return total_num_param
|
||||
|
||||
|
||||
def enable_vram_management(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None):
|
||||
enable_vram_management_recursively(model, module_map, module_config, max_num_param, overflow_module_config, total_num_param=0)
|
||||
model.vram_management_enabled = True
|
||||
|
||||
@@ -8,7 +8,8 @@ We propose EliGen, a novel approach that leverages fine-grained entity-level inf
|
||||
* Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio)
|
||||
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
|
||||
* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
||||
* Training dataset: Coming soon
|
||||
* Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
||||
|
||||
|
||||
## Methodology
|
||||
|
||||
|
||||
18
examples/WanX/test_prompter.py
Normal file
18
examples/WanX/test_prompter.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from diffsynth.prompters import WanXPrompter
|
||||
from diffsynth.models.wanx_text_encoder import WanXTextEncoder
|
||||
|
||||
prompter = WanXPrompter('models/WanX/google/umt5-xxl')
|
||||
text_encoder = WanXTextEncoder()
|
||||
text_encoder.load_state_dict(torch.load('models/WanX/models_t5_umt5-xxl-enc-bf16.pth', map_location='cpu'))
|
||||
text_encoder = text_encoder.eval().requires_grad_(False).to(dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
prompter.fetch_models(text_encoder)
|
||||
|
||||
prompt = '维京战士双手挥舞着大斧,对抗猛犸象,黄昏,雪地中,漫天飞雪'
|
||||
neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
||||
|
||||
prompt_emb = prompter.encode_prompt(prompt)
|
||||
neg_prompt_emb = prompter.encode_prompt(neg_prompt)
|
||||
print(prompt_emb[0]) # torch.Size([31, 4096])
|
||||
print(neg_prompt_emb[0]) # torch.Size([126, 4096])
|
||||
46
examples/WanX/test_vae.py
Normal file
46
examples/WanX/test_vae.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import imageio
|
||||
from diffsynth import ModelManager
|
||||
|
||||
def save_video(tensor,
|
||||
save_file=None,
|
||||
fps=30,
|
||||
nrow=8,
|
||||
normalize=True,
|
||||
value_range=(-1, 1)):
|
||||
|
||||
tensor = tensor.clamp(min(value_range), max(value_range))
|
||||
tensor = torch.stack([
|
||||
torchvision.utils.make_grid(
|
||||
u, nrow=nrow, normalize=normalize, value_range=value_range)
|
||||
for u in tensor.unbind(2)
|
||||
],
|
||||
dim=1).permute(1, 2, 3, 0) #frame, h, w, 3
|
||||
tensor = (tensor * 255).type(torch.uint8).cpu()
|
||||
|
||||
# write video
|
||||
writer = imageio.get_writer(
|
||||
save_file, fps=fps, codec='libx264', quality=8)
|
||||
for frame in tensor.numpy():
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
|
||||
torch.cuda.memory._record_memory_history()
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.float, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/WanX/vae.pth",
|
||||
])
|
||||
|
||||
vae = model_manager.fetch_model('wanxvideo_vae')
|
||||
|
||||
latents = [torch.load('sample.pt')]
|
||||
videos = vae.decode(latents, device=latents[0].device, tiled=True)
|
||||
back_encode = vae.encode(videos, device=latents[0].device, tiled=True)
|
||||
|
||||
videos_back_encode = vae.decode(back_encode, device=latents[0].device, tiled=False)
|
||||
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
|
||||
|
||||
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
|
||||
save_video(videos_back_encode[0][None], save_file='example_backencode.mp4', fps=16, nrow=1)
|
||||
19
examples/stepvideo/README.md
Normal file
19
examples/stepvideo/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Stepvideo
|
||||
|
||||
StepVideo is a state-of-the-art (SoTA) text-to-video pre-trained model with 30 billion parameters and the capability to generate videos up to 204 frames.
|
||||
|
||||
* Model: https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary
|
||||
* GitHub: https://github.com/stepfun-ai/Step-Video-T2V
|
||||
* Technical report: https://arxiv.org/abs/2502.10248
|
||||
|
||||
## Examples
|
||||
|
||||
For original BF16 version, please see [`./stepvideo_text_to_video.py`](./stepvideo_text_to_video.py). 80G VRAM required.
|
||||
|
||||
We also support auto-offload, which can reduce the VRAM requirement to **24GB**; however, it requires 2x time for inference. Please see [`./stepvideo_text_to_video_low_vram.py`](./stepvideo_text_to_video_low_vram.py).
|
||||
|
||||
https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
|
||||
|
||||
For FP8 quantized version, please see [`./stepvideo_text_to_video_quantized.py`](./stepvideo_text_to_video_quantized.py). 40G VRAM required.
|
||||
|
||||
https://github.com/user-attachments/assets/f3697f4e-bc08-47d2-b00a-32d7dfa272ad
|
||||
50
examples/stepvideo/stepvideo_text_to_video.py
Normal file
50
examples/stepvideo/stepvideo_text_to_video.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from modelscope import snapshot_download
|
||||
from diffsynth import ModelManager, StepVideoPipeline, save_video
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download(model_id="stepfun-ai/stepvideo-t2v", cache_dir="models")
|
||||
|
||||
# Load the compiled attention for the LLM text encoder.
|
||||
# If you encounter errors here. Please select other compiled file that matches your environment or delete this line.
|
||||
torch.ops.load_library("models/stepfun-ai/stepvideo-t2v/lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/hunyuan_clip/clip_text_encoder/pytorch_model.bin"],
|
||||
torch_dtype=torch.float32, device="cpu"
|
||||
)
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/step_llm",
|
||||
"models/stepfun-ai/stepvideo-t2v/vae/vae_v2.safetensors",
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
]
|
||||
],
|
||||
torch_dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Enable VRAM management
|
||||
# This model requires 80G VRAM.
|
||||
# In order to reduce VRAM required, please set `num_persistent_param_in_dit` to a small number.
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Run!
|
||||
video = pipe(
|
||||
prompt="一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光。超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。",
|
||||
negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
|
||||
num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1
|
||||
)
|
||||
save_video(
|
||||
video, "video.mp4", fps=25, quality=5,
|
||||
ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"]
|
||||
)
|
||||
54
examples/stepvideo/stepvideo_text_to_video_low_vram.py
Normal file
54
examples/stepvideo/stepvideo_text_to_video_low_vram.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from modelscope import snapshot_download
|
||||
from diffsynth import ModelManager, StepVideoPipeline, save_video
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download(model_id="stepfun-ai/stepvideo-t2v", cache_dir="models")
|
||||
|
||||
# Load the compiled attention for the LLM text encoder.
|
||||
# If you encounter errors here. Please select other compiled file that matches your environment or delete this line.
|
||||
torch.ops.load_library("models/stepfun-ai/stepvideo-t2v/lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/hunyuan_clip/clip_text_encoder/pytorch_model.bin"],
|
||||
torch_dtype=torch.float32, device="cpu"
|
||||
)
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/step_llm",
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
]
|
||||
],
|
||||
torch_dtype=torch.bfloat16, device="cpu" # You can set torch_dtype=torch.bfloat16 to reduce RAM (not VRAM) usage.
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/vae/vae_v2.safetensors"],
|
||||
torch_dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Enable VRAM management
|
||||
# This model requires 24G VRAM.
|
||||
# In order to speed up, please set `num_persistent_param_in_dit` to a large number or None (unlimited).
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=0)
|
||||
|
||||
# Run!
|
||||
video = pipe(
|
||||
prompt="一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光。超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。",
|
||||
negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
|
||||
num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1,
|
||||
tiled=True, tile_size=(34, 34), tile_stride=(16, 16)
|
||||
)
|
||||
save_video(
|
||||
video, "video.mp4", fps=25, quality=5,
|
||||
ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"]
|
||||
)
|
||||
53
examples/stepvideo/stepvideo_text_to_video_quantized.py
Normal file
53
examples/stepvideo/stepvideo_text_to_video_quantized.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from modelscope import snapshot_download
|
||||
from diffsynth import ModelManager, StepVideoPipeline, save_video
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download(model_id="stepfun-ai/stepvideo-t2v", cache_dir="models")
|
||||
|
||||
# Load the compiled attention for the LLM text encoder.
|
||||
# If you encounter errors here. Please select other compiled file that matches your environment or delete this line.
|
||||
torch.ops.load_library("models/stepfun-ai/stepvideo-t2v/lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/hunyuan_clip/clip_text_encoder/pytorch_model.bin"],
|
||||
torch_dtype=torch.float32, device="cpu"
|
||||
)
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/step_llm",
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
]
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, device="cpu"
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/vae/vae_v2.safetensors"],
|
||||
torch_dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Enable VRAM management
|
||||
# This model requires 40G VRAM.
|
||||
# In order to reduce VRAM required, please set `num_persistent_param_in_dit` to a small number.
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Run!
|
||||
video = pipe(
|
||||
prompt="一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光。超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。",
|
||||
negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
|
||||
num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1
|
||||
)
|
||||
save_video(
|
||||
video, "video.mp4", fps=25, quality=5,
|
||||
ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"]
|
||||
)
|
||||
3
examples/vram_management/README.md
Normal file
3
examples/vram_management/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# VRAM Management
|
||||
|
||||
Experimental feature. Still under development.
|
||||
25
examples/vram_management/flux_text_to_image.py
Normal file
25
examples/vram_management/flux_text_to_image.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, FluxImagePipeline
|
||||
|
||||
|
||||
model_manager = ModelManager(
|
||||
file_path_list=[
|
||||
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/text_encoder_2",
|
||||
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors",
|
||||
"models/FLUX/FLUX.1-dev/ae.safetensors",
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn,
|
||||
device="cpu"
|
||||
)
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Enable VRAM management
|
||||
# `num_persistent_param_in_dit` indicates the number of parameters that reside persistently in VRAM within the DiT model.
|
||||
# When `num_persistent_param_in_dit=None`, it means all parameters reside persistently in memory.
|
||||
# When `num_persistent_param_in_dit=7*10**9`, it indicates that 7 billion parameters reside persistently in memory.
|
||||
# When `num_persistent_param_in_dit=0`, it means no parameters reside persistently in memory, and they are loaded layer by layer during inference.
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
image = pipe(prompt="a beautiful orange cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
Reference in New Issue
Block a user