mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
wan-refactor
This commit is contained in:
@@ -68,6 +68,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
torch.nn.LayerNorm: AutoWrappedModule,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
torch.nn.Conv2d: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -237,6 +238,18 @@ class WanVideoPipeline(BasePipeline):
|
||||
return latents
|
||||
|
||||
|
||||
def prepare_reference_image(self, reference_image, height, width):
|
||||
if reference_image is not None:
|
||||
self.load_models_to_device(["vae"])
|
||||
reference_image = reference_image.resize((width, height))
|
||||
reference_image = self.preprocess_images([reference_image])
|
||||
reference_image = torch.stack(reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
reference_latents = self.vae.encode(reference_image, device=self.device)
|
||||
return {"reference_latents": reference_latents}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def prepare_controlnet_kwargs(self, control_video, num_frames, height, width, clip_feature=None, y=None, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
if control_video is not None:
|
||||
control_latents = self.encode_control_video(control_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
@@ -339,6 +352,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
end_image=None,
|
||||
input_video=None,
|
||||
control_video=None,
|
||||
reference_image=None,
|
||||
vace_video=None,
|
||||
vace_video_mask=None,
|
||||
vace_reference_image=None,
|
||||
@@ -398,6 +412,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
# Reference image
|
||||
reference_image_kwargs = self.prepare_reference_image(reference_image, height, width)
|
||||
|
||||
# ControlNet
|
||||
if control_video is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
@@ -435,14 +452,14 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, **reference_image_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -526,6 +543,7 @@ def model_fn_wan_video(
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
reference_latents = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
tea_cache: TeaCache = None,
|
||||
@@ -552,6 +570,12 @@ def model_fn_wan_video(
|
||||
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
# Reference image
|
||||
if reference_latents is not None:
|
||||
reference_latents = dit.ref_conv(reference_latents[:, :, 0]).flatten(2).transpose(1, 2)
|
||||
x = torch.concat([reference_latents, x], dim=1)
|
||||
f += 1
|
||||
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
@@ -580,6 +604,10 @@ def model_fn_wan_video(
|
||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
if reference_latents is not None:
|
||||
x = x[:, reference_latents.shape[1]:]
|
||||
f -= 1
|
||||
|
||||
x = dit.head(x, t)
|
||||
if use_unified_sequence_parallel:
|
||||
|
||||
1166
diffsynth/pipelines/wan_video_new.py
Normal file
1166
diffsynth/pipelines/wan_video_new.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user