mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
@@ -3,6 +3,193 @@ from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
"""
|
||||
Get n-D meshgrid with start, stop and num.
|
||||
|
||||
Args:
|
||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||
n-tuples.
|
||||
*args: See above.
|
||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
grid (np.ndarray): [dim, ...]
|
||||
"""
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[torch.FloatTensor, int],
|
||||
theta: float = 10000.0,
|
||||
use_real: bool = False,
|
||||
theta_rescale_factor: float = 1.0,
|
||||
interpolation_factor: float = 1.0,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||
|
||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
if theta_rescale_factor != 1.0:
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||
) # [D/2]
|
||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(
|
||||
torch.ones_like(freqs), freqs
|
||||
) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
start,
|
||||
*args,
|
||||
theta=10000.0,
|
||||
use_real=False,
|
||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
"""
|
||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||
|
||||
Args:
|
||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
*args: See above.
|
||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||
part and an imaginary part separately.
|
||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
pos_embed (torch.Tensor): [HW, D/2]
|
||||
"""
|
||||
|
||||
grid = get_meshgrid_nd(
|
||||
start, *args, dim=len(rope_dim_list)
|
||||
) # [3, W, H, D] / [2, W, H]
|
||||
|
||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||
assert len(theta_rescale_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||
assert len(interpolation_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
# use 1/ndim of dimensions to encode grid_axis
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i],
|
||||
grid[i].reshape(-1),
|
||||
theta,
|
||||
use_real=use_real,
|
||||
theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
) # 2 x [WHD, rope_dim_list[i]]
|
||||
embs.append(emb)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||
return emb
|
||||
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
[16, 56, 56],
|
||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
||||
theta=256,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1,
|
||||
)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
@@ -406,13 +593,16 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
model.to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def prepare_freqs(self, latents):
|
||||
return HunyuanVideoRope(latents)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
text_states: torch.Tensor = None,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
text_states_2: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
@@ -420,9 +610,9 @@ class HunyuanVideoDiT(torch.nn.Module):
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(text_states_2) + self.guidance_in(guidance, dtype=torch.float32)
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(text_states, t, text_mask)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
@@ -2,15 +2,17 @@ import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
|
||||
|
||||
|
||||
|
||||
class SD3TextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
@@ -13,36 +15,48 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
# 参照diffsynth的排序,text_encoder_1指CLIP;text_encoder_2指llm,与hunyuanvideo源代码刚好相反
|
||||
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
|
||||
self.prompter = HunyuanVideoPrompter()
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: LlamaModel = None
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'vae_decoder']
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("hunyuan_video_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, device=None):
|
||||
|
||||
pipe = HunyuanVideoPipeline(
|
||||
device=model_manager.device if device is None else device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
# VRAM management is automatically enabled.
|
||||
if enable_vram_management:
|
||||
pipe.enable_cpu_offload()
|
||||
pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device)
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(prompt,
|
||||
device=self.device,
|
||||
positive=positive,
|
||||
clip_sequence_length=clip_sequence_length,
|
||||
llm_sequence_length=llm_sequence_length)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
|
||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
|
||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
@@ -50,21 +64,48 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
seed=None,
|
||||
progress_bar_cmd=tqdm,
|
||||
height=720,
|
||||
width=1280,
|
||||
num_frames=129,
|
||||
embedded_guidance=6.0,
|
||||
cfg_scale=1.0,
|
||||
num_inference_steps=30,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# encode prompt
|
||||
# prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# test data
|
||||
latents = torch.load('latents.pt').to(device=self.device, dtype=self.torch_dtype) # torch.Size([1, 16, 33, 90, 160])
|
||||
latents = latents[:, :, :2, :, :]
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
self.load_models_to_device([])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = dict(use_temporal_tiling=False, use_spatial_tiling=False, sample_ssize=256, sample_tsize=64)
|
||||
# decode
|
||||
|
||||
@@ -70,12 +70,17 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
raise TypeError(f"Unsupported prompt type: {type(text)}")
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, max_length, device):
|
||||
input_ids = self.tokenizer_1(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True).input_ids.to(device)
|
||||
return self.text_encoder_1(input_ids=input_ids)[0]
|
||||
tokenized_result = self.tokenizer_1(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = tokenized_result.input_ids.to(device)
|
||||
attention_mask = tokenized_result.attention_mask.to(device)
|
||||
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
|
||||
|
||||
def encode_prompt_using_llm(self,
|
||||
prompt,
|
||||
@@ -110,7 +115,7 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
last_hidden_state = last_hidden_state[:, crop_start:]
|
||||
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
|
||||
|
||||
return last_hidden_state
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
@@ -142,8 +147,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb = self.encode_prompt_using_llm(
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
@@ -4,18 +4,22 @@ import torch
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False):
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
|
||||
Reference in New Issue
Block a user