From 260e32217f649d1dc8a267cc303e42b3a409a129 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 2 Sep 2025 13:14:08 +0800 Subject: [PATCH 1/3] unified dataset --- diffsynth/trainers/unified_dataset.py | 334 ++++++++++++++++++++ examples/flux/model_training/train.py | 18 +- examples/qwen_image/model_training/train.py | 18 +- examples/wanvideo/model_training/train.py | 21 +- 4 files changed, 385 insertions(+), 6 deletions(-) create mode 100644 diffsynth/trainers/unified_dataset.py diff --git a/diffsynth/trainers/unified_dataset.py b/diffsynth/trainers/unified_dataset.py new file mode 100644 index 0000000..fc40eaa --- /dev/null +++ b/diffsynth/trainers/unified_dataset.py @@ -0,0 +1,334 @@ +import torch, torchvision, imageio, os, json, pandas +import imageio.v3 as iio +from PIL import Image + + + +class DataProcessingPipeline: + def __init__(self, operators=None): + self.operators: list[DataProcessingOperator] = [] if operators is None else operators + + def __call__(self, data): + for operator in self.operators: + data = operator(data) + return data + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline(self.operators + pipe.operators) + + + +class DataProcessingOperator: + def __call__(self, data): + raise NotImplementedError("DataProcessingOperator cannot be called directly.") + + def __rshift__(self, pipe): + if isinstance(pipe, DataProcessingOperator): + pipe = DataProcessingPipeline([pipe]) + return DataProcessingPipeline([self]).__rshift__(pipe) + + + +class DataProcessingOperatorRaw(DataProcessingOperator): + def __call__(self, data): + return data + + + +class ToInt(DataProcessingOperator): + def __call__(self, data): + return int(data) + + + +class ToFloat(DataProcessingOperator): + def __call__(self, data): + return float(data) + + + +class ToStr(DataProcessingOperator): + def __init__(self, none_value=""): + self.none_value = none_value + + def __call__(self, data): + if data is None: data = self.none_value + return str(data) + + + +class LoadImage(DataProcessingOperator): + def __init__(self, convert_RGB=True): + self.convert_RGB = convert_RGB + + def __call__(self, data: str): + image = Image.open(data) + if self.convert_RGB: image = image.convert("RGB") + return image + + + +class ImageCropAndResize(DataProcessingOperator): + def __init__(self, height, width, max_pixels, height_division_factor, width_division_factor): + self.height = height + self.width = width + self.max_pixels = max_pixels + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + def get_height_width(self, image): + if self.height is None or self.width is None: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + + def __call__(self, data: Image.Image): + image = self.crop_and_resize(data, *self.get_height_width(data)) + return image + + + +class ToList(DataProcessingOperator): + def __call__(self, data): + return [data] + + + +class LoadVideo(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, reader): + num_frames = self.num_frames + if int(reader.count_frames()) < num_frames: + num_frames = int(reader.count_frames()) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + reader = imageio.get_reader(data) + num_frames = self.get_num_frames(reader) + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(frame_id) + frame = Image.fromarray(frame) + frame = self.frame_processor(frame) + frames.append(frame) + reader.close() + return frames + + + +class SequencialProcess(DataProcessingOperator): + def __init__(self, operator=lambda x: x): + self.operator = operator + + def __call__(self, data): + return [self.operator(i) for i in data] + + + +class LoadGIF(DataProcessingOperator): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + + def get_num_frames(self, path): + num_frames = self.num_frames + images = iio.imread(path, mode="RGB") + if len(images) < num_frames: + num_frames = len(images) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + def __call__(self, data: str): + num_frames = self.get_num_frames(data) + frames = [] + images = iio.imread(data, mode="RGB") + for img in images: + frame = Image.fromarray(img) + frame = self.frame_processor(frame) + frames.append(frame) + if len(frames) >= num_frames: + break + return frames + + + +class RouteByExtensionName(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data: str): + file_ext_name = data.split(".")[-1].lower() + for ext_names, operator in self.operator_map: + if ext_names is None or file_ext_name in ext_names: + return operator(data) + raise ValueError(f"Unsupported file: {data}") + + + +class RouteByType(DataProcessingOperator): + def __init__(self, operator_map): + self.operator_map = operator_map + + def __call__(self, data): + for dtype, operator in self.operator_map: + if dtype is None or isinstance(data, dtype): + return operator(data) + raise ValueError(f"Unsupported data: {data}") + + + +class LoadTorchPickle(DataProcessingOperator): + def __init__(self, map_location="cpu"): + self.map_location = map_location + + def __call__(self, data): + return torch.load(data, map_location=self.map_location) + + + +class ToAbsolutePath(DataProcessingOperator): + def __init__(self, base_path=""): + self.base_path = base_path + + def __call__(self, data): + return os.path.join(self.base_path, data) + + + +class UnifiedDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path=None, metadata_path=None, + repeat=1, + data_file_keys=tuple(), + main_data_operator=lambda x: x, + special_operator_map=None, + ): + self.base_path = base_path + self.metadata_path = metadata_path + self.repeat = repeat + self.data_file_keys = data_file_keys + self.main_data_operator = main_data_operator + self.cached_data_operator = LoadTorchPickle() + self.special_operator_map = {} if special_operator_map is None else special_operator_map + self.data = [] + self.cached_data = [] + self.load_from_cache = metadata_path is None + self.load_metadata(metadata_path) + + @staticmethod + def default_image_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), + (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))), + ]) + + @staticmethod + def default_video_operator( + base_path="", + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + num_frames=81, time_division_factor=4, time_division_remainder=1, + ): + return RouteByType(operator_map=[ + (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ + (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()), + (("gif",), LoadGIF(num_frames, time_division_factor, time_division_remainder) >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)), + (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( + num_frames, time_division_factor, time_division_remainder, + frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + )), + ])), + ]) + + def search_for_cached_data_files(self, path): + for file_name in os.listdir(path): + subpath = os.path.join(path, file_name) + if os.path.isdir(subpath): + self.search_for_cached_data_files(subpath) + elif subpath.endswith(".pth"): + self.cached_data.append(subpath) + + def load_metadata(self, metadata_path): + if metadata_path is None: + print("No metadata_path. Searching for cached data files.") + self.search_for_cached_data_files(self.base_path) + print(f"{len(self.cached_data)} cached data files found.") + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + self.data = metadata + elif metadata_path.endswith(".jsonl"): + metadata = [] + with open(metadata_path, 'r') as f: + for line in f: + metadata.append(json.loads(line.strip())) + self.data = metadata + else: + metadata = pandas.read_csv(metadata_path) + self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + + def __getitem__(self, data_id): + if self.load_from_cache: + data = self.cached_data[data_id % len(self.data)].copy() + data = self.cached_data_operator(data) + else: + data = self.data[data_id % len(self.data)].copy() + for key in self.data_file_keys: + if key in data: + if key in self.special_operator_map: + data[key] = self.special_operator_map[key] + elif key in self.data_file_keys: + data[key] = self.main_data_operator(data[key]) + return data + + def __len__(self): + if self.load_from_cache: + return len(self.cached_data) * self.repeat + else: + return len(self.data) * self.repeat + + def check_data_equal(self, data1, data2): + # Debug only + if len(data1) != len(data2): + return False + for k in data1: + if data1[k] != data2[k]: + return False + return True diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index 46eac56..a0db5c4 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -1,8 +1,9 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser from diffsynth.models.lora import FluxLoRAConverter +from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -106,7 +107,20 @@ class FluxTrainingModule(DiffusionTrainingModule): if __name__ == "__main__": parser = flux_parser() args = parser.parse_args() - dataset = ImageDataset(args=args) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) model = FluxTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index ee6752d..5faa763 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -2,7 +2,8 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.flux_image_new import ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -119,7 +120,20 @@ class QwenImageTrainingModule(DiffusionTrainingModule): if __name__ == "__main__": parser = qwen_image_parser() args = parser.parse_args() - dataset = ImageDataset(args=args) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) model = QwenImageTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index f2f437e..7df70da 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -1,7 +1,8 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser +from diffsynth.trainers.unified_dataset import UnifiedDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -112,7 +113,23 @@ class WanTrainingModule(DiffusionTrainingModule): if __name__ == "__main__": parser = wan_parser() args = parser.parse_args() - dataset = VideoDataset(args=args) + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=args.dataset_repeat, + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_video_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + num_frames=args.num_frames, + time_division_factor=4, + time_division_remainder=1, + ), + ) model = WanTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, From b6da77e46880e9b6085cadb8e266059fae7dc9ae Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 2 Sep 2025 16:44:14 +0800 Subject: [PATCH 2/3] qwen-image splited training --- diffsynth/pipelines/qwen_image.py | 9 +- diffsynth/trainers/unified_dataset.py | 4 +- diffsynth/trainers/utils.py | 36 +++- .../model_training/lora/Qwen-Image-Splited.sh | 25 +++ examples/qwen_image/model_training/train.py | 1 + .../model_training/train_data_process.py | 154 ++++++++++++++++++ test.py | 6 + 7 files changed, 221 insertions(+), 14 deletions(-) create mode 100644 examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh create mode 100644 examples/qwen_image/model_training/train_data_process.py create mode 100644 test.py diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 383d9f5..18183e5 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -174,9 +174,12 @@ class QwenImagePipeline(BasePipeline): computation_dtype=self.torch_dtype, computation_device="cuda", ) - enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config) - enable_vram_management(self.dit, module_map=module_map, module_config=model_config) - enable_vram_management(self.vae, module_map=module_map, module_config=model_config) + if self.text_encoder is not None: + enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config) + if self.dit is not None: + enable_vram_management(self.dit, module_map=module_map, module_config=model_config) + if self.vae is not None: + enable_vram_management(self.vae, module_map=module_map, module_config=model_config) def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False): diff --git a/diffsynth/trainers/unified_dataset.py b/diffsynth/trainers/unified_dataset.py index fc40eaa..c764ebd 100644 --- a/diffsynth/trainers/unified_dataset.py +++ b/diffsynth/trainers/unified_dataset.py @@ -214,7 +214,7 @@ class LoadTorchPickle(DataProcessingOperator): self.map_location = map_location def __call__(self, data): - return torch.load(data, map_location=self.map_location) + return torch.load(data, map_location=self.map_location, weights_only=False) @@ -306,7 +306,7 @@ class UnifiedDataset(torch.utils.data.Dataset): def __getitem__(self, data_id): if self.load_from_cache: - data = self.cached_data[data_id % len(self.data)].copy() + data = self.cached_data[data_id % len(self.cached_data)] data = self.cached_data_operator(data) else: data = self.data[data_id % len(self.data)].copy() diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 5f02117..f0577a2 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -417,6 +417,13 @@ class DiffusionTrainingModule(torch.nn.Module): state_dict_[name] = param state_dict = state_dict_ return state_dict + + + def transfer_data_to_device(self, data, device): + for key in data: + if isinstance(data[key], torch.Tensor): + data[key] = data[key].to(device) + return data @@ -484,7 +491,10 @@ def launch_training_task( for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() - loss = model(data) + if dataset.load_from_cache: + loss = model({}, inputs=data) + else: + loss = model(data) accelerator.backward(loss) optimizer.step() model_logger.on_step_end(accelerator, model, save_steps) @@ -494,16 +504,24 @@ def launch_training_task( model_logger.on_training_end(accelerator, model, save_steps) -def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0]) +def launch_data_process_task( + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + num_workers: int = 8, +): + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) accelerator = Accelerator() model, dataloader = accelerator.prepare(model, dataloader) - os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) - for data_id, data in enumerate(tqdm(dataloader)): - with torch.no_grad(): - inputs = model.forward_preprocess(data) - inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs} - torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")) + + for data_id, data in tqdm(enumerate(dataloader)): + with accelerator.accumulate(model): + with torch.no_grad(): + folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) + os.makedirs(folder, exist_ok=True) + save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") + data = model(data) + torch.save(data, save_path) diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh new file mode 100644 index 0000000..b456ca1 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh @@ -0,0 +1,25 @@ +accelerate launch examples/qwen_image/model_training/train_data_process.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --output_path "./models/train/Qwen-Image_lora_cache" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path models/train/Qwen-Image_lora_cache \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --enable_fp8_training diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 5faa763..b89e679 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -111,6 +111,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): def forward(self, data, inputs=None): if inputs is None: inputs = self.forward_preprocess(data) + else: inputs = self.transfer_data_to_device(inputs, self.pipe.device) models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(**models, **inputs) return loss diff --git a/examples/qwen_image/model_training/train_data_process.py b/examples/qwen_image/model_training/train_data_process.py new file mode 100644 index 0000000..0f4f4fb --- /dev/null +++ b/examples/qwen_image/model_training/train_data_process.py @@ -0,0 +1,154 @@ +import torch, os, json +from diffsynth import load_state_dict +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.pipelines.flux_image_new import ControlNetInput +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_data_process_task, qwen_image_parser +from diffsynth.trainers.unified_dataset import UnifiedDataset +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +class QwenImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, processor_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + enable_fp8_training=False, + ): + super().__init__() + # Load models + offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] + + tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) + self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # Enable FP8 + if enable_fp8_training: + self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) + + # Reset training scheduler (do it in each training step) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Add LoRA to the base models + if lora_base_model is not None: + model = self.add_lora_to_model( + getattr(self.pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank, + upcast_dtype=self.pipe.torch_dtype, + ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") + setattr(self.pipe, lora_base_model, model) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + + + def forward_preprocess(self, data): + # CFG-sensitive parameters + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + + # CFG-unsensitive parameters + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "edit_image_auto_resize": True, + } + + # Extra inputs + controlnet_input, blockwise_controlnet_input = {}, {} + for extra_input in self.extra_inputs: + if extra_input.startswith("blockwise_controlnet_"): + blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] + elif extra_input.startswith("controlnet_"): + controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] + else: + inputs_shared[extra_input] = data[extra_input] + if len(controlnet_input) > 0: + inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)] + if len(blockwise_controlnet_input) > 0: + inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)] + + # Pipeline units will automatically process the input parameters. + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + return {**inputs_shared, **inputs_posi} + + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.forward_preprocess(data) + return inputs + + + +if __name__ == "__main__": + parser = qwen_image_parser() + args = parser.parse_args() + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=1, # Set repeat = 1 + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = QwenImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + processor_path=args.processor_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + enable_fp8_training=args.enable_fp8_training, + ) + model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) + launch_data_process_task( + dataset, model, model_logger, + num_workers=args.dataset_num_workers, + ) diff --git a/test.py b/test.py new file mode 100644 index 0000000..8f14de3 --- /dev/null +++ b/test.py @@ -0,0 +1,6 @@ +import torch + + +data = torch.load("models/train/Qwen-Image_lora_cache/0/0.pth", map_location="cpu", weights_only=False) +for i in data: + print(i) \ No newline at end of file From 958ebf135278f3534582aa8ed8a34aaf0d283990 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 2 Sep 2025 16:44:36 +0800 Subject: [PATCH 3/3] remove testing script --- test.py | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index 8f14de3..0000000 --- a/test.py +++ /dev/null @@ -1,6 +0,0 @@ -import torch - - -data = torch.load("models/train/Qwen-Image_lora_cache/0/0.pth", map_location="cpu", weights_only=False) -for i in data: - print(i) \ No newline at end of file