merge data process to training script

This commit is contained in:
Artiprocher
2025-09-04 15:18:56 +08:00
parent cb8de6be1b
commit 144365b07d
6 changed files with 35 additions and 147 deletions

View File

@@ -520,14 +520,26 @@ def launch_training_task(
dataset: torch.utils.data.Dataset, dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule, model: DiffusionTrainingModule,
model_logger: ModelLogger, model_logger: ModelLogger,
optimizer: torch.optim.Optimizer, learning_rate: float = 1e-5,
scheduler: torch.optim.lr_scheduler.LRScheduler, weight_decay: float = 1e-2,
num_workers: int = 8, num_workers: int = 8,
save_steps: int = None, save_steps: int = None,
num_epochs: int = 1, num_epochs: int = 1,
gradient_accumulation_steps: int = 1, gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False, find_unused_parameters: bool = False,
args = None,
): ):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
gradient_accumulation_steps = args.gradient_accumulation_steps
find_unused_parameters = args.find_unused_parameters
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
@@ -557,8 +569,12 @@ def launch_data_process_task(
model: DiffusionTrainingModule, model: DiffusionTrainingModule,
model_logger: ModelLogger, model_logger: ModelLogger,
num_workers: int = 8, num_workers: int = 8,
args = None,
): ):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator() accelerator = Accelerator()
model, dataloader = accelerator.prepare(model, dataloader) model, dataloader = accelerator.prepare(model, dataloader)
@@ -568,7 +584,7 @@ def launch_data_process_task(
folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data) data = model(data, return_inputs=True)
torch.save(data, save_path) torch.save(data, save_path)
@@ -671,4 +687,5 @@ def qwen_image_parser():
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.") parser.add_argument("--processor_path", type=str, default=None, help="Path to the processor. If provided, the processor will be used for image editing.")
parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.") parser.add_argument("--enable_fp8_training", default=False, action="store_true", help="Whether to enable FP8 training. Only available for LoRA training on a single GPU.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
return parser return parser

View File

@@ -25,7 +25,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
# Training mode # Training mode
self.switch_pipe_to_training_mode( self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=False, enable_fp8_training=False,
) )

View File

@@ -1,11 +1,12 @@
accelerate launch examples/qwen_image/model_training/train_data_process.py \ accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \ --dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \ --dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \ --max_pixels 1048576 \
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--output_path "./models/train/Qwen-Image_lora_cache" \ --output_path "./models/train/Qwen-Image_lora_cache" \
--use_gradient_checkpointing \ --use_gradient_checkpointing \
--dataset_num_workers 8 --dataset_num_workers 8 \
--task data_process
accelerate launch examples/qwen_image/model_training/train.py \ accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path models/train/Qwen-Image_lora_cache \ --dataset_base_path models/train/Qwen-Image_lora_cache \

View File

@@ -2,7 +2,7 @@ import torch, os, json
from diffsynth import load_state_dict from diffsynth import load_state_dict
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import ControlNetInput from diffsynth.pipelines.flux_image_new import ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, qwen_image_parser, launch_training_task, launch_data_process_task
from diffsynth.trainers.unified_dataset import UnifiedDataset from diffsynth.trainers.unified_dataset import UnifiedDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -29,7 +29,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
# Training mode # Training mode
self.switch_pipe_to_training_mode( self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=enable_fp8_training, enable_fp8_training=enable_fp8_training,
) )
@@ -81,9 +81,10 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
return {**inputs_shared, **inputs_posi} return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None): def forward(self, data, inputs=None, return_inputs=False):
if inputs is None: inputs = self.forward_preprocess(data) if inputs is None: inputs = self.forward_preprocess(data)
else: inputs = self.transfer_data_to_device(inputs, self.pipe.device) else: inputs = self.transfer_data_to_device(inputs, self.pipe.device)
if return_inputs: return inputs
models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models}
loss = self.pipe.training_loss(**models, **inputs) loss = self.pipe.training_loss(**models, **inputs)
return loss return loss
@@ -123,13 +124,8 @@ if __name__ == "__main__":
enable_fp8_training=args.enable_fp8_training, enable_fp8_training=args.enable_fp8_training,
) )
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate, weight_decay=args.weight_decay) launcher_map = {
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) "sft": launch_training_task,
launch_training_task( "data_process": launch_data_process_task
dataset, model, model_logger, optimizer, scheduler, }
num_epochs=args.num_epochs, launcher_map[args.task](dataset, model, model_logger, args=args)
gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
)

View File

@@ -1,126 +0,0 @@
import torch, os, json
from diffsynth import load_state_dict
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.pipelines.flux_image_new import ControlNetInput
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_data_process_task, qwen_image_parser
from diffsynth.trainers.unified_dataset import UnifiedDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class QwenImageTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None, processor_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
enable_fp8_training=False,
):
super().__init__()
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, enable_fp8_training=enable_fp8_training)
tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path)
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config)
# Training mode
self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=enable_fp8_training,
)
# Store other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
def forward_preprocess(self, data):
# CFG-sensitive parameters
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {"negative_prompt": ""}
# CFG-unsensitive parameters
inputs_shared = {
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"edit_image_auto_resize": True,
}
# Extra inputs
controlnet_input, blockwise_controlnet_input = {}, {}
for extra_input in self.extra_inputs:
if extra_input.startswith("blockwise_controlnet_"):
blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
elif extra_input.startswith("controlnet_"):
controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input]
else:
inputs_shared[extra_input] = data[extra_input]
if len(controlnet_input) > 0:
inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)]
if len(blockwise_controlnet_input) > 0:
inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)]
# Pipeline units will automatically process the input parameters.
for unit in self.pipe.units:
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
return {**inputs_shared, **inputs_posi}
def forward(self, data, inputs=None):
if inputs is None: inputs = self.forward_preprocess(data)
return inputs
if __name__ == "__main__":
parser = qwen_image_parser()
args = parser.parse_args()
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=1, # Set repeat = 1
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_image_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=16,
width_division_factor=16,
)
)
model = QwenImageTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
processor_path=args.processor_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
enable_fp8_training=args.enable_fp8_training,
)
model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt)
launch_data_process_task(
dataset, model, model_logger,
num_workers=args.dataset_num_workers,
)

View File

@@ -26,7 +26,7 @@ class WanTrainingModule(DiffusionTrainingModule):
# Training mode # Training mode
self.switch_pipe_to_training_mode( self.switch_pipe_to_training_mode(
self, self.pipe, trainable_models, self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint, lora_base_model, lora_target_modules, lora_rank, lora_checkpoint=lora_checkpoint,
enable_fp8_training=False, enable_fp8_training=False,
) )