mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -39,17 +39,21 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.dit2: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
WanVideoUnit_NoiseInitializer(),
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedder(),
|
||||
WanVideoUnit_ImageEmbedderVAE(),
|
||||
WanVideoUnit_ImageEmbedderCLIP(),
|
||||
WanVideoUnit_ImageEmbedderFused(),
|
||||
WanVideoUnit_FunControl(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
@@ -69,7 +73,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def training_loss(self, **inputs):
|
||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
||||
@@ -141,6 +147,37 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit2 is not None:
|
||||
dtype = next(iter(self.dit2.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
@@ -239,6 +276,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
for block in self.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
||||
if self.dit2 is not None:
|
||||
for block in self.dit2.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
@@ -283,10 +324,18 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Load models
|
||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||
num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"])
|
||||
if num_dits == 2:
|
||||
pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1]
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
# Size division factor
|
||||
if pipe.vae is not None:
|
||||
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
|
||||
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer_config.download_if_necessary(use_usp=use_usp)
|
||||
@@ -333,6 +382,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Boundary
|
||||
switch_DiT_boundary: Optional[float] = 0.875,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
@@ -385,8 +436,14 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
# Switch DiT if necessary
|
||||
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||
self.load_models_to_device(self.in_iteration_models_2)
|
||||
models["dit"] = self.dit2
|
||||
|
||||
# Timestep
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
if cfg_scale != 1.0:
|
||||
@@ -400,6 +457,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
|
||||
if "first_frame_latents" in inputs_shared:
|
||||
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
||||
|
||||
# VACE (TODO: remove it)
|
||||
if vace_reference_image is not None:
|
||||
@@ -433,7 +492,8 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
length = (num_frames - 1) // 4 + 1
|
||||
if vace_reference_image is not None:
|
||||
length += 1
|
||||
noise = pipe.generate_noise((1, 16, length, height//8, width//8), seed=seed, rand_device=rand_device)
|
||||
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
|
||||
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
|
||||
if vace_reference_image is not None:
|
||||
noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2)
|
||||
return {"noise": noise}
|
||||
@@ -482,6 +542,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
"""
|
||||
Deprecated
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
@@ -489,7 +552,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
if input_image is None or pipe.image_encoder is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
@@ -517,13 +580,90 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "height", "width"),
|
||||
onload_model_names=("image_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
|
||||
if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
clip_context = pipe.image_encoder.encode_image([image])
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
if pipe.dit.has_image_pos_emb:
|
||||
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
|
||||
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"clip_feature": clip_context}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or not pipe.dit.require_vae_embedding:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"y": y}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
|
||||
"""
|
||||
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
|
||||
z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
latents[:, :, 0: 1] = z
|
||||
return {"latents": latents, "fuse_vae_embedding_in_latents": True, "first_frame_latents": z}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
|
||||
onload_model_names=("vae")
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
|
||||
@@ -547,7 +687,7 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("reference_image", "height", "width", "reference_image"),
|
||||
onload_model_names=("vae")
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
|
||||
@@ -832,6 +972,7 @@ def model_fn_wan_video(
|
||||
use_gradient_checkpointing: bool = False,
|
||||
use_gradient_checkpointing_offload: bool = False,
|
||||
control_camera_latents_input = None,
|
||||
fuse_vae_embedding_in_latents: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
if sliding_window_size is not None and sliding_window_stride is not None:
|
||||
@@ -865,9 +1006,20 @@ def model_fn_wan_video(
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
|
||||
# Timestep
|
||||
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
|
||||
timestep = torch.concat([
|
||||
torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
|
||||
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
|
||||
]).flatten()
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
|
||||
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
|
||||
else:
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
# Motion Controller
|
||||
if motion_bucket_id is not None and motion_controller is not None:
|
||||
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
@@ -878,15 +1030,16 @@ def model_fn_wan_video(
|
||||
x = torch.concat([x] * context.shape[0], dim=0)
|
||||
if timestep.shape[0] != context.shape[0]:
|
||||
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
||||
|
||||
if dit.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
|
||||
# Image Embedding
|
||||
if y is not None and dit.require_vae_embedding:
|
||||
x = torch.cat([x, y], dim=1)
|
||||
if clip_feature is not None and dit.require_clip_embedding:
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
# Add camera control
|
||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
|
||||
Reference in New Issue
Block a user