Merge branch 'main' into qwen-image-eligen

This commit is contained in:
mi804
2025-08-06 15:57:13 +08:00
23 changed files with 220 additions and 29 deletions

View File

@@ -383,5 +383,20 @@ class WanLoRAConverter:
return state_dict
class QwenImageLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, **kwargs):
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
return state_dict
@staticmethod
def align_to_diffsynth_format(state_dict, **kwargs):
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
return state_dict
def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -335,7 +335,7 @@ class WanModel(torch.nn.Module):
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None):
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)

View File

@@ -4,6 +4,7 @@ from PIL import Image
import pandas as pd
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
@@ -364,12 +365,15 @@ class ModelLogger:
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter
def on_step_end(self, loss):
pass
self.num_steps = 0
def on_step_end(self, accelerator, model, save_steps=None):
self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
@@ -381,6 +385,21 @@ class ModelLogger:
accelerator.save(state_dict, path, safe_serialization=True)
def on_training_end(self, accelerator, model, save_steps=None):
if save_steps is not None and self.num_steps % save_steps != 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def save_model(self, accelerator, model, file_name):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, file_name)
accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task(
dataset: torch.utils.data.Dataset,
@@ -388,11 +407,17 @@ def launch_training_task(
model_logger: ModelLogger,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
num_workers: int = 8,
save_steps: int = None,
num_epochs: int = 1,
gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False,
):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0])
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs):
@@ -402,10 +427,11 @@ def launch_training_task(
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(loss)
model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()
model_logger.on_epoch_end(accelerator, model, epoch_id)
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
@@ -446,6 +472,9 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser
@@ -474,6 +503,9 @@ def flux_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser
@@ -503,4 +535,7 @@ def qwen_image_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser