From 5f68727ad31c07ea3ce45f6dab2787854a711a19 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 28 Jul 2025 11:00:54 +0800 Subject: [PATCH] refine code --- diffsynth/models/model_manager.py | 21 ++- diffsynth/models/wan_video_dit.py | 13 +- diffsynth/pipelines/wan_video_new.py | 155 +++++++++--------- diffsynth/trainers/utils.py | 2 + examples/wanvideo/README.md | 3 + examples/wanvideo/README_zh.md | 3 + .../model_inference/Wan2.2-I2V-A14B.py | 2 +- .../model_inference/Wan2.2-T2V-A14B.py | 3 - .../model_inference/Wan2.2-TI2V-5B.py | 2 +- .../model_training/full/Wan2.2-I2V-A14B.sh | 35 ++++ .../model_training/full/Wan2.2-T2V-A14B.sh | 31 ++++ .../model_training/full/Wan2.2-TI2V-5B.sh | 14 ++ .../model_training/lora/Wan2.2-I2V-A14B.sh | 37 +++++ .../model_training/lora/Wan2.2-T2V-A14B.sh | 36 ++++ .../model_training/lora/Wan2.2-TI2V-5B.sh | 16 ++ examples/wanvideo/model_training/train.py | 8 + .../validate_full/Wan2.2-I2V-A14B.py | 33 ++++ .../validate_full/Wan2.2-T2V-A14B.py | 28 ++++ .../validate_full/Wan2.2-TI2V-5B.py | 30 ++++ .../validate_lora/Wan2.2-I2V-A14B.py | 31 ++++ .../validate_lora/Wan2.2-T2V-A14B.py | 28 ++++ .../validate_lora/Wan2.2-TI2V-5B.py | 29 ++++ 22 files changed, 474 insertions(+), 86 deletions(-) create mode 100644 examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 7ae3c50..d46eedf 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -426,7 +426,7 @@ class ModelManager: self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype) - def fetch_model(self, model_name, file_path=None, require_model_path=False): + def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None): fetched_models = [] fetched_model_paths = [] for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name): @@ -440,12 +440,25 @@ class ModelManager: return None if len(fetched_models) == 1: print(f"Using {model_name} from {fetched_model_paths[0]}.") + model = fetched_models[0] + path = fetched_model_paths[0] else: - print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + if index is None: + model = fetched_models[0] + path = fetched_model_paths[0] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.") + elif isinstance(index, int): + model = fetched_models[:index] + path = fetched_model_paths[:index] + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.") + else: + model = fetched_models + path = fetched_model_paths + print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.") if require_model_path: - return fetched_models[0], fetched_model_paths[0] + return model, path else: - return fetched_models[0] + return model def to(self, device): diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 3262057..ea473b0 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -288,6 +288,9 @@ class WanModel(torch.nn.Module): add_control_adapter: bool = False, in_dim_control_adapter: int = 24, seperated_timestep: bool = False, + require_vae_embedding: bool = True, + require_clip_embedding: bool = True, + fuse_vae_embedding_in_latents: bool = False, ): super().__init__() self.dim = dim @@ -295,6 +298,9 @@ class WanModel(torch.nn.Module): self.has_image_input = has_image_input self.patch_size = patch_size self.seperated_timestep = seperated_timestep + self.require_vae_embedding = require_vae_embedding + self.require_clip_embedding = require_clip_embedding + self.fuse_vae_embedding_in_latents = fuse_vae_embedding_in_latents self.patch_embedding = nn.Conv3d( in_dim, dim, kernel_size=patch_size, stride=patch_size) @@ -352,7 +358,6 @@ class WanModel(torch.nn.Module): context: torch.Tensor, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, - fused_y: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, @@ -366,8 +371,6 @@ class WanModel(torch.nn.Module): x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - if fused_y is not None: - x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) x, (f, h, w) = self.patchify(x) @@ -690,6 +693,9 @@ class WanModelStateDictConverter: "num_layers": 30, "eps": 1e-6, "seperated_timestep": True, + "require_clip_embedding": False, + "require_vae_embedding": False, + "fuse_vae_embedding_in_latents": True, } elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626": # Wan-AI/Wan2.2-I2V-A14B @@ -705,6 +711,7 @@ class WanModelStateDictConverter: "num_heads": 40, "num_layers": 40, "eps": 1e-6, + "require_clip_embedding": False, } else: config = {} diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index f27382d..9963aa1 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -230,16 +230,17 @@ class WanVideoPipeline(BasePipeline): self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None - self.in_iteration_models = ("dit", "dit2", "motion_controller", "vace") + self.in_iteration_models = ("dit", "motion_controller", "vace") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), WanVideoUnit_NoiseInitializer(), WanVideoUnit_InputVideoEmbedder(), WanVideoUnit_PromptEmbedder(), - WanVideoUnit_ImageEmbedder(), - WanVideoUnit_ImageVaeEmbedder(), - WanVideoUnit_ImageEmbedderNoClip(), + WanVideoUnit_ImageEmbedderVAE(), + WanVideoUnit_ImageEmbedderCLIP(), + WanVideoUnit_ImageEmbedderFused(), WanVideoUnit_FunControl(), WanVideoUnit_FunReference(), WanVideoUnit_FunCameraControl(), @@ -259,7 +260,9 @@ class WanVideoPipeline(BasePipeline): def training_loss(self, **inputs): - timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) + max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) + min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) + timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device) inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep) @@ -517,6 +520,11 @@ class WanVideoPipeline(BasePipeline): pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") + + # Size division factor + if pipe.vae is not None: + pipe.height_division_factor = pipe.vae.upsampling_factor * 2 + pipe.width_division_factor = pipe.vae.upsampling_factor * 2 # Initialize tokenizer tokenizer_config.download_if_necessary(local_model_path, skip_download=skip_download) @@ -564,7 +572,7 @@ class WanVideoPipeline(BasePipeline): cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, # Boundary - boundary: Optional[float] = 0.875, + switch_DiT_boundary: Optional[float] = 0.875, # Scheduler num_inference_steps: Optional[int] = 50, sigma_shift: Optional[float] = 5.0, @@ -617,11 +625,14 @@ class WanVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - # switch high_noise DiT to low_noise DiT - if models.get("dit2") is not None and timestep.item() < boundary * self.scheduler.num_train_timesteps: - self.load_models_to_device(["dit2", "motion_controller", "vace"]) - models["dit"] = models.pop("dit2") + # Switch DiT if necessary + if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: + self.load_models_to_device(self.in_iteration_models_2) + models["dit"] = self.dit2 + + # Timestep timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + # Inference noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) if cfg_scale != 1.0: @@ -775,6 +786,9 @@ class WanVideoUnit_PromptEmbedder(PipelineUnit): class WanVideoUnit_ImageEmbedder(PipelineUnit): + """ + Deprecated + """ def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), @@ -811,70 +825,38 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit): return {"clip_feature": clip_context, "y": y} -class WanVideoUnit_ImageVaeEmbedder(PipelineUnit): - """ - Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. - """ + +class WanVideoUnit_ImageEmbedderCLIP(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_image", "noise", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae") + input_params=("input_image", "end_image", "height", "width"), + onload_model_names=("image_encoder",) ) - def process(self, pipe: WanVideoPipeline, input_image, noise, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or not pipe.dit.seperated_timestep: + def process(self, pipe: WanVideoPipeline, input_image, end_image, height, width): + if input_image is None or pipe.image_encoder is None or not pipe.dit.require_clip_embedding: return {} pipe.load_models_to_device(self.onload_model_names) - image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1).to(pipe.device) - z = pipe.vae.encode([image.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + clip_context = pipe.image_encoder.encode_image([image]) + if end_image is not None: + end_image = pipe.preprocess_image(end_image.resize((width, height))).to(pipe.device) + if pipe.dit.has_image_pos_emb: + clip_context = torch.concat([clip_context, pipe.image_encoder.encode_image([end_image])], dim=1) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context} + - _, mask2 = self.masks_like([noise.squeeze(0)], zero=True) - latents = (1. - mask2[0]) * z + mask2[0] * noise.squeeze(0) - latents = latents.unsqueeze(0) - seq_len = ((num_frames - 1) // 4 + 1) * (height // pipe.vae.upsampling_factor) * (width // pipe.vae.upsampling_factor) // (2 * 2) - if hasattr(pipe, "use_unified_sequence_parallel") and pipe.use_unified_sequence_parallel: - import math - seq_len = int(math.ceil(seq_len / pipe.sp_size)) * pipe.sp_size - - return {"latents": latents, "latent_mask_for_timestep": mask2[0].unsqueeze(0), "seq_len": seq_len} - - @staticmethod - def masks_like(tensor, zero=False, generator=None, p=0.2): - assert isinstance(tensor, list) - out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] - out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] - - if zero: - if generator is not None: - for u, v in zip(out1, out2): - random_num = torch.rand(1, generator=generator, device=generator.device).item() - if random_num < p: - u[:, 0] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, 0]).exp() - v[:, 0] = torch.zeros_like(v[:, 0]) - else: - u[:, 0] = u[:, 0] - v[:, 0] = v[:, 0] - else: - for u, v in zip(out1, out2): - u[:, 0] = torch.zeros_like(u[:, 0]) - v[:, 0] = torch.zeros_like(v[:, 0]) - - return out1, out2 - - -class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): - """ - Encode input image to fused_y using only VAE. This unit is for Wan-AI/Wan2.2-I2V-A14B. - """ +class WanVideoUnit_ImageEmbedderVAE(PipelineUnit): def __init__(self): super().__init__( input_params=("input_image", "end_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, input_image, end_image, num_frames, height, width, tiled, tile_size, tile_stride): - if input_image is None or pipe.image_encoder is not None or pipe.dit.seperated_timestep: + if input_image is None or not pipe.dit.require_vae_embedding: return {} pipe.load_models_to_device(self.onload_model_names) image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) @@ -896,14 +878,36 @@ class WanVideoUnit_ImageEmbedderNoClip(PipelineUnit): y = torch.concat([msk, y]) y = y.unsqueeze(0) y = y.to(dtype=pipe.torch_dtype, device=pipe.device) - return {"fused_y": y} + return {"y": y} + + + +class WanVideoUnit_ImageEmbedderFused(PipelineUnit): + """ + Encode input image to latents using VAE. This unit is for Wan-AI/Wan2.2-TI2V-5B. + """ + def __init__(self): + super().__init__( + input_params=("input_image", "latents", "height", "width", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, tiled, tile_size, tile_stride): + if input_image is None or not pipe.dit.fuse_vae_embedding_in_latents: + return {} + pipe.load_models_to_device(self.onload_model_names) + image = pipe.preprocess_image(input_image.resize((width, height))).transpose(0, 1) + z = pipe.vae.encode([image], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + latents[:, :, 0: 1] = z + return {"latents": latents, "fuse_vae_embedding_in_latents": True} + class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): @@ -927,7 +931,7 @@ class WanVideoUnit_FunReference(PipelineUnit): def __init__(self): super().__init__( input_params=("reference_image", "height", "width", "reference_image"), - onload_model_names=("vae") + onload_model_names=("vae",) ) def process(self, pipe: WanVideoPipeline, reference_image, height, width): @@ -1200,7 +1204,6 @@ def model_fn_wan_video( context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, - fused_y: Optional[torch.Tensor] = None, reference_latents = None, vace_context = None, vace_scale = 1.0, @@ -1213,6 +1216,7 @@ def model_fn_wan_video( use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, control_camera_latents_input = None, + fuse_vae_embedding_in_latents: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -1247,15 +1251,19 @@ def model_fn_wan_video( get_sequence_parallel_world_size, get_sp_group) - if dit.seperated_timestep and "latent_mask_for_timestep" in kwargs: - temp_ts = (kwargs["latent_mask_for_timestep"][0][0][:, ::2, ::2] * timestep).flatten() - temp_ts= torch.cat([temp_ts, temp_ts.new_ones(kwargs["seq_len"] - temp_ts.size(0)) * timestep]) - timestep = temp_ts.unsqueeze(0).flatten() - t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unflatten(0, (latents.size(0), kwargs["seq_len"]))) + # Timestep + if dit.seperated_timestep and fuse_vae_embedding_in_latents: + timestep = torch.concat([ + torch.zeros((1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device), + torch.ones((latents.shape[2] - 1, latents.shape[3] * latents.shape[4] // 4), dtype=latents.dtype, device=latents.device) * timestep + ]).flatten() + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep).unsqueeze(0)) t_mod = dit.time_projection(t).unflatten(2, (6, dit.dim)) else: t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + + # Motion Controller if motion_bucket_id is not None and motion_controller is not None: t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -1267,16 +1275,15 @@ def model_fn_wan_video( if timestep.shape[0] != context.shape[0]: timestep = torch.concat([timestep] * context.shape[0], dim=0) - if dit.has_image_input: - x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + # Image Embedding + if y is not None and dit.require_vae_embedding: + x = torch.cat([x, y], dim=1) + if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - if fused_y is not None: - x = torch.cat([x, fused_y], dim=1) # (b, c_x + c_y + c_fused_y, f, h, w) # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) - # Reference image if reference_latents is not None: diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index b171857..b27fc34 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -434,6 +434,8 @@ def wan_parser(): parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") return parser diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 7b2ebbd..c75e518 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## Model Inference diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 860ff83..822c8e6 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -65,6 +65,9 @@ save_video(video, "video1.mp4", fps=15, quality=5) |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)| |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)| |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| +|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| +|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| +|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](./model_inference/Wan2.2-TI2V-5B.py)|[code](./model_training/full/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](./model_training/lora/Wan2.2-TI2V-5B.sh)|[code](./model_training/validate_lora/Wan2.2-TI2V-5B.py)| ## 模型推理 diff --git a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py index 9782a2c..0c1be54 100644 --- a/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py @@ -22,7 +22,7 @@ dataset_snapshot_download( allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] ) input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)) -# Text-to-video + video = pipe( prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", diff --git a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py index de9ae5f..27b10d0 100644 --- a/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py @@ -1,9 +1,6 @@ import torch from diffsynth import save_video from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from modelscope import snapshot_download - -snapshot_download("Wan-AI/Wan2.2-T2V-A14B", local_dir="models/Wan-AI/Wan2.2-T2V-A14B") pipe = WanVideoPipeline.from_pretrained( diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index f41a941..50d81c2 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -10,7 +10,7 @@ pipe = WanVideoPipeline.from_pretrained( model_configs=[ ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), - ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), ], ) pipe.enable_vram_management() diff --git a/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..2f531e7 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh @@ -0,0 +1,35 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ + --use_gradient_checkpointing_offload \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..f634117 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh @@ -0,0 +1,31 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_full" \ + --trainable_models "dit" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..def9f89 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_full" \ + --trainable_models "dit" \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh new file mode 100644 index 0000000..4201b47 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh @@ -0,0 +1,37 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-I2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-I2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-I2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-I2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh new file mode 100644 index 0000000..737896c --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh @@ -0,0 +1,36 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:high_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_high_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 1 \ + --min_timestep_boundary 0.875 + + +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-T2V-A14B:low_noise_model/diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-T2V-A14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-T2V-A14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-T2V-A14B_low_noise_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --max_timestep_boundary 0.875 \ + --min_timestep_boundary 0 diff --git a/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh new file mode 100644 index 0000000..6a33b57 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --num_frames 49 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-TI2V-5B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-TI2V-5B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-TI2V-5B:Wan2.2_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-TI2V-5B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image" \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 877c5b8..93ec0bd 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -14,6 +14,8 @@ class WanTrainingModule(DiffusionTrainingModule): use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, + max_timestep_boundary=1.0, + min_timestep_boundary=0.0, ): super().__init__() # Load models @@ -45,6 +47,8 @@ class WanTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + self.max_timestep_boundary = max_timestep_boundary + self.min_timestep_boundary = min_timestep_boundary def forward_preprocess(self, data): @@ -69,6 +73,8 @@ class WanTrainingModule(DiffusionTrainingModule): "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, "cfg_merge": False, "vace_scale": 1, + "max_timestep_boundary": self.max_timestep_boundary, + "min_timestep_boundary": self.min_timestep_boundary, } # Extra inputs @@ -106,6 +112,8 @@ if __name__ == "__main__": lora_rank=args.lora_rank, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, + max_timestep_boundary=args.max_timestep_boundary, + min_timestep_boundary=args.min_timestep_boundary, ) model_logger = ModelLogger( args.output_path, diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..ddcdf5c --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-I2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..be0e000 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_high_noise_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +state_dict = load_state_dict("models/train/Wan2.2-T2V-A14B_low_noise_full/epoch-1.safetensors") +pipe.dit2.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..0f0ea5d --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-TI2V-5B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py new file mode 100644 index 0000000..4a6bd9c --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-I2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-I2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-I2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py new file mode 100644 index 0000000..ab43927 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-T2V-A14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-T2V-A14B_high_noise_lora/epoch-4.safetensors", alpha=1) +pipe.load_lora(pipe.dit2, "models/train/Wan2.2-T2V-A14B_low_noise_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_frames=49, + seed=1, tiled=True +) +save_video(video, "video_Wan2.2-T2V-A14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py new file mode 100644 index 0000000..d5a9229 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-TI2V-5B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + num_frames=49, + seed=1, tiled=False, +) +save_video(video, "video_Wan2.2-TI2V-5B.mp4", fps=15, quality=5)