refine code

This commit is contained in:
Artiprocher
2025-07-28 11:00:54 +08:00
parent bba44173d2
commit 5f68727ad3
22 changed files with 474 additions and 86 deletions

View File

@@ -426,7 +426,7 @@ class ModelManager:
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
fetched_models = []
fetched_model_paths = []
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
@@ -440,12 +440,25 @@ class ModelManager:
return None
if len(fetched_models) == 1:
print(f"Using {model_name} from {fetched_model_paths[0]}.")
model = fetched_models[0]
path = fetched_model_paths[0]
else:
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
if index is None:
model = fetched_models[0]
path = fetched_model_paths[0]
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
elif isinstance(index, int):
model = fetched_models[:index]
path = fetched_model_paths[:index]
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
else:
model = fetched_models
path = fetched_model_paths
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
if require_model_path:
return fetched_models[0], fetched_model_paths[0]
return model, path
else:
return fetched_models[0]
return model
def to(self, device):

View File

@@ -288,6 +288,9 @@ class WanModel(torch.nn.Module):
add_control_adapter: bool = False,
in_dim_control_adapter: int = 24,
seperated_timestep: bool = False,
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
):
super().__init__()
self.dim = dim
@@ -295,6 +298,9 @@ class WanModel(torch.nn.Module):
self.has_image_input = has_image_input
self.patch_size = patch_size
self.seperated_timestep = seperated_timestep
self.require_vae_embedding = require_vae_embedding
self.require_clip_embedding = require_clip_embedding
self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
@@ -352,7 +358,6 @@ class WanModel(torch.nn.Module):
context: torch.Tensor,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
fused_y: Optional[torch.Tensor] = None,
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
**kwargs,
@@ -366,8 +371,6 @@ class WanModel(torch.nn.Module):
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
clip_embdding = self.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
if fused_y is not None:
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
x, (f, h, w) = self.patchify(x)
@@ -690,6 +693,9 @@ class WanModelStateDictConverter:
"num_layers": 30,
"eps": 1e-6,
"seperated_timestep": True,
"require_clip_embedding": False,
"require_vae_embedding": False,
"fuse_vae_embedding_in_latents": True,
}
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
# Wan-AI/Wan2.2-I2V-A14B
@@ -705,6 +711,7 @@ class WanModelStateDictConverter:
"num_heads": 40,
"num_layers": 40,
"eps": 1e-6,
"require_clip_embedding": False,
}
else:
config = {}

View File

@@ -230,16 +230,17 @@ class WanVideoPipeline(BasePipeline):
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace")
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
self.unit_runner = PipelineUnitRunner()
self.units = [
WanVideoUnit_ShapeChecker(),
WanVideoUnit_NoiseInitializer(),
WanVideoUnit_InputVideoEmbedder(),
WanVideoUnit_PromptEmbedder(),
WanVideoUnit_ImageEmbedder(),
WanVideoUnit_ImageVaeEmbedder(),
WanVideoUnit_ImageEmbedderNoClip(),
WanVideoUnit_ImageEmbedderVAE(),
WanVideoUnit_ImageEmbedderCLIP(),
WanVideoUnit_ImageEmbedderFused(),
WanVideoUnit_FunControl(),
WanVideoUnit_FunReference(),
WanVideoUnit_FunCameraControl(),
@@ -259,7 +260,9 @@ class WanVideoPipeline(BasePipeline):
def training_loss(self, **inputs):
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
@@ -517,6 +520,11 @@ class WanVideoPipeline(BasePipeline):
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
pipe.vace = model_manager.fetch_model("wan_video_vace")
# Size division factor
if pipe.vae is not None:
pipe.height_division_factor = pipe.vae.upsampling_factor * 2
pipe.width_division_factor = pipe.vae.upsampling_factor * 2
# Initialize tokenizer
tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download)
@@ -564,7 +572,7 @@ class WanVideoPipeline(BasePipeline):
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Boundary
boundary: Optional[float] = 0.875,
switch_DiT_boundary: Optional[float] = 0.875,
# Scheduler
num_inference_steps: Optional[int] = 50,
sigma_shift: Optional[float] = 5.0,
@@ -617,11 +625,14 @@ class WanVideoPipeline(BasePipeline):
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
# switch high_noise DiT to low_noise DiT
if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps:
self.load_models_to_device(["dit2", "motion_controller", "vace"])
models["dit"] = models.pop("dit2")
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
models["dit"] = self.dit2
# Timestep
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep)
if cfg_scale != 1.0:
@@ -775,6 +786,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit):
class WanVideoUnit_ImageEmbedder(PipelineUnit):
"""
Deprecated
"""
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
@@ -811,70 +825,38 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
return {"clip_feature": clip_context, "y": y}
class WanVideoUnit_ImageVaeEmbedder(PipelineUnit):
"""
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
"""
class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae")
input_params=("input_image", "end_image", "height", "width"),
onload_model_names=("image_encoder",)
)
def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.seperated_timestep:
def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width):
if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device)
z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
clip_context = pipe.image_encoder.encode_image([image])
if end_image is not None:
end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device)
if pipe.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context}
_, mask2 = self.masks_like([noise.squeeze(0)], zero=True)
latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0)
latents = latents.unsqueeze(0)
seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2)
if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel:
import math
seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size
return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len}
@staticmethod
def masks_like(tensor, zero=False, generator=None, p=0.2):
assert isinstance(tensor, list)
out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
if zero:
if generator is not None:
for u, v in zip(out1, out2):
random_num = torch.rand(1, generator=generator, device=generator.device).item()
if random_num < p:
u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp()
v[:, 0] = torch.zeros_like(v[:, 0])
else:
u[:, 0] = u[:, 0]
v[:, 0] = v[:, 0]
else:
for u, v in zip(out1, out2):
u[:, 0] = torch.zeros_like(u[:, 0])
v[:, 0] = torch.zeros_like(v[:, 0])
return out1, out2
class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit):
"""
Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B.
"""
class WanVideoUnit_ImageEmbedderVAE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae")
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride):
if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep:
if input_image is None or not pipe.dit.require_vae_embedding:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
@@ -896,14 +878,36 @@ class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit):
y = torch.concat([msk, y])
y = y.unsqueeze(0)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"fused_y": y}
return {"y": y}
class WanVideoUnit_ImageEmbedderFused(PipelineUnit):
"""
Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B.
"""
def __init__(self):
super().__init__(
input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride):
if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents:
return {}
pipe.load_models_to_device(self.onload_model_names)
image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1)
z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
latents[:, :, 0: 1] = z
return {"latents": latents, "fuse_vae_embedding_in_latents": True}
class WanVideoUnit_FunControl(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
onload_model_names=("vae")
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
@@ -927,7 +931,7 @@ class WanVideoUnit_FunReference(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("reference_image", "height", "width", "reference_image"),
onload_model_names=("vae")
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, reference_image, height, width):
@@ -1200,7 +1204,6 @@ def model_fn_wan_video(
context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
fused_y: Optional[torch.Tensor] = None,
reference_latents = None,
vace_context = None,
vace_scale = 1.0,
@@ -1213,6 +1216,7 @@ def model_fn_wan_video(
use_gradient_checkpointing: bool = False,
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1247,15 +1251,19 @@ def model_fn_wan_video(
get_sequence_parallel_world_size,
get_sp_group)
if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs:
temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten()
temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep])
timestep = temp_ts.unsqueeze(0).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"])))
# Timestep
if dit.seperated_timestep and fuse_vae_embedding_in_latents:
timestep = torch.concat([
torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device),
torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep
]).flatten()
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0))
t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim))
else:
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
# Motion Controller
if motion_bucket_id is not None and motion_controller is not None:
t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim))
context = dit.text_embedding(context)
@@ -1267,16 +1275,15 @@ def model_fn_wan_video(
if timestep.shape[0] != context.shape[0]:
timestep = torch.concat([timestep] * context.shape[0], dim=0)
if dit.has_image_input:
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
# Image Embedding
if y is not None and dit.require_vae_embedding:
x = torch.cat([x, y], dim=1)
if clip_feature is not None and dit.require_clip_embedding:
clip_embdding = dit.img_emb(clip_feature)
context = torch.cat([clip_embdding, context], dim=1)
if fused_y is not None:
x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w)
# Add camera control
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
# Reference image
if reference_latents is not None:

