mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan2.2 A14B I2V&T2V
This commit is contained in:
@@ -226,10 +226,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
self.image_encoder: WanImageEncoder = None
|
||||
self.dit: WanModel = None
|
||||
self.dit2: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.vace: VaceWanModel = None
|
||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
||||
self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace")
|
||||
self.unit_runner = PipelineUnitRunner()
|
||||
self.units = [
|
||||
WanVideoUnit_ShapeChecker(),
|
||||
@@ -238,6 +239,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedder(),
|
||||
WanVideoUnit_ImageVaeEmbedder(),
|
||||
WanVideoUnit_ImageEmbedderNoClip(),
|
||||
WanVideoUnit_FunControl(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
@@ -329,6 +331,37 @@ class WanVideoPipeline(BasePipeline):
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.dit2 is not None:
|
||||
dtype = next(iter(self.dit2.parameters())).dtype
|
||||
device = "cpu" if vram_limit is not None else self.device
|
||||
enable_vram_management(
|
||||
self.dit2,
|
||||
module_map = {
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: WanAutoCastLayerNorm,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device=device,
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
max_num_param=num_persistent_param_in_dit,
|
||||
overflow_module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
offload_device="cpu",
|
||||
onload_dtype=dtype,
|
||||
onload_device="cpu",
|
||||
computation_dtype=self.torch_dtype,
|
||||
computation_device=self.device,
|
||||
),
|
||||
vram_limit=vram_limit,
|
||||
)
|
||||
if self.vae is not None:
|
||||
dtype = next(iter(self.vae.parameters())).dtype
|
||||
enable_vram_management(
|
||||
@@ -427,6 +460,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
for block in self.dit.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit.forward = types.MethodType(usp_dit_forward, self.dit)
|
||||
if self.dit2 is not None:
|
||||
for block in self.dit2.blocks:
|
||||
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
|
||||
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
self.use_unified_sequence_parallel = True
|
||||
|
||||
@@ -473,6 +510,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Load models
|
||||
pipe.text_encoder = model_manager.fetch_model("wan_video_text_encoder")
|
||||
pipe.dit = model_manager.fetch_model("wan_video_dit")
|
||||
num_dits = len([model_name for model_name in model_manager.model_name if model_name == "wan_video_dit"])
|
||||
if num_dits == 2:
|
||||
pipe.dit2 = [model for model, model_name in zip(model_manager.model, model_manager.model_name) if model_name == "wan_video_dit"][-1]
|
||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
@@ -523,6 +563,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 5.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Boundary
|
||||
boundary: Optional[float] = 0.875,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
sigma_shift: Optional[float] = 5.0,
|
||||
@@ -575,8 +617,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
# switch high_noise DiT to low_noise DiT
|
||||
if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps:
|
||||
print("switching to low noise DiT")
|
||||
self.load_models_to_device(["dit2", "motion_controller", "vace"])
|
||||
models["dit"] = models.pop("dit2")
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
|
||||
if cfg_scale != 1.0:
|
||||
@@ -737,7 +783,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or pipe.dit.seperated_timestep:
|
||||
if input_image is None or pipe.image_encoder is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
@@ -767,6 +813,9 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
|
||||
"""
|
||||
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
@@ -815,6 +864,42 @@ class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
|
||||
return out1, out2
|
||||
|
||||
|
||||
class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit):
|
||||
"""
|
||||
Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
|
||||
if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
|
||||
msk[:, 1:] = 0
|
||||
if end_image is not None:
|
||||
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
|
||||
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
|
||||
msk[:, -1:] = 1
|
||||
else:
|
||||
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
|
||||
|
||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"fused_y": y}
|
||||
|
||||
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1116,6 +1201,7 @@ def model_fn_wan_video(
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
fused_y: Optional[torch.Tensor] = None,
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
@@ -1181,11 +1267,13 @@ def model_fn_wan_video(
|
||||
x = torch.concat([x] * context.shape[0], dim=0)
|
||||
if timestep.shape[0] != context.shape[0]:
|
||||
timestep = torch.concat([timestep] * context.shape[0], dim=0)
|
||||
|
||||
|
||||
if dit.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
if fused_y is not None:
|
||||
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
|
||||
|
||||
# Add camera control
|
||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
||||
|
||||
Reference in New Issue
Block a user