diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index b27a7d3..7773f7b 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -15,6 +15,7 @@ from typing_extensions import Literal from ..schedulers import FlowMatchScheduler from ..prompters import FluxPrompter from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder +from ..models.step1x_connector import Qwen2Connector from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit from ..lora.flux_lora import FluxLoRALoader @@ -36,7 +37,9 @@ class FluxImagePipeline(BasePipeline): self.vae_decoder: FluxVAEDecoder = None self.vae_encoder: FluxVAEEncoder = None self.unit_runner = PipelineUnitRunner() - self.in_iteration_models = ("dit", ) + self.qwenvl = None + self.step1x_connector: Qwen2Connector = None + self.in_iteration_models = ("dit", "step1x_connector") self.units = [ FluxImageUnit_ShapeChecker(), FluxImageUnit_NoiseInitializer(), @@ -46,6 +49,9 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_EmbeddedGuidanceEmbedder(), FluxImageUnit_IPAdapter(), FluxImageUnit_EntityControl(), + FluxImageUnit_TeaCache(), + FluxImageUnit_Flex(), + FluxImageUnit_Step1x(), ] self.model_fn = model_fn_flux_image @@ -105,6 +111,9 @@ class FluxImagePipeline(BasePipeline): pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2) pipe.ipadapter = model_manager.fetch_model("flux_ipadapter") pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") + # Step1x + pipe.qwenvl = model_manager.fetch_model("qwenvl") + pipe.step1x_connector = model_manager.fetch_model("step1x_connector") return pipe @@ -374,6 +383,73 @@ class FluxImageUnit_EntityControl(PipelineUnit): return inputs_shared, inputs_posi, inputs_nega +class FluxImageUnit_Step1x(PipelineUnit): + def __init__(self): + super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder")) + + def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict): + image = inputs_shared.get("step1x_reference_image",None) + if image is None: + return inputs_shared, inputs_posi, inputs_nega + else: + pipe.load_models_to_device(self.onload_model_names) + prompt = inputs_posi["prompt"] + nega_prompt = inputs_nega["negative_prompt"] + captions = [prompt, nega_prompt] + ref_images = [image, image] + embs, masks = pipe.qwenvl(captions, ref_images) + image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) + image = pipe.vae_encoder(image) + inputs_posi.update({"step1x_llm_embedding": embs[0:1], "step1x_mask": masks[0:1], "step1x_reference_latents": image}) + inputs_nega.update({"step1x_llm_embedding": embs[1:2], "step1x_mask": masks[1:2], "step1x_reference_latents": image}) + return inputs_shared, inputs_posi, inputs_nega + + +class FluxImageUnit_TeaCache(PipelineUnit): + def __init__(self): + super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh")) + + def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh): + if tea_cache_l1_thresh is None: + return {} + else: + return {"tea_cache": TeaCache(num_inference_steps=num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh)} + +class FluxImageUnit_Flex(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae_encoder",) + ) + + def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride): + if pipe.dit.input_dim == 196: + pipe.load_models_to_device(self.onload_model_names) + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + flex_inpaint_image = pipe.preprocess_image(flex_inpaint_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_image = pipe.vae_encoder(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = pipe.preprocess_image(flex_inpaint_mask).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 + flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask) + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + flex_control_image = pipe.preprocess_image(flex_control_image).to(device=pipe.device, dtype=pipe.torch_dtype) + flex_control_image = pipe.vae_encoder(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1) + flex_control_stop_timestep = pipe.scheduler.timesteps[int(flex_control_stop * (len(pipe.scheduler.timesteps) - 1))] + return {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep} + else: + return {} + + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh): self.num_inference_steps = num_inference_steps diff --git a/examples/flux/flex_text_to_image.py b/examples/flux/flex_text_to_image.py new file mode 100644 index 0000000..aa33641 --- /dev/null +++ b/examples/flux/flex_text_to_image.py @@ -0,0 +1,50 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from diffsynth.controlnets.processors import Annotator +import numpy as np +from PIL import Image + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +image = pipe( + prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + seed=0 +) +image.save(f"image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[200:400, 400:700] = 255 +mask = Image.fromarray(mask) +mask.save(f"image_mask.jpg") + +inpaint_image = image + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, + seed=4 +) +image.save(f"image_2_new.jpg") + +control_image = Annotator("canny")(image) +control_image.save("image_control.jpg") + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_control_image=control_image, + seed=4 +) +image.save(f"image_3_new.jpg") diff --git a/examples/flux/flux_teacache.py b/examples/flux/flux_teacache.py new file mode 100644 index 0000000..a325324 --- /dev/null +++ b/examples/flux/flux_teacache.py @@ -0,0 +1,24 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + + +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." + +for tea_cache_l1_thresh in [None, 0.2, 0.4, 0.6, 0.8]: + image = pipe( + prompt=prompt, embedded_guidance=3.5, seed=0, + num_inference_steps=50, tea_cache_l1_thresh=tea_cache_l1_thresh + ) + image.save(f"image_{tea_cache_l1_thresh}.png") diff --git a/examples/flux/step1x.py b/examples/flux/step1x.py new file mode 100644 index 0000000..eca7605 --- /dev/null +++ b/examples/flux/step1x.py @@ -0,0 +1,44 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import snapshot_download +from PIL import Image +import numpy as np + + +snapshot_download("Qwen/Qwen2.5-VL-7B-Instruct", cache_dir="./models") +snapshot_download("stepfun-ai/Step1X-Edit", cache_dir="./models") + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(path="models/Qwen/Qwen2.5-VL-7B-Instruct"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"), + ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="vae.safetensors"), + ], +) + + +pipe.enable_vram_management() + +image = Image.fromarray(np.zeros((1248, 832, 3), dtype=np.uint8) + 255) +image = pipe( + prompt="draw red flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=1, + rand_device='cuda' +) +image.save("image_1.jpg") + + + +image = pipe( + prompt="add more flowers in Chinese ink painting style", + step1x_reference_image=image, + width=832, height=1248, cfg_scale=6, + seed=2, + rand_device='cuda' +) +image.save("image_2.jpg") +