View File

@@ -434,6 +434,8 @@ def wan_parser():
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
return parser

View File

@@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
## Model Inference

View File

@@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)|
## 模型推理

View File

@@ -22,7 +22,7 @@ dataset_snapshot_download(
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
)
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480))
# Text-to-video
video = pipe(
prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",

View File

@@ -1,9 +1,6 @@
import torch
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import snapshot_download
snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B")
pipe = WanVideoPipeline.from_pretrained(

View File

@@ -10,7 +10,7 @@ pipe = WanVideoPipeline.from_pretrained(
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()

View File

@@ -0,0 +1,35 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \
--trainable_models "dit" \
--extra_inputs "input_image" \
--use_gradient_checkpointing_offload \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,31 @@
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \
--trainable_models "dit" \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,14 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-TI2V-5B_full" \
--trainable_models "dit" \
--extra_inputs "input_image"

View File

@@ -0,0 +1,37 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image" \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,36 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 1 \
--min_timestep_boundary 0.875
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--max_timestep_boundary 0.875 \
--min_timestep_boundary 0

View File

@@ -0,0 +1,16 @@
accelerate launch examples/wanvideo/model_training/train.py \
--dataset_base_path data/example_video_dataset \
--dataset_metadata_path data/example_video_dataset/metadata.csv \
--height 480 \
--width 832 \
--num_frames 49 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Wan2.2-TI2V-5B_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "input_image"

View File

@@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule):
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
max_timestep_boundary=1.0,
min_timestep_boundary=0.0,
):
super().__init__()
# Load models
@@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule):
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.max_timestep_boundary = max_timestep_boundary
self.min_timestep_boundary = min_timestep_boundary
def forward_preprocess(self, data):
@@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule):
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"cfg_merge": False,
"vace_scale": 1,
"max_timestep_boundary": self.max_timestep_boundary,
"min_timestep_boundary": self.min_timestep_boundary,
}
# Extra inputs
@@ -106,6 +112,8 @@ if __name__ == "__main__":
lora_rank=args.lora_rank,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
max_timestep_boundary=args.max_timestep_boundary,
min_timestep_boundary=args.min_timestep_boundary,
)
model_logger = ModelLogger(
args.output_path,

View File

@@ -0,0 +1,33 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors")
pipe.dit2.load_state_dict(state_dict)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,28 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors")
pipe.dit2.load_state_dict(state_dict)
pipe.enable_vram_management()
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=1, tiled=True
)
save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,30 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
state_dict = load_state_dict("models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,31 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.dit, "models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,28 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.dit, "models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1)
pipe.load_lora(pipe.dit2, "models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_frames=49,
seed=1, tiled=True
)
save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5)

View File

@@ -0,0 +1,29 @@
import torch
from PIL import Image
from diffsynth import save_video, VideoData, load_state_dict
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.load_lora(pipe.dit, "models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors", alpha=1)
pipe.enable_vram_management()
input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0]
video = pipe(
prompt="from sunset to night, a small town, light, house, river",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
input_image=input_image,
num_frames=49,
seed=1, tiled=False,
)
save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)