diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index 891ae0e..fd3cb76 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -89,7 +89,7 @@ class FlowMatchScheduler(): return float(mu) @staticmethod - def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None): + def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16): sigma_min = 1 / num_inference_steps sigma_max = 1.0 num_train_timesteps = 1000 diff --git a/diffsynth/diffusion/training_module.py b/diffsynth/diffusion/training_module.py index 034336d..e3b3329 100644 --- a/diffsynth/diffusion/training_module.py +++ b/diffsynth/diffusion/training_module.py @@ -29,6 +29,8 @@ class DiffusionTrainingModule(torch.nn.Module): def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None): if lora_alpha is None: lora_alpha = lora_rank + if isinstance(target_modules, list) and len(target_modules) == 1: + target_modules = target_modules[0] lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules) model = inject_adapter_in_model(lora_config, model) if upcast_dtype is not None: diff --git a/examples/flux2/model_training/lora/FLUX.2-dev.sh b/examples/flux2/model_training/lora/FLUX.2-dev.sh index fa5547d..4b1e74b 100644 --- a/examples/flux2/model_training/lora/FLUX.2-dev.sh +++ b/examples/flux2/model_training/lora/FLUX.2-dev.sh @@ -1,4 +1,4 @@ -accelerate launch train.py \ +accelerate launch examples/flux2/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ --max_pixels 1048576 \ @@ -9,13 +9,13 @@ accelerate launch train.py \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ --lora_base_model "dit" \ - --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ --lora_rank 32 \ --use_gradient_checkpointing \ --dataset_num_workers 8 \ --task "sft:data_process" -accelerate launch train.py \ +accelerate launch examples/flux2/model_training/train.py \ --dataset_base_path "./models/train/FLUX.2-dev-LoRA-splited-cache" \ --max_pixels 1048576 \ --dataset_repeat 50 \ @@ -25,7 +25,7 @@ accelerate launch train.py \ --remove_prefix_in_ckpt "pipe.dit." \ --output_path "./models/train/FLUX.2-dev-LoRA-splited" \ --lora_base_model "dit" \ - --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_qkv_mlp_proj,to_out.0,to_add_out,linear_in,linear_out,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out,single_transformer_blocks.24.attn.to_out,single_transformer_blocks.25.attn.to_out,single_transformer_blocks.26.attn.to_out,single_transformer_blocks.27.attn.to_out,single_transformer_blocks.28.attn.to_out,single_transformer_blocks.29.attn.to_out,single_transformer_blocks.30.attn.to_out,single_transformer_blocks.31.attn.to_out,single_transformer_blocks.32.attn.to_out,single_transformer_blocks.33.attn.to_out,single_transformer_blocks.34.attn.to_out,single_transformer_blocks.35.attn.to_out,single_transformer_blocks.36.attn.to_out,single_transformer_blocks.37.attn.to_out,single_transformer_blocks.38.attn.to_out,single_transformer_blocks.39.attn.to_out,single_transformer_blocks.40.attn.to_out,single_transformer_blocks.41.attn.to_out,single_transformer_blocks.42.attn.to_out,single_transformer_blocks.43.attn.to_out,single_transformer_blocks.44.attn.to_out,single_transformer_blocks.45.attn.to_out,single_transformer_blocks.46.attn.to_out,single_transformer_blocks.47.attn.to_out" \ --lora_rank 32 \ --use_gradient_checkpointing \ --dataset_num_workers 8 \ diff --git a/examples/flux2/model_training/train.py b/examples/flux2/model_training/train.py index aa61482..30408a1 100644 --- a/examples/flux2/model_training/train.py +++ b/examples/flux2/model_training/train.py @@ -80,7 +80,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule): return loss -def qwen_image_parser(): +def flux2_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = add_general_config(parser) parser = add_image_size_config(parser) @@ -89,7 +89,7 @@ def qwen_image_parser(): if __name__ == "__main__": - parser = qwen_image_parser() + parser = flux2_parser() args = parser.parse_args() accelerator = accelerate.Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, diff --git a/examples/flux2/model_training/validate_lora/FLUX.2-dev.py b/examples/flux2/model_training/validate_lora/FLUX.2-dev.py index f1f628a..e67e2a7 100644 --- a/examples/flux2/model_training/validate_lora/FLUX.2-dev.py +++ b/examples/flux2/model_training/validate_lora/FLUX.2-dev.py @@ -23,6 +23,6 @@ pipe = Flux2ImagePipeline.from_pretrained( tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"), ) pipe.load_lora(pipe.dit, "./models/train/FLUX.2-dev-LoRA-splited/epoch-4.safetensors") -prompt = "a dog is jumping" +prompt = "a dog" image = pipe(prompt, seed=0) image.save("image_FLUX.2-dev_lora.jpg")