mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e8c201d3b | ||
|
|
d96709fb6a | ||
|
|
bf7b339efb | ||
|
|
b0abdaffb4 | ||
|
|
e9f29bc402 | ||
|
|
1a7f482fbd | ||
|
|
d93e8738cd | ||
|
|
7e5ce5d5c9 | ||
|
|
7aef554d83 |
@@ -381,6 +381,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
|||||||
|
|
||||||
## Update History
|
## Update History
|
||||||
|
|
||||||
|
- **September 22, 2025**: We have supported Direct Preference Optimization (DPO) training for Qwen-Image. Please refer to the [example code](examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh) for the training script.
|
||||||
|
|
||||||
- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
|
- **September 9, 2025**: Our training framework now supports multiple training modes and has been adapted for Qwen-Image. In addition to the standard SFT training mode, Direct Distill is now also supported; please refer to [our example code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support comprehensive model training capabilities.
|
||||||
|
|
||||||
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
||||||
|
|||||||
@@ -397,6 +397,8 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
|
|||||||
|
|
||||||
## 更新历史
|
## 更新历史
|
||||||
|
|
||||||
|
- **2025年9月22日** 我们支持了 Qwen-Image 的直接偏好对齐 (DPO) 训练,训练脚本请参考[示例代码](examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh)。
|
||||||
|
|
||||||
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
||||||
|
|
||||||
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_e
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
import cv2
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class PatchMatcher:
|
class PatchMatcher:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -233,13 +234,11 @@ class PyramidPatchMatcher:
|
|||||||
|
|
||||||
def resample_image(self, images, level):
|
def resample_image(self, images, level):
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||||
images = images.get()
|
images_torch = torch.as_tensor(images, device='cuda', dtype=torch.float32)
|
||||||
images_resample = []
|
images_torch = images_torch.permute(0, 3, 1, 2)
|
||||||
for image in images:
|
images_resample = F.interpolate(images_torch, size=(height, width), mode='area', align_corners=None)
|
||||||
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
images_resample = images_resample.permute(0, 2, 3, 1).contiguous()
|
||||||
images_resample.append(image_resample)
|
return cp.asarray(images_resample)
|
||||||
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
|
||||||
return images_resample
|
|
||||||
|
|
||||||
def initialize_nnf(self, batch_size):
|
def initialize_nnf(self, batch_size):
|
||||||
if self.initialize == "random":
|
if self.initialize == "random":
|
||||||
@@ -262,14 +261,16 @@ class PyramidPatchMatcher:
|
|||||||
def update_nnf(self, nnf, level):
|
def update_nnf(self, nnf, level):
|
||||||
# upscale
|
# upscale
|
||||||
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
||||||
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
nnf[:, 1::2, :, 0] += 1
|
||||||
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
nnf[:, :, 1::2, 1] += 1
|
||||||
# check if scale is 2
|
# check if scale is 2
|
||||||
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
||||||
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
||||||
nnf = nnf.get().astype(np.float32)
|
nnf_torch = torch.as_tensor(nnf, device='cuda', dtype=torch.float32)
|
||||||
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
nnf_torch = nnf_torch.permute(0, 3, 1, 2)
|
||||||
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
nnf_resized = F.interpolate(nnf_torch, size=(height, width), mode='bilinear', align_corners=False)
|
||||||
|
nnf_resized = nnf_resized.permute(0, 2, 3, 1)
|
||||||
|
nnf = cp.asarray(nnf_resized).astype(cp.int32)
|
||||||
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
||||||
return nnf
|
return nnf
|
||||||
|
|
||||||
|
|||||||
@@ -140,8 +140,9 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||||
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
|
||||||
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], inputs["noise"], timestep)
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
training_target = self.scheduler.training_target(inputs["input_latents"], inputs["noise"], timestep)
|
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||||
|
|
||||||
@@ -371,6 +372,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
rand_device: str = "cpu",
|
rand_device: str = "cpu",
|
||||||
# Steps
|
# Steps
|
||||||
num_inference_steps: int = 30,
|
num_inference_steps: int = 30,
|
||||||
|
exponential_shift_mu: float = None,
|
||||||
# Blockwise ControlNet
|
# Blockwise ControlNet
|
||||||
blockwise_controlnet_inputs: list[ControlNetInput] = None,
|
blockwise_controlnet_inputs: list[ControlNetInput] = None,
|
||||||
# EliGen
|
# EliGen
|
||||||
@@ -393,7 +395,7 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
progress_bar_cmd = tqdm,
|
progress_bar_cmd = tqdm,
|
||||||
):
|
):
|
||||||
# Scheduler
|
# Scheduler
|
||||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16))
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=(height // 16) * (width // 16), exponential_shift_mu=exponential_shift_mu)
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
inputs_posi = {
|
inputs_posi = {
|
||||||
@@ -523,7 +525,7 @@ class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
|||||||
return split_result
|
return split_result
|
||||||
|
|
||||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||||
if pipe.text_encoder is not None:
|
if pipe.text_encoder is not None and prompt is not None:
|
||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
|
# If edit_image is None, use the default template for Qwen-Image, otherwise use the template for Qwen-Image-Edit
|
||||||
if edit_image is None:
|
if edit_image is None:
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class FlowMatchScheduler():
|
|||||||
self.set_timesteps(num_inference_steps)
|
self.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None):
|
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, shift=None, dynamic_shift_len=None, exponential_shift_mu=None):
|
||||||
if shift is not None:
|
if shift is not None:
|
||||||
self.shift = shift
|
self.shift = shift
|
||||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||||
@@ -42,7 +42,12 @@ class FlowMatchScheduler():
|
|||||||
if self.inverse_timesteps:
|
if self.inverse_timesteps:
|
||||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||||
if self.exponential_shift:
|
if self.exponential_shift:
|
||||||
mu = self.calculate_shift(dynamic_shift_len) if dynamic_shift_len is not None else self.exponential_shift_mu
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = self.calculate_shift(dynamic_shift_len)
|
||||||
|
else:
|
||||||
|
mu = self.exponential_shift_mu
|
||||||
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
self.sigmas = math.exp(mu) / (math.exp(mu) + (1 / self.sigmas - 1))
|
||||||
else:
|
else:
|
||||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||||
|
|||||||
@@ -396,6 +396,15 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
param.data = param.to(upcast_dtype)
|
param.data = param.to(upcast_dtype)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def disable_all_lora_layers(self, model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, 'enable_adapters'):
|
||||||
|
module.enable_adapters(False)
|
||||||
|
|
||||||
|
def enable_all_lora_layers(self, model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if hasattr(module, 'enable_adapters'):
|
||||||
|
module.enable_adapters(True)
|
||||||
|
|
||||||
def mapping_lora_state_dict(self, state_dict):
|
def mapping_lora_state_dict(self, state_dict):
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
@@ -421,10 +430,12 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def transfer_data_to_device(self, data, device):
|
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
||||||
for key in data:
|
for key in data:
|
||||||
if isinstance(data[key], torch.Tensor):
|
if isinstance(data[key], torch.Tensor):
|
||||||
data[key] = data[key].to(device)
|
data[key] = data[key].to(device)
|
||||||
|
if torch_float_dtype is not None and data[key].dtype in [torch.float, torch.float16, torch.bfloat16]:
|
||||||
|
data[key] = data[key].to(torch_float_dtype)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -552,9 +563,9 @@ def launch_training_task(
|
|||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
if dataset.load_from_cache:
|
if dataset.load_from_cache:
|
||||||
loss = model({}, inputs=data)
|
loss = model({}, inputs=data, accelerator=accelerator)
|
||||||
else:
|
else:
|
||||||
loss = model(data)
|
loss = model(data, accelerator=accelerator)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
model_logger.on_step_end(accelerator, model, save_steps)
|
model_logger.on_step_end(accelerator, model, save_steps)
|
||||||
@@ -688,4 +699,5 @@ def qwen_image_parser():
|
|||||||
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
|
||||||
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
|
||||||
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
||||||
|
parser.add_argument("--beta_dpo", type=float, default=1000, help="hyperparameter beta for DPO loss, only used when task is dpo.")
|
||||||
return parser
|
return parser
|
||||||
|
|||||||
@@ -153,6 +153,19 @@ class BasePipeline(torch.nn.Module):
|
|||||||
latents_next = scheduler.step(noise_pred, timestep, latents)
|
latents_next = scheduler.step(noise_pred, timestep, latents)
|
||||||
return latents_next
|
return latents_next
|
||||||
|
|
||||||
|
def sample_timestep(self):
|
||||||
|
timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,))
|
||||||
|
timestep = self.scheduler.timesteps[timestep_id].to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
return timestep
|
||||||
|
|
||||||
|
def training_loss_minimum(self, noise, timestep, **inputs):
|
||||||
|
inputs["latents"] = self.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = self.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
noise_pred = self.model_fn(**inputs, timestep=timestep)
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * self.scheduler.training_weight(timestep)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
|||||||
return {**inputs_shared, **inputs_posi}
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None, **kwargs):
|
||||||
if inputs is None: inputs = self.forward_preprocess(data)
|
if inputs is None: inputs = self.forward_preprocess(data)
|
||||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
loss = self.pipe.training_loss(**models, **inputs)
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, load_state_dict
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
import torch, math
|
||||||
|
|
||||||
|
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
snapshot_download("MusePublic/Qwen-Image-Distill", allow_file_pattern="qwen_image_distill_3step.safetensors", cache_dir="models")
|
||||||
|
lora_state_dict = load_state_dict("models/MusePublic/Qwen-Image-Distill/qwen_image_distill_3step.safetensors")
|
||||||
|
lora_state_dict = {i.replace("base_model.model.", ""): j for i, j in lora_state_dict.items()}
|
||||||
|
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
||||||
|
|
||||||
|
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||||
|
image = pipe(prompt, seed=0, num_inference_steps=3, cfg_scale=1, exponential_shift_mu=math.log(2.5))
|
||||||
|
image.save("image.jpg")
|
||||||
25
examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh
Normal file
25
examples/qwen_image/model_training/lora/Qwen-Image-DPO.sh
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# dataset format:
|
||||||
|
# {
|
||||||
|
# "image": "path/to/win_image.png", # win image
|
||||||
|
# "lose_image": "path/to/lose_image.png", # lose image
|
||||||
|
# "prompt": "a photo of ...",
|
||||||
|
# }
|
||||||
|
accelerate launch examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/dpo.jsonl \
|
||||||
|
--data_file_keys "image,lose_image" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 400 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_DPO_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--task dpo \
|
||||||
|
--beta_dpo 2500 \
|
||||||
|
--find_unused_parameters
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
import torch, os, json
|
import torch, os
|
||||||
from diffsynth import load_state_dict
|
|
||||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
from diffsynth.pipelines.flux_image_new import ControlNetInput
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
|
||||||
@@ -7,7 +6,6 @@ from diffsynth.trainers.unified_dataset import UnifiedDataset
|
|||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QwenImageTrainingModule(DiffusionTrainingModule):
|
class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -20,6 +18,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
extra_inputs=None,
|
extra_inputs=None,
|
||||||
enable_fp8_training=False,
|
enable_fp8_training=False,
|
||||||
task="sft",
|
task="sft",
|
||||||
|
beta_dpo=1000.,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Load models
|
# Load models
|
||||||
@@ -40,7 +39,8 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
self.task = task
|
self.task = task
|
||||||
|
self.lora_base_model = lora_base_model
|
||||||
|
self.beta_dpo = beta_dpo
|
||||||
|
|
||||||
def forward_preprocess(self, data):
|
def forward_preprocess(self, data):
|
||||||
# CFG-sensitive parameters
|
# CFG-sensitive parameters
|
||||||
@@ -82,11 +82,48 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
return {**inputs_shared, **inputs_posi}
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
def forward_dpo(self, data, accelerator=None):
|
||||||
|
# Loss DPO: -logσ(−β(diff_policy − diff_ref))
|
||||||
|
# Prepare inputs
|
||||||
|
win_data = {key: data[key] for key in ["prompt", "image"]}
|
||||||
|
lose_data = {"prompt": None, "image": data["lose_image"]}
|
||||||
|
inputs_win = self.forward_preprocess(win_data)
|
||||||
|
inputs_lose = self.forward_preprocess(lose_data)
|
||||||
|
inputs_lose.update({key: inputs_win[key] for key in ["prompt", "prompt_emb", "prompt_emb_mask"]})
|
||||||
|
inputs_win.pop('noise')
|
||||||
|
inputs_lose.pop('noise')
|
||||||
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
|
# sample timestep and noise
|
||||||
|
timestep = self.pipe.sample_timestep()
|
||||||
|
noise = torch.rand_like(inputs_win["latents"])
|
||||||
|
# compute diff_policy = loss_win - loss_lose
|
||||||
|
loss_win = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
|
||||||
|
loss_lose = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
|
||||||
|
diff_policy = loss_win - loss_lose
|
||||||
|
# compute diff_ref
|
||||||
|
if self.lora_base_model is not None:
|
||||||
|
self.disable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
|
||||||
|
# load the original model weights
|
||||||
|
with torch.no_grad():
|
||||||
|
loss_win_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_win)
|
||||||
|
loss_lose_ref = self.pipe.training_loss_minimum(noise, timestep, **models, **inputs_lose)
|
||||||
|
diff_ref = loss_win_ref - loss_lose_ref
|
||||||
|
self.enable_all_lora_layers(accelerator.unwrap_model(self).pipe.dit)
|
||||||
|
else:
|
||||||
|
# TODO: may support full model training
|
||||||
|
raise NotImplementedError("DPO with full model training is not supported yet.")
|
||||||
|
# compute loss
|
||||||
|
loss = -1. * torch.nn.functional.logsigmoid(self.beta_dpo * (diff_ref - diff_policy)).mean()
|
||||||
|
return loss
|
||||||
|
|
||||||
def forward(self, data, inputs=None, return_inputs=False):
|
def forward(self, data, inputs=None, return_inputs=False, accelerator=None, **kwargs):
|
||||||
|
if self.task == "dpo":
|
||||||
|
return self.forward_dpo(data, accelerator=accelerator)
|
||||||
# Inputs
|
# Inputs
|
||||||
if inputs is None: inputs = self.forward_preprocess(data)
|
if inputs is None:
|
||||||
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
|
inputs = self.forward_preprocess(data)
|
||||||
|
else:
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
if return_inputs: return inputs
|
if return_inputs: return inputs
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
@@ -135,11 +172,13 @@ if __name__ == "__main__":
|
|||||||
extra_inputs=args.extra_inputs,
|
extra_inputs=args.extra_inputs,
|
||||||
enable_fp8_training=args.enable_fp8_training,
|
enable_fp8_training=args.enable_fp8_training,
|
||||||
task=args.task,
|
task=args.task,
|
||||||
|
beta_dpo=args.beta_dpo,
|
||||||
)
|
)
|
||||||
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
|
||||||
launcher_map = {
|
launcher_map = {
|
||||||
"sft": launch_training_task,
|
"sft": launch_training_task,
|
||||||
"data_process": launch_data_process_task,
|
"data_process": launch_data_process_task,
|
||||||
"direct_distill": launch_training_task,
|
"direct_distill": launch_training_task,
|
||||||
|
"dpo": launch_training_task,
|
||||||
}
|
}
|
||||||
launcher_map[args.task](dataset, model, model_logger, args=args)
|
launcher_map[args.task](dataset, model, model_logger, args=args)
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/Qwen-Image_DPO_lora/epoch-4.safetensors")
|
||||||
|
prompt = "黑板上写着“群起效尤,心灵手巧”,字的颜色分别是 “群”: 橙色、“起”: 黑色、“效”: 蓝色、“尤”: 绿色、“心”: 紫色、“灵”: 粉色、“手”: 红色、“巧”: 白色"
|
||||||
|
for seed in range(0, 5):
|
||||||
|
image = pipe(prompt, seed=seed)
|
||||||
|
image.save(f"image_dpo_{seed}.jpg")
|
||||||
@@ -82,7 +82,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
return {**inputs_shared, **inputs_posi}
|
return {**inputs_shared, **inputs_posi}
|
||||||
|
|
||||||
|
|
||||||
def forward(self, data, inputs=None):
|
def forward(self, data, inputs=None, **kwargs):
|
||||||
if inputs is None: inputs = self.forward_preprocess(data)
|
if inputs is None: inputs = self.forward_preprocess(data)
|
||||||
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
|
||||||
loss = self.pipe.training_loss(**models, **inputs)
|
loss = self.pipe.training_loss(**models, **inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user