diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 383d9f5..18183e5 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -174,9 +174,12 @@ class QwenImagePipeline(BasePipeline): computation_dtype=self.torch_dtype, computation_device="cuda", ) - enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config) - enable_vram_management(self.dit, module_map=module_map, module_config=model_config) - enable_vram_management(self.vae, module_map=module_map, module_config=model_config) + if self.text_encoder is not None: + enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config) + if self.dit is not None: + enable_vram_management(self.dit, module_map=module_map, module_config=model_config) + if self.vae is not None: + enable_vram_management(self.vae, module_map=module_map, module_config=model_config) def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False): diff --git a/diffsynth/trainers/unified_dataset.py b/diffsynth/trainers/unified_dataset.py index fc40eaa..c764ebd 100644 --- a/diffsynth/trainers/unified_dataset.py +++ b/diffsynth/trainers/unified_dataset.py @@ -214,7 +214,7 @@ class LoadTorchPickle(DataProcessingOperator): self.map_location = map_location def __call__(self, data): - return torch.load(data, map_location=self.map_location) + return torch.load(data, map_location=self.map_location, weights_only=False) @@ -306,7 +306,7 @@ class UnifiedDataset(torch.utils.data.Dataset): def __getitem__(self, data_id): if self.load_from_cache: - data = self.cached_data[data_id % len(self.data)].copy() + data = self.cached_data[data_id % len(self.cached_data)] data = self.cached_data_operator(data) else: data = self.data[data_id % len(self.data)].copy() diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 5f02117..f0577a2 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -417,6 +417,13 @@ class DiffusionTrainingModule(torch.nn.Module): state_dict_[name] = param state_dict = state_dict_ return state_dict + + + def transfer_data_to_device(self, data, device): + for key in data: + if isinstance(data[key], torch.Tensor): + data[key] = data[key].to(device) + return data @@ -484,7 +491,10 @@ def launch_training_task( for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() - loss = model(data) + if dataset.load_from_cache: + loss = model({}, inputs=data) + else: + loss = model(data) accelerator.backward(loss) optimizer.step() model_logger.on_step_end(accelerator, model, save_steps) @@ -494,16 +504,24 @@ def launch_training_task( model_logger.on_training_end(accelerator, model, save_steps) -def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0]) +def launch_data_process_task( + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + model_logger: ModelLogger, + num_workers: int = 8, +): + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) accelerator = Accelerator() model, dataloader = accelerator.prepare(model, dataloader) - os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) - for data_id, data in enumerate(tqdm(dataloader)): - with torch.no_grad(): - inputs = model.forward_preprocess(data) - inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs} - torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")) + + for data_id, data in tqdm(enumerate(dataloader)): + with accelerator.accumulate(model): + with torch.no_grad(): + folder = os.path.join(model_logger.output_path, str(accelerator.process_index)) + os.makedirs(folder, exist_ok=True) + save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth") + data = model(data) + torch.save(data, save_path) diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh new file mode 100644 index 0000000..b456ca1 --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Splited.sh @@ -0,0 +1,25 @@ +accelerate launch examples/qwen_image/model_training/train_data_process.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --output_path "./models/train/Qwen-Image_lora_cache" \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 + +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path models/train/Qwen-Image_lora_cache \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters \ + --enable_fp8_training diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 5faa763..b89e679 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -111,6 +111,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): def forward(self, data, inputs=None): if inputs is None: inputs = self.forward_preprocess(data) + else: inputs = self.transfer_data_to_device(inputs, self.pipe.device) models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} loss = self.pipe.training_loss(**models, **inputs) return loss diff --git a/examples/qwen_image/model_training/train_data_process.py b/examples/qwen_image/model_training/train_data_process.py new file mode 100644 index 0000000..0f4f4fb --- /dev/null +++ b/examples/qwen_image/model_training/train_data_process.py @@ -0,0 +1,154 @@ +import torch, os, json +from diffsynth import load_state_dict +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth.pipelines.flux_image_new import ControlNetInput +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_data_process_task, qwen_image_parser +from diffsynth.trainers.unified_dataset import UnifiedDataset +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +class QwenImageTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + tokenizer_path=None, processor_path=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + extra_inputs=None, + enable_fp8_training=False, + ): + super().__init__() + # Load models + offload_dtype = torch.float8_e4m3fn if enable_fp8_training else None + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + model_configs += [ModelConfig(path=path, offload_dtype=offload_dtype) for path in model_paths] + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1], offload_dtype=offload_dtype) for i in model_id_with_origin_paths] + + tokenizer_config = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path) + processor_config = ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/") if processor_path is None else ModelConfig(processor_path) + self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=tokenizer_config, processor_config=processor_config) + + # Enable FP8 + if enable_fp8_training: + self.pipe._enable_fp8_lora_training(torch.float8_e4m3fn) + + # Reset training scheduler (do it in each training step) + self.pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Add LoRA to the base models + if lora_base_model is not None: + model = self.add_lora_to_model( + getattr(self.pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank, + upcast_dtype=self.pipe.torch_dtype, + ) + if lora_checkpoint is not None: + state_dict = load_state_dict(lora_checkpoint) + state_dict = self.mapping_lora_state_dict(state_dict) + load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") + if len(load_result[1]) > 0: + print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") + setattr(self.pipe, lora_base_model, model) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] + + + def forward_preprocess(self, data): + # CFG-sensitive parameters + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {"negative_prompt": ""} + + # CFG-unsensitive parameters + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_image": data["image"], + "height": data["image"].size[1], + "width": data["image"].size[0], + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "edit_image_auto_resize": True, + } + + # Extra inputs + controlnet_input, blockwise_controlnet_input = {}, {} + for extra_input in self.extra_inputs: + if extra_input.startswith("blockwise_controlnet_"): + blockwise_controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input] + elif extra_input.startswith("controlnet_"): + controlnet_input[extra_input.replace("controlnet_", "")] = data[extra_input] + else: + inputs_shared[extra_input] = data[extra_input] + if len(controlnet_input) > 0: + inputs_shared["controlnet_inputs"] = [ControlNetInput(**controlnet_input)] + if len(blockwise_controlnet_input) > 0: + inputs_shared["blockwise_controlnet_inputs"] = [ControlNetInput(**blockwise_controlnet_input)] + + # Pipeline units will automatically process the input parameters. + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + return {**inputs_shared, **inputs_posi} + + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.forward_preprocess(data) + return inputs + + + +if __name__ == "__main__": + parser = qwen_image_parser() + args = parser.parse_args() + dataset = UnifiedDataset( + base_path=args.dataset_base_path, + metadata_path=args.dataset_metadata_path, + repeat=1, # Set repeat = 1 + data_file_keys=args.data_file_keys.split(","), + main_data_operator=UnifiedDataset.default_image_operator( + base_path=args.dataset_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + ) + ) + model = QwenImageTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + tokenizer_path=args.tokenizer_path, + processor_path=args.processor_path, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + lora_checkpoint=args.lora_checkpoint, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + extra_inputs=args.extra_inputs, + enable_fp8_training=args.enable_fp8_training, + ) + model_logger = ModelLogger(args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt) + launch_data_process_task( + dataset, model, model_logger, + num_workers=args.dataset_num_workers, + ) diff --git a/test.py b/test.py new file mode 100644 index 0000000..8f14de3 --- /dev/null +++ b/test.py @@ -0,0 +1,6 @@ +import torch + + +data = torch.load("models/train/Qwen-Image_lora_cache/0/0.pth", map_location="cpu", weights_only=False) +for i in data: + print(i) \ No newline at end of file