diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 3bf971d..1568efb 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -102,6 +102,7 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_InputImageEmbedder(), FluxImageUnit_ImageIDs(), FluxImageUnit_EmbeddedGuidanceEmbedder(), + FluxImageUnit_Kontext(), FluxImageUnit_InfiniteYou(), FluxImageUnit_ControlNet(), FluxImageUnit_IPAdapter(), @@ -211,6 +212,8 @@ class FluxImagePipeline(BasePipeline): multidiffusion_prompts=(), multidiffusion_masks=(), multidiffusion_scales=(), + # Kontext + kontext_images: Union[list[Image.Image], Image.Image] = None, # ControlNet controlnet_inputs: list[ControlNetInput] = None, # IP-Adapter @@ -257,6 +260,7 @@ class FluxImagePipeline(BasePipeline): "seed": seed, "rand_device": rand_device, "sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps, "multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales, + "kontext_images": kontext_images, "controlnet_inputs": controlnet_inputs, "ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale, "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint, @@ -378,6 +382,32 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): +class FluxImageUnit_Kontext(PipelineUnit): + def __init__(self): + super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride")) + + def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride): + if kontext_images is None: + return {} + if not isinstance(kontext_images, list): + kontext_images = [kontext_images] + + kontext_latents = [] + kontext_image_ids = [] + for kontext_image in kontext_images: + kontext_image = pipe.preprocess_image(kontext_image) + kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image_ids = pipe.dit.prepare_image_ids(kontext_latent) + image_ids[..., 0] = 1 + kontext_image_ids.append(image_ids) + kontext_latent = pipe.dit.patchify(kontext_latent) + kontext_latents.append(kontext_latent) + kontext_latents = torch.concat(kontext_latents, dim=1) + kontext_image_ids = torch.concat(kontext_image_ids, dim=-2) + return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids} + + + class FluxImageUnit_ControlNet(PipelineUnit): def __init__(self): super().__init__( @@ -688,6 +718,8 @@ def model_fn_flux_image( guidance=None, text_ids=None, image_ids=None, + kontext_latents=None, + kontext_image_ids=None, controlnet_inputs=None, controlnet_conditionings=None, tiled=False, @@ -787,6 +819,11 @@ def model_fn_flux_image( height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) + # Kontext + if kontext_latents is not None: + image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2) + hidden_states = torch.concat([hidden_states, kontext_latents], dim=1) + # Step1x if step1x_reference_latents is not None: step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents) @@ -827,7 +864,10 @@ def model_fn_flux_image( ) # ControlNet if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None: - hidden_states = hidden_states + controlnet_res_stack[block_id] + if kontext_latents is None: + hidden_states = hidden_states + controlnet_res_stack[block_id] + else: + hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id] # Single Blocks hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) @@ -846,7 +886,10 @@ def model_fn_flux_image( ) # ControlNet if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None: - hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + if kontext_latents is None: + hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] + else: + hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id] hidden_states = hidden_states[:, prompt_emb.shape[1]:] if tea_cache is not None: @@ -858,6 +901,10 @@ def model_fn_flux_image( # Step1x if step1x_reference_latents is not None: hidden_states = hidden_states[:, :hidden_states.shape[1] // 2] + + # Kontext + if kontext_latents is not None: + hidden_states = hidden_states[:, :-kontext_latents.shape[1]] hidden_states = dit.unpatchify(hidden_states, height, width) diff --git a/download.py b/download.py deleted file mode 100644 index 1404010..0000000 --- a/download.py +++ /dev/null @@ -1,3 +0,0 @@ -#模型下载 -from modelscope import snapshot_download -model_dir = snapshot_download('black-forest-labs/FLUX.1-Kontext-dev', cache_dir="models", ignore_file_pattern="transformer/*") \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-Kontext-dev.py b/examples/flux/model_inference/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000..3d0e921 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-Kontext-dev.py @@ -0,0 +1,54 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image_1 = pipe( + prompt="a beautiful Asian long-haired female college student.", + embedded_guidance=2.5, + seed=1, +) +image_1.save("image_1.jpg") + +image_2 = pipe( + prompt="transform the style to anime style.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=2, +) +image_2.save("image_2.jpg") + +image_3 = pipe( + prompt="let her smile.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=3, +) +image_3.save("image_3.jpg") + +image_4 = pipe( + prompt="let the girl play basketball.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=4, +) +image_4.save("image_4.jpg") + +image_5 = pipe( + prompt="move the girl to a park, let her sit on a chair.", + kontext_images=image_1, + embedded_guidance=2.5, + seed=5, +) +image_5.save("image_5.jpg") \ No newline at end of file diff --git a/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh b/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh new file mode 100644 index 0000000..814d7ad --- /dev/null +++ b/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh @@ -0,0 +1,17 @@ +accelerate launch examples/flux/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \ + --data_file_keys "image,kontext_images" \ + --max_pixels 1048576 \ + --dataset_repeat 400 \ + --model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/FLUX.1-Kontext-dev_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --extra_inputs "kontext_images" \ + --use_gradient_checkpointing diff --git a/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py b/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py new file mode 100644 index 0000000..b61cd4b --- /dev/null +++ b/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py @@ -0,0 +1,24 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) +pipe.load_lora(pipe.dit, "models/train/FLUX.1-Kontext-dev_lora/epoch-4.safetensors", alpha=1) + +image = pipe( + prompt="Make the dog turn its head around.", + kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)), + height=768, width=768, + seed=0 +) +image.save("image_FLUX.1-Kontext-dev_lora.jpg")