From c795e35142cc44a8d924e682853066816ebe8468 Mon Sep 17 00:00:00 2001 From: lzws <63908509+lzws@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:20:31 +0800 Subject: [PATCH 1/2] add wan2.2-fun-A14B inp, control and control-camera (#839) * update wan2.2-fun * update wan2.2-fun * update wan2.2-fun * add examples * update wan2.2-fun * update wan2.2-fun * Rename Wan2.2-Fun-A14B-Inp.py to Wan2.2-Fun-A14B-InP.py --------- Co-authored-by: lzw478614@alibaba-inc.com --- diffsynth/configs/model_config.py | 2 + .../models/wan_video_camera_controller.py | 6 ++- diffsynth/models/wan_video_dit.py | 37 ++++++++++++++++ diffsynth/pipelines/wan_video_new.py | 35 +++++++++++---- .../Wan2.2-Fun-A14B-Control-Camera.py | 43 +++++++++++++++++++ .../Wan2.2-Fun-A14B-Control.py | 35 +++++++++++++++ .../model_inference/Wan2.2-Fun-A14B-InP.py | 35 +++++++++++++++ 7 files changed, 183 insertions(+), 10 deletions(-) create mode 100644 examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py create mode 100644 examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py create mode 100644 examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 9de5bfb..b4b847f 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -150,6 +150,8 @@ model_loader_configs = [ (None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"), (None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"), (None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"), + (None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"), + (None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"), (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), diff --git a/diffsynth/models/wan_video_camera_controller.py b/diffsynth/models/wan_video_camera_controller.py index 026b558..45a44ee 100644 --- a/diffsynth/models/wan_video_camera_controller.py +++ b/diffsynth/models/wan_video_camera_controller.py @@ -182,7 +182,7 @@ def process_pose_file(cam_params, width=672, height=384, original_pose_width=128 def generate_camera_coordinates( - direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"], + direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"], length: int, speed: float = 1/54, origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0) @@ -198,5 +198,9 @@ def generate_camera_coordinates( coor[13] += speed if "Down" in direction: coor[13] -= speed + if "In" in direction: + coor[18] -= speed + if "Out" in direction: + coor[18] += speed coordinates.append(coor) return coordinates diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 419f8cf..1a54728 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -294,6 +294,7 @@ class WanModel(torch.nn.Module): ): super().__init__() self.dim = dim + self.in_dim = in_dim self.freq_dim = freq_dim self.has_image_input = has_image_input self.patch_size = patch_size @@ -713,6 +714,42 @@ class WanModelStateDictConverter: "eps": 1e-6, "require_clip_embedding": False, } + elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5": + # Wan2.2-Fun-A14B-Control + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 52, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": True, + "require_clip_embedding": False, + } + elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1": + # Wan2.2-Fun-A14B-Control-Camera + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6, + "has_ref_conv": False, + "add_control_adapter": True, + "in_dim_control_adapter": 24, + "require_clip_embedding": False, + } else: config = {} return state_dict, config diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 89adbdf..53df7d9 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -663,22 +663,23 @@ class WanVideoUnit_ImageEmbedderFused(PipelineUnit): class WanVideoUnit_FunControl(PipelineUnit): def __init__(self): super().__init__( - input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"), + input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y): + def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents): if control_video is None: return {} pipe.load_models_to_device(self.onload_model_names) control_video = pipe.preprocess_video(control_video) control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device) + y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1] if clip_feature is None or y is None: clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device) - y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) + y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device) else: - y = y[:, -16:] + y = y[:, -y_dim:] y = torch.concat([control_latents, y], dim=1) return {"clip_feature": clip_feature, "y": y} @@ -698,6 +699,8 @@ class WanVideoUnit_FunReference(PipelineUnit): reference_image = reference_image.resize((width, height)) reference_latents = pipe.preprocess_video([reference_image]) reference_latents = pipe.vae.encode(reference_latents, device=pipe.device) + if pipe.image_encoder is None: + return {"reference_latents": reference_latents} clip_feature = pipe.preprocess_image(reference_image) clip_feature = pipe.image_encoder.encode_image([clip_feature]) return {"reference_latents": reference_latents, "clip_feature": clip_feature} @@ -707,13 +710,14 @@ class WanVideoUnit_FunReference(PipelineUnit): class WanVideoUnit_FunCameraControl(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"), + input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image): + def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride): if camera_control_direction is None: return {} + pipe.load_models_to_device(self.onload_model_names) camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates( camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin) @@ -728,14 +732,27 @@ class WanVideoUnit_FunCameraControl(PipelineUnit): control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype) - + input_image = input_image.resize((width, height)) input_latents = pipe.preprocess_video([input_image]) - pipe.load_models_to_device(self.onload_model_names) input_latents = pipe.vae.encode(input_latents, device=pipe.device) y = torch.zeros_like(latents).to(pipe.device) y[:, :, :1] = input_latents y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + + if y.shape[1] != pipe.dit.in_dim - latents.shape[1]: + image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, 1:] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + y = torch.cat([msk,y]) + y = y.unsqueeze(0) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) return {"control_camera_latents_input": control_camera_latents_input, "y": y} @@ -1048,7 +1065,7 @@ def model_fn_wan_video( if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - + # Add camera control x, (f, h, w) = dit.patchify(x, control_camera_latents_input) diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py new file mode 100644 index 0000000..27cda27 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py @@ -0,0 +1,43 @@ +import torch +from diffsynth import save_video,VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +input_image = Image.open("data/examples/wan/input_image.jpg") + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Left", camera_control_speed=0.01, +) +save_video(video, "video_left.mp4", fps=15, quality=5) + +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + input_image=input_image, + camera_control_direction="Up", camera_control_speed=0.01, +) +save_video(video, "video_up.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py new file mode 100644 index 0000000..2941422 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py @@ -0,0 +1,35 @@ +import torch +from diffsynth import save_video,VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"] +) + +# Control video +control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) +reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832)) +video = pipe( + prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=control_video, reference_image=reference_image, + height=832, width=576, num_frames=49, + seed=1, tiled=True +) +save_video(video, "video.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py new file mode 100644 index 0000000..c63e522 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py @@ -0,0 +1,35 @@ +import torch +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from PIL import Image +from modelscope import dataset_snapshot_download + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=f"data/examples/wan/input_image.jpg" +) +image = Image.open("data/examples/wan/input_image.jpg") + +# First and last frame to video +video = pipe( + prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=image, + seed=0, tiled=True, + # You can input `end_image=xxx` to control the last frame of the video. + # The model will automatically generate the dynamic content between `input_image` and `end_image`. +) +save_video(video, "video.mp4", fps=15, quality=5) From ce0b9486555fb33a6e15da101d891e22a59c2017 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 25 Aug 2025 20:32:36 +0800 Subject: [PATCH 2/2] support qwen-image fp8 lora training --- diffsynth/pipelines/qwen_image.py | 29 +++++++++++++++++++++ diffsynth/trainers/utils.py | 7 ++++- examples/qwen_image/model_training/train.py | 14 +++++++--- 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index f0a7496..383d9f5 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -150,6 +150,35 @@ class QwenImagePipeline(BasePipeline): return loss + def _enable_fp8_lora_training(self, dtype): + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLRotaryEmbedding, Qwen2RMSNorm, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionRotaryEmbedding + from ..models.qwen_image_dit import RMSNorm + from ..models.qwen_image_vae import QwenImageRMS_norm + module_map = { + RMSNorm: AutoWrappedModule, + torch.nn.Linear: AutoWrappedLinear, + torch.nn.Conv3d: AutoWrappedModule, + torch.nn.Conv2d: AutoWrappedModule, + torch.nn.Embedding: AutoWrappedModule, + Qwen2_5_VLRotaryEmbedding: AutoWrappedModule, + Qwen2RMSNorm: AutoWrappedModule, + Qwen2_5_VisionPatchEmbed: AutoWrappedModule, + Qwen2_5_VisionRotaryEmbedding: AutoWrappedModule, + QwenImageRMS_norm: AutoWrappedModule, + } + model_config = dict( + offload_dtype=dtype, + offload_device="cuda", + onload_dtype=dtype, + onload_device="cuda", + computation_dtype=self.torch_dtype, + computation_device="cuda", + ) + enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config) + enable_vram_management(self.dit, module_map=module_map, module_config=model_config) + enable_vram_management(self.vae, module_map=module_map, module_config=model_config) + + def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False): self.vram_management_enabled = True if vram_limit is None: diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 065b687..22ea31e 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -338,11 +338,15 @@ class DiffusionTrainingModule(torch.nn.Module): return trainable_param_names - def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None): + def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None): if lora_alpha is None: lora_alpha = lora_rank lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) model = inject_adapter_in_model(lora_config, model) + if upcast_dtype is not None: + for param in model.parameters(): + if param.requires_grad: + param.data = param.to(upcast_dtype) return model @@ -555,4 +559,5 @@ def qwen_image_parser(): parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") + parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.") return parser diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 7418661..ee6752d 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -17,21 +17,27 @@ class QwenImageTrainingModule(DiffusionTrainingModule): use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, extra_inputs=None, + enable_fp8_training=False, ): super().__init__() # Load models + offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None model_configs = [] if model_paths is not None: model_paths = json.loads(model_paths) - model_configs += [ModelConfig(path=path) for path in model_paths] + model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] if model_id_with_origin_paths is not None: model_id_with_origin_paths = model_id_with_origin_paths.split(",") - model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + # Enable FP8 + if enable_fp8_training: + self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) + # Reset training scheduler (do it in each training step) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -43,7 +49,8 @@ class QwenImageTrainingModule(DiffusionTrainingModule): model = self.add_lora_to_model( getattr(self.pipe, lora_base_model), target_modules=lora_target_modules.split(","), - lora_rank=lora_rank + lora_rank=lora_rank, + upcast_dtype=self.pipe.torch_dtype, ) if lora_checkpoint is not None: state_dict = load_state_dict(lora_checkpoint) @@ -126,6 +133,7 @@ if __name__ == "__main__": use_gradient_checkpointing=args.use_gradient_checkpointing, use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, extra_inputs=args.extra_inputs, + enable_fp8_training=args.enable_fp8_training, ) model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay)