mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
kontext
This commit is contained in:
@@ -102,6 +102,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_InputImageEmbedder(),
|
||||
FluxImageUnit_ImageIDs(),
|
||||
FluxImageUnit_EmbeddedGuidanceEmbedder(),
|
||||
FluxImageUnit_Kontext(),
|
||||
FluxImageUnit_InfiniteYou(),
|
||||
FluxImageUnit_ControlNet(),
|
||||
FluxImageUnit_IPAdapter(),
|
||||
@@ -211,6 +212,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
multidiffusion_prompts=(),
|
||||
multidiffusion_masks=(),
|
||||
multidiffusion_scales=(),
|
||||
# Kontext
|
||||
kontext_images: Union[list[Image.Image], Image.Image] = None,
|
||||
# ControlNet
|
||||
controlnet_inputs: list[ControlNetInput] = None,
|
||||
# IP-Adapter
|
||||
@@ -257,6 +260,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"sigma_shift": sigma_shift, "num_inference_steps": num_inference_steps,
|
||||
"multidiffusion_prompts": multidiffusion_prompts, "multidiffusion_masks": multidiffusion_masks, "multidiffusion_scales": multidiffusion_scales,
|
||||
"kontext_images": kontext_images,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"ipadapter_images": ipadapter_images, "ipadapter_scale": ipadapter_scale,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
|
||||
@@ -378,6 +382,32 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_Kontext(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride"))
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
|
||||
if kontext_images is None:
|
||||
return {}
|
||||
if not isinstance(kontext_images, list):
|
||||
kontext_images = [kontext_images]
|
||||
|
||||
kontext_latents = []
|
||||
kontext_image_ids = []
|
||||
for kontext_image in kontext_images:
|
||||
kontext_image = pipe.preprocess_image(kontext_image)
|
||||
kontext_latent = pipe.vae_encoder(kontext_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image_ids = pipe.dit.prepare_image_ids(kontext_latent)
|
||||
image_ids[..., 0] = 1
|
||||
kontext_image_ids.append(image_ids)
|
||||
kontext_latent = pipe.dit.patchify(kontext_latent)
|
||||
kontext_latents.append(kontext_latent)
|
||||
kontext_latents = torch.concat(kontext_latents, dim=1)
|
||||
kontext_image_ids = torch.concat(kontext_image_ids, dim=-2)
|
||||
return {"kontext_latents": kontext_latents, "kontext_image_ids": kontext_image_ids}
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_ControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -688,6 +718,8 @@ def model_fn_flux_image(
|
||||
guidance=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
kontext_latents=None,
|
||||
kontext_image_ids=None,
|
||||
controlnet_inputs=None,
|
||||
controlnet_conditionings=None,
|
||||
tiled=False,
|
||||
@@ -787,6 +819,11 @@ def model_fn_flux_image(
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = dit.patchify(hidden_states)
|
||||
|
||||
# Kontext
|
||||
if kontext_latents is not None:
|
||||
image_ids = torch.concat([image_ids, kontext_image_ids], dim=-2)
|
||||
hidden_states = torch.concat([hidden_states, kontext_latents], dim=1)
|
||||
|
||||
# Step1x
|
||||
if step1x_reference_latents is not None:
|
||||
step1x_reference_image_ids = dit.prepare_image_ids(step1x_reference_latents)
|
||||
@@ -827,7 +864,10 @@ def model_fn_flux_image(
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_res_stack is not None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
if kontext_latents is None:
|
||||
hidden_states = hidden_states + controlnet_res_stack[block_id]
|
||||
else:
|
||||
hidden_states[:, :-kontext_latents.shape[1]] = hidden_states[:, :-kontext_latents.shape[1]] + controlnet_res_stack[block_id]
|
||||
|
||||
# Single Blocks
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
@@ -846,7 +886,10 @@ def model_fn_flux_image(
|
||||
)
|
||||
# ControlNet
|
||||
if controlnet is not None and controlnet_conditionings is not None and controlnet_single_res_stack is not None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
if kontext_latents is None:
|
||||
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]
|
||||
else:
|
||||
hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] = hidden_states[:, prompt_emb.shape[1]:-kontext_latents.shape[1]] + controlnet_single_res_stack[block_id]
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
if tea_cache is not None:
|
||||
@@ -858,6 +901,10 @@ def model_fn_flux_image(
|
||||
# Step1x
|
||||
if step1x_reference_latents is not None:
|
||||
hidden_states = hidden_states[:, :hidden_states.shape[1] // 2]
|
||||
|
||||
# Kontext
|
||||
if kontext_latents is not None:
|
||||
hidden_states = hidden_states[:, :-kontext_latents.shape[1]]
|
||||
|
||||
hidden_states = dit.unpatchify(hidden_states, height, width)
|
||||
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
#模型下载
|
||||
from modelscope import snapshot_download
|
||||
model_dir = snapshot_download('black-forest-labs/FLUX.1-Kontext-dev', cache_dir="models", ignore_file_pattern="transformer/*")
|
||||
54
examples/flux/model_inference/FLUX.1-Kontext-dev.py
Normal file
54
examples/flux/model_inference/FLUX.1-Kontext-dev.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image_1 = pipe(
|
||||
prompt="a beautiful Asian long-haired female college student.",
|
||||
embedded_guidance=2.5,
|
||||
seed=1,
|
||||
)
|
||||
image_1.save("image_1.jpg")
|
||||
|
||||
image_2 = pipe(
|
||||
prompt="transform the style to anime style.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=2,
|
||||
)
|
||||
image_2.save("image_2.jpg")
|
||||
|
||||
image_3 = pipe(
|
||||
prompt="let her smile.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=3,
|
||||
)
|
||||
image_3.save("image_3.jpg")
|
||||
|
||||
image_4 = pipe(
|
||||
prompt="let the girl play basketball.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=4,
|
||||
)
|
||||
image_4.save("image_4.jpg")
|
||||
|
||||
image_5 = pipe(
|
||||
prompt="move the girl to a park, let her sit on a chair.",
|
||||
kontext_images=image_1,
|
||||
embedded_guidance=2.5,
|
||||
seed=5,
|
||||
)
|
||||
image_5.save("image_5.jpg")
|
||||
17
examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh
Normal file
17
examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
accelerate launch examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
|
||||
--data_file_keys "image,kontext_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Kontext-dev_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp" \
|
||||
--lora_rank 32 \
|
||||
--align_to_opensource_format \
|
||||
--extra_inputs "kontext_images" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-Kontext-dev", origin_file_pattern="flux1-kontext-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/FLUX.1-Kontext-dev_lora/epoch-4.safetensors", alpha=1)
|
||||
|
||||
image = pipe(
|
||||
prompt="Make the dog turn its head around.",
|
||||
kontext_images=Image.open("data/example_image_dataset/2.jpg").resize((768, 768)),
|
||||
height=768, width=768,
|
||||
seed=0
|
||||
)
|
||||
image.save("image_FLUX.1-Kontext-dev_lora.jpg")
|
||||
Reference in New Issue
Block a user