Compare commits

...

6 Commits

Author SHA1 Message Date
Artiprocher
2abc97fc0f update fp8 linear computation 2025-08-07 13:40:36 +08:00
Zhongjie Duan
84ede171fd Merge pull request #752 from modelscope/qwen-image-lora-fromat
remove default in qwen-image lora
2025-08-06 15:42:03 +08:00
Artiprocher
6f4e38276e remove default in qwen-image lora 2025-08-06 15:41:22 +08:00
Zhongjie Duan
829ca3414b fmt fixes in wan_video_dit.py
fmt fixes in wan_video_dit.py
2025-08-06 14:39:25 +08:00
Yudong Jin
26461c1963 Update diffsynth/models/wan_video_dit.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-08-04 23:52:48 +08:00
Yudong Jin
0412fc7232 fmt fixes in wan_video_dit.py 2025-08-04 23:40:18 +08:00
4 changed files with 64 additions and 2 deletions

View File

@@ -383,5 +383,20 @@ class WanLoRAConverter:
return state_dict
class QwenImageLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, **kwargs):
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
return state_dict
@staticmethod
def align_to_diffsynth_format(state_dict, **kwargs):
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
return state_dict
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -335,7 +335,7 @@ class WanModel(torch.nn.Module):
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None):
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)

View File

@@ -110,8 +110,47 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
self.lora_A_weights = []
self.lora_B_weights = []
self.lora_merger = None
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
def fp8_linear(
self,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
device = input.device
origin_dtype = input.dtype
origin_shape = input.shape
input = input.reshape(-1, origin_shape[-1])
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
fp8_max = 448.0
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
# To avoid overflow and ensure numerical compatibility during FP8 computation,
# we scale down the input by 2.0 in advance.
# This scaling will be compensated later during the final result scaling.
if self.computation_dtype == torch.float8_e4m3fnuz:
fp8_max = fp8_max / 2.0
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
input = input / (scale_a + 1e-8)
input = input.to(self.computation_dtype)
weight = weight.to(self.computation_dtype)
result = torch._scaled_mm(
input,
weight.T,
scale_a=scale_a,
scale_b=scale_b.T,
bias=bias,
out_dtype=origin_dtype,
)
new_shape = origin_shape[:-1] + result.shape[-1:]
result = result.reshape(new_shape)
return result
def forward(self, x, *args, **kwargs):
# VRAM management
if self.state == 2:
weight, bias = self.weight, self.bias
else:
@@ -123,8 +162,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
else:
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
out = torch.nn.functional.linear(x, weight, bias)
# Linear forward
if self.enable_fp8:
out = self.fp8_linear(x, weight, bias)
else:
out = torch.nn.functional.linear(x, weight, bias)
# LoRA
if len(self.lora_A_weights) == 0:
# No LoRA
return out

View File

@@ -1,6 +1,7 @@
import torch, os, json
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
from diffsynth.models.lora import QwenImageLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -108,6 +109,7 @@ if __name__ == "__main__":
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)