mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
refactor wan dit
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||||
from ..models.wan_video_dit import WanLayerNorm, WanRMSNorm
|
from ..models.wan_video_dit import RMSNorm
|
||||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||||
|
|
||||||
|
|
||||||
@@ -60,8 +60,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
torch.nn.Linear: AutoWrappedLinear,
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
torch.nn.Conv3d: AutoWrappedModule,
|
torch.nn.Conv3d: AutoWrappedModule,
|
||||||
torch.nn.LayerNorm: AutoWrappedModule,
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
WanLayerNorm: AutoWrappedModule,
|
torch.nn.LayerNorm: AutoWrappedModule,
|
||||||
WanRMSNorm: AutoWrappedModule,
|
RMSNorm: AutoWrappedModule,
|
||||||
},
|
},
|
||||||
module_config = dict(
|
module_config = dict(
|
||||||
offload_dtype=dtype,
|
offload_dtype=dtype,
|
||||||
@@ -224,7 +224,8 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift)
|
||||||
|
|
||||||
# Initialize noise
|
# Initialize noise
|
||||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device)
|
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32)
|
||||||
|
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||||
if input_video is not None:
|
if input_video is not None:
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
input_video = self.preprocess_images(input_video)
|
input_video = self.preprocess_images(input_video)
|
||||||
@@ -252,20 +253,19 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
# Denoise
|
# Denoise
|
||||||
self.load_models_to_device(["dit"])
|
self.load_models_to_device(["dit"])
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device)
|
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
self.load_models_to_device(['vae'])
|
self.load_models_to_device(['vae'])
|
||||||
|
|||||||
@@ -104,5 +104,6 @@ class WanPrompter(BasePrompter):
|
|||||||
mask = mask.to(device)
|
mask = mask.to(device)
|
||||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||||
prompt_emb = self.text_encoder(ids, mask)
|
prompt_emb = self.text_encoder(ids, mask)
|
||||||
prompt_emb = [u[:v] for u, v in zip(prompt_emb, seq_lens)]
|
for i, v in enumerate(seq_lens):
|
||||||
|
prompt_emb[:, v:] = 0
|
||||||
return prompt_emb
|
return prompt_emb
|
||||||
|
|||||||
Reference in New Issue
Block a